import os import numpy as np import h5py import hnswlib import torch import open_clip import torch from flask import Flask, request, jsonify from flask_cors import CORS from PIL import Image import requests import io import base64 from huggingface_hub import hf_hub_download from flask import Response, send_file import tempfile PREFETCH_IMAGES = True # bật lên cho nhanh PLACEHOLDER_BASE64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg==" app = Flask(__name__) CORS(app, origins=['*']) print("\n" + "="*50) print("📥 INITIALIZING MEDICAL SERVER...") print("="*50) # Cấu hình Dataset HF_TOKEN = os.environ.get("HF_TOKEN") DATASET_ID = "huynguyen6906/Medical_server_data" # Tải file từ Hugging Face Dataset try: print(f"Downloading data from {DATASET_ID}...") H5_FILE_PATH = hf_hub_download(repo_id=DATASET_ID, filename="Medical_Embedded.h5", repo_type="dataset", token=HF_TOKEN) BIN_FILE_PATH = hf_hub_download(repo_id=DATASET_ID, filename="Medical_Embedded.bin", repo_type="dataset", token=HF_TOKEN) print(f"✅ Data loaded: {H5_FILE_PATH}") except Exception as e: print(f"❌ Error downloading data: {str(e)}") H5_FILE_PATH = 'Medical_Embedded.h5' BIN_FILE_PATH = 'Medical_Embedded.bin' class ImageSearchEngine: def __init__(self, h5_file_path=H5_FILE_PATH): print("Initializing Search Engine...") self.device = "cuda" if torch.cuda.is_available() else "cpu" print("Loading BiomedCLIP-PubMedBERT_256-vit_base_patch16_224...") self.model, preprocess_train, self.preprocess = open_clip.create_model_and_transforms('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224') self.tokenizer = open_clip.get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224') if not os.path.exists(h5_file_path): # Tạo file giả nếu không có để server không crash ngay (giúp debug) print("⚠️ H5 file not found. Running in empty mode.") self.max_elements = 0 self.dim = 512 return self.h5_file = h5py.File(h5_file_path, 'r') self.dim = self.h5_file['embeddings'].shape[1] self.max_elements = len(self.h5_file['urls']) print(f"Loaded {self.max_elements} image embeddings. Dim: {self.dim}") self.index = hnswlib.Index(space='cosine', dim=self.dim) if os.path.exists(BIN_FILE_PATH): print(f"⚡ Loading Index from {BIN_FILE_PATH}...") self.index.load_index(BIN_FILE_PATH, max_elements=self.max_elements) self.index.set_ef(400) else: print("⚠️ BIN file not found.") def text_to_vector(self, text): if isinstance(text, str): text = [text] tokens = self.tokenizer(text).to(self.device) with torch.no_grad(): text_features = self.model.encode_text(tokens) text_features = text_features / text_features.norm(dim=-1, keepdim=True) return text_features.cpu().numpy() def image_to_vector(self, image): image_tensor = self.preprocess(image).unsqueeze(0).to(self.device) with torch.no_grad(): image_features = self.model.encode_image(image_tensor) image_features = image_features / image_features.norm(dim=-1, keepdim=True) return image_features.cpu().numpy().astype(np.float32)[0] def search(self, vector, k=10): if self.max_elements == 0: return [] indices, distances = self.index.knn_query(vector, k=k) results = [] for idx, dist in zip(indices[0], distances[0]): url_bytes = self.h5_file['urls'][idx] url = url_bytes.decode('utf-8') if isinstance(url_bytes, bytes) else str(url_bytes) url = url.strip() result = { 'path': url, 'url': url, 'score': float(1 - dist) } # Nếu bật prefetch → gửi thẳng URL (frontend sẽ dùng /i/ để load cực nhanh) if PREFETCH_IMAGES: result['image_data'] = url # không cần base64 nữa! results.append(result) return results search_engine = ImageSearchEngine() # --- ROUTES --- @app.route('/health', methods=['GET']) def health_check(): return jsonify({'status': 'healthy', 'total_images': search_engine.max_elements}) @app.route('/search', methods=['POST']) def search_text(): try: data = request.get_json() query = data.get('query', '') k = int(data.get('k', 20)) vector = search_engine.text_to_vector(query) results = search_engine.search(vector, k=k) return jsonify({'results': results}) except Exception as e: return jsonify({'error': str(e)}), 500 @app.route('/search/image', methods=['POST']) def search_image_file(): try: if 'image' not in request.files: return jsonify({'error': 'No image provided'}), 400 file = request.files['image'] k = int(request.form.get('k', 20)) img = Image.open(file.stream).convert('RGB') vector = search_engine.image_to_vector(img) results = search_engine.search(vector, k=k) return jsonify({'results': results}) except Exception as e: return jsonify({'error': str(e)}), 500 @app.route('/i/') def fast_proxy(image_url): """ URL đã có sẵn https:// → chỉ cần redirect thẳng, không cần kiểm tra gì thêm Ví dụ: /i/i.redd.it/abc123.jpg → https://i.redd.it/abc123.jpg /i/pbs.twimg.com/media/xyz.jpg → https://pbs.twimg.com/media/xyz.jpg """ # image_url là phần sau /i/ → ghép lại thành URL đầy đủ full_url = 'https://' + image_url return f''' ''', 200, {'Content-Type': 'text/html'} @app.route('/placeholder') def placeholder(): img = base64.b64decode(PLACEHOLDER_BASE64) return Response(img, mimetype='image/png') if __name__ == '__main__': port = 7860 app.run(host='0.0.0.0', port=port)