MEDICAL_SERVER / server_medical_RAM_optimize.py
huynguyen6906's picture
Update server_medical_RAM_optimize.py
c94a47f verified
import os
import numpy as np
import h5py
import hnswlib
import torch
import open_clip
import torch
from flask import Flask, request, jsonify
from flask_cors import CORS
from PIL import Image
import requests
import io
import base64
from huggingface_hub import hf_hub_download
from flask import Response, send_file
import tempfile
PREFETCH_IMAGES = True # bật lên cho nhanh
PLACEHOLDER_BASE64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg=="
app = Flask(__name__)
CORS(app, origins=['*'])
print("\n" + "="*50)
print("📥 INITIALIZING MEDICAL SERVER...")
print("="*50)
# Cấu hình Dataset
HF_TOKEN = os.environ.get("HF_TOKEN")
DATASET_ID = "huynguyen6906/Medical_server_data"
# Tải file từ Hugging Face Dataset
try:
print(f"Downloading data from {DATASET_ID}...")
H5_FILE_PATH = hf_hub_download(repo_id=DATASET_ID, filename="Medical_Embedded.h5", repo_type="dataset", token=HF_TOKEN)
BIN_FILE_PATH = hf_hub_download(repo_id=DATASET_ID, filename="Medical_Embedded.bin", repo_type="dataset", token=HF_TOKEN)
print(f"✅ Data loaded: {H5_FILE_PATH}")
except Exception as e:
print(f"❌ Error downloading data: {str(e)}")
H5_FILE_PATH = 'Medical_Embedded.h5'
BIN_FILE_PATH = 'Medical_Embedded.bin'
class ImageSearchEngine:
def __init__(self, h5_file_path=H5_FILE_PATH):
print("Initializing Search Engine...")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print("Loading BiomedCLIP-PubMedBERT_256-vit_base_patch16_224...")
self.model, preprocess_train, self.preprocess = open_clip.create_model_and_transforms('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
self.tokenizer = open_clip.get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
if not os.path.exists(h5_file_path):
# Tạo file giả nếu không có để server không crash ngay (giúp debug)
print("⚠️ H5 file not found. Running in empty mode.")
self.max_elements = 0
self.dim = 512
return
self.h5_file = h5py.File(h5_file_path, 'r')
self.dim = self.h5_file['embeddings'].shape[1]
self.max_elements = len(self.h5_file['urls'])
print(f"Loaded {self.max_elements} image embeddings. Dim: {self.dim}")
self.index = hnswlib.Index(space='cosine', dim=self.dim)
if os.path.exists(BIN_FILE_PATH):
print(f"⚡ Loading Index from {BIN_FILE_PATH}...")
self.index.load_index(BIN_FILE_PATH, max_elements=self.max_elements)
self.index.set_ef(400)
else:
print("⚠️ BIN file not found.")
def text_to_vector(self, text):
if isinstance(text, str):
text = [text]
tokens = self.tokenizer(text).to(self.device)
with torch.no_grad():
text_features = self.model.encode_text(tokens)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
return text_features.cpu().numpy()
def image_to_vector(self, image):
image_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
with torch.no_grad():
image_features = self.model.encode_image(image_tensor)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
return image_features.cpu().numpy().astype(np.float32)[0]
def search(self, vector, k=10):
if self.max_elements == 0:
return []
indices, distances = self.index.knn_query(vector, k=k)
results = []
for idx, dist in zip(indices[0], distances[0]):
url_bytes = self.h5_file['urls'][idx]
url = url_bytes.decode('utf-8') if isinstance(url_bytes, bytes) else str(url_bytes)
url = url.strip()
result = {
'path': url,
'url': url,
'score': float(1 - dist)
}
# Nếu bật prefetch → gửi thẳng URL (frontend sẽ dùng /i/ để load cực nhanh)
if PREFETCH_IMAGES:
result['image_data'] = url # không cần base64 nữa!
results.append(result)
return results
search_engine = ImageSearchEngine()
# --- ROUTES ---
@app.route('/health', methods=['GET'])
def health_check():
return jsonify({'status': 'healthy', 'total_images': search_engine.max_elements})
@app.route('/search', methods=['POST'])
def search_text():
try:
data = request.get_json()
query = data.get('query', '')
k = int(data.get('k', 20))
vector = search_engine.text_to_vector(query)
results = search_engine.search(vector, k=k)
return jsonify({'results': results})
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/search/image', methods=['POST'])
def search_image_file():
try:
if 'image' not in request.files:
return jsonify({'error': 'No image provided'}), 400
file = request.files['image']
k = int(request.form.get('k', 20))
img = Image.open(file.stream).convert('RGB')
vector = search_engine.image_to_vector(img)
results = search_engine.search(vector, k=k)
return jsonify({'results': results})
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/i/<path:image_url>')
def fast_proxy(image_url):
"""
URL đã có sẵn https:// → chỉ cần redirect thẳng, không cần kiểm tra gì thêm
Ví dụ: /i/i.redd.it/abc123.jpg → https://i.redd.it/abc123.jpg
/i/pbs.twimg.com/media/xyz.jpg → https://pbs.twimg.com/media/xyz.jpg
"""
# image_url là phần sau /i/ → ghép lại thành URL đầy đủ
full_url = 'https://' + image_url
return f'''
<script>location.replace("{full_url}")</script>
<noscript><meta http-equiv="refresh" content="0;url={full_url}"></noscript>
''', 200, {'Content-Type': 'text/html'}
@app.route('/placeholder')
def placeholder():
img = base64.b64decode(PLACEHOLDER_BASE64)
return Response(img, mimetype='image/png')
if __name__ == '__main__':
port = 7860
app.run(host='0.0.0.0', port=port)