Spaces:
Running
Running
| import os | |
| import numpy as np | |
| import h5py | |
| import hnswlib | |
| import torch | |
| import open_clip | |
| import torch | |
| from flask import Flask, request, jsonify | |
| from flask_cors import CORS | |
| from PIL import Image | |
| import requests | |
| import io | |
| import base64 | |
| from huggingface_hub import hf_hub_download | |
| from flask import Response, send_file | |
| import tempfile | |
| PREFETCH_IMAGES = True # bật lên cho nhanh | |
| PLACEHOLDER_BASE64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg==" | |
| app = Flask(__name__) | |
| CORS(app, origins=['*']) | |
| print("\n" + "="*50) | |
| print("📥 INITIALIZING MEDICAL SERVER...") | |
| print("="*50) | |
| # Cấu hình Dataset | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| DATASET_ID = "huynguyen6906/Medical_server_data" | |
| # Tải file từ Hugging Face Dataset | |
| try: | |
| print(f"Downloading data from {DATASET_ID}...") | |
| H5_FILE_PATH = hf_hub_download(repo_id=DATASET_ID, filename="Medical_Embedded.h5", repo_type="dataset", token=HF_TOKEN) | |
| BIN_FILE_PATH = hf_hub_download(repo_id=DATASET_ID, filename="Medical_Embedded.bin", repo_type="dataset", token=HF_TOKEN) | |
| print(f"✅ Data loaded: {H5_FILE_PATH}") | |
| except Exception as e: | |
| print(f"❌ Error downloading data: {str(e)}") | |
| H5_FILE_PATH = 'Medical_Embedded.h5' | |
| BIN_FILE_PATH = 'Medical_Embedded.bin' | |
| class ImageSearchEngine: | |
| def __init__(self, h5_file_path=H5_FILE_PATH): | |
| print("Initializing Search Engine...") | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print("Loading BiomedCLIP-PubMedBERT_256-vit_base_patch16_224...") | |
| self.model, preprocess_train, self.preprocess = open_clip.create_model_and_transforms('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224') | |
| self.tokenizer = open_clip.get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224') | |
| if not os.path.exists(h5_file_path): | |
| # Tạo file giả nếu không có để server không crash ngay (giúp debug) | |
| print("⚠️ H5 file not found. Running in empty mode.") | |
| self.max_elements = 0 | |
| self.dim = 512 | |
| return | |
| self.h5_file = h5py.File(h5_file_path, 'r') | |
| self.dim = self.h5_file['embeddings'].shape[1] | |
| self.max_elements = len(self.h5_file['urls']) | |
| print(f"Loaded {self.max_elements} image embeddings. Dim: {self.dim}") | |
| self.index = hnswlib.Index(space='cosine', dim=self.dim) | |
| if os.path.exists(BIN_FILE_PATH): | |
| print(f"⚡ Loading Index from {BIN_FILE_PATH}...") | |
| self.index.load_index(BIN_FILE_PATH, max_elements=self.max_elements) | |
| self.index.set_ef(400) | |
| else: | |
| print("⚠️ BIN file not found.") | |
| def text_to_vector(self, text): | |
| if isinstance(text, str): | |
| text = [text] | |
| tokens = self.tokenizer(text).to(self.device) | |
| with torch.no_grad(): | |
| text_features = self.model.encode_text(tokens) | |
| text_features = text_features / text_features.norm(dim=-1, keepdim=True) | |
| return text_features.cpu().numpy() | |
| def image_to_vector(self, image): | |
| image_tensor = self.preprocess(image).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| image_features = self.model.encode_image(image_tensor) | |
| image_features = image_features / image_features.norm(dim=-1, keepdim=True) | |
| return image_features.cpu().numpy().astype(np.float32)[0] | |
| def search(self, vector, k=10): | |
| if self.max_elements == 0: | |
| return [] | |
| indices, distances = self.index.knn_query(vector, k=k) | |
| results = [] | |
| for idx, dist in zip(indices[0], distances[0]): | |
| url_bytes = self.h5_file['urls'][idx] | |
| url = url_bytes.decode('utf-8') if isinstance(url_bytes, bytes) else str(url_bytes) | |
| url = url.strip() | |
| result = { | |
| 'path': url, | |
| 'url': url, | |
| 'score': float(1 - dist) | |
| } | |
| # Nếu bật prefetch → gửi thẳng URL (frontend sẽ dùng /i/ để load cực nhanh) | |
| if PREFETCH_IMAGES: | |
| result['image_data'] = url # không cần base64 nữa! | |
| results.append(result) | |
| return results | |
| search_engine = ImageSearchEngine() | |
| # --- ROUTES --- | |
| def health_check(): | |
| return jsonify({'status': 'healthy', 'total_images': search_engine.max_elements}) | |
| def search_text(): | |
| try: | |
| data = request.get_json() | |
| query = data.get('query', '') | |
| k = int(data.get('k', 20)) | |
| vector = search_engine.text_to_vector(query) | |
| results = search_engine.search(vector, k=k) | |
| return jsonify({'results': results}) | |
| except Exception as e: | |
| return jsonify({'error': str(e)}), 500 | |
| def search_image_file(): | |
| try: | |
| if 'image' not in request.files: | |
| return jsonify({'error': 'No image provided'}), 400 | |
| file = request.files['image'] | |
| k = int(request.form.get('k', 20)) | |
| img = Image.open(file.stream).convert('RGB') | |
| vector = search_engine.image_to_vector(img) | |
| results = search_engine.search(vector, k=k) | |
| return jsonify({'results': results}) | |
| except Exception as e: | |
| return jsonify({'error': str(e)}), 500 | |
| def fast_proxy(image_url): | |
| """ | |
| URL đã có sẵn https:// → chỉ cần redirect thẳng, không cần kiểm tra gì thêm | |
| Ví dụ: /i/i.redd.it/abc123.jpg → https://i.redd.it/abc123.jpg | |
| /i/pbs.twimg.com/media/xyz.jpg → https://pbs.twimg.com/media/xyz.jpg | |
| """ | |
| # image_url là phần sau /i/ → ghép lại thành URL đầy đủ | |
| full_url = 'https://' + image_url | |
| return f''' | |
| <script>location.replace("{full_url}")</script> | |
| <noscript><meta http-equiv="refresh" content="0;url={full_url}"></noscript> | |
| ''', 200, {'Content-Type': 'text/html'} | |
| def placeholder(): | |
| img = base64.b64decode(PLACEHOLDER_BASE64) | |
| return Response(img, mimetype='image/png') | |
| if __name__ == '__main__': | |
| port = 7860 | |
| app.run(host='0.0.0.0', port=port) |