import gradio as gr from sentence_transformers import SentenceTransformer, util # cache models to avoid reloading _cache = {} def load_model(name: str): if name not in _cache: _cache[name] = SentenceTransformer(name) return _cache[name] FAMILIES = ["auto", "plain (no prefix)", "e5", "gte"] def prep(text: str, family: str): t = (text or "").strip() if family == "e5": # E5-style instruction prompt return f"query: {t}" elif family == "gte": # GTE typically works without a strict prefix, but include one for consistency return t else: return t def compare(text_a: str, text_b: str, model_ids: str, family: str, normalize: bool): ids = [m.strip() for m in (model_ids or "").split(",") if m.strip()] if not text_a or not text_b or not ids: return [] # auto-detect family if requested (simple heuristic) fam = family rows = [] for mid in ids: try: if fam == "auto": if "e5" in mid.lower(): fam_use = "e5" elif "gte" in mid.lower(): fam_use = "gte" else: fam_use = "plain (no prefix)" else: fam_use = family a_text = prep(text_a, fam_use) b_text = prep(text_b, fam_use) model = load_model(mid) a = model.encode(a_text, convert_to_tensor=True, normalize_embeddings=normalize) b = model.encode(b_text, convert_to_tensor=True, normalize_embeddings=normalize) score = util.cos_sim(a, b).item() rows.append([mid, fam_use, round(score, 6)]) except Exception as e: rows.append([mid, "error", f"⚠️ {type(e).__name__}: {str(e)[:120]}"]) # sort numeric scores first (desc), then errors def keyfn(r): try: return (0, float(r[2]) * -1) # negative for descending except: return (1, 0) rows.sort(key=keyfn) return rows with gr.Blocks(title="Medical Embedding Similarity") as demo: gr.Markdown("## 🩺 Embedding Similarity (Medical models)\n" "Enter two texts, paste one or more Hugging Face model IDs (comma-separated), " "and view cosine similarity per model.") with gr.Row(): text_a = gr.Textbox(label="Text A", lines=3, placeholder="e.g., bone in the breast") text_b = gr.Textbox(label="Text B", lines=3, placeholder="e.g., sternum") model_ids = gr.Textbox( label="Model IDs (comma-separated)", placeholder="e.g., cambridgeltl/SapBERT-from-PubMedBERT-fulltext, pritamdeka/biomedical-sentence-transformers, sentence-transformers/all-MiniLM-L6-v2" ) with gr.Row(): family = gr.Dropdown(FAMILIES, value="auto", label="Prompt style") normalize = gr.Checkbox(True, label="Normalize embeddings (recommended for cosine)") btn = gr.Button("Compute similarity") out = gr.Dataframe( headers=["model", "family", "cosine_similarity / error"], datatype=["str", "str", "str"], wrap=True ) btn.click(compare, [text_a, text_b, model_ids, family, normalize], out) if __name__ == "__main__": demo.launch()