Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from sentence_transformers import SentenceTransformer, util | |
| # cache models to avoid reloading | |
| _cache = {} | |
| def load_model(name: str): | |
| if name not in _cache: | |
| _cache[name] = SentenceTransformer(name) | |
| return _cache[name] | |
| FAMILIES = ["auto", "plain (no prefix)", "e5", "gte"] | |
| def prep(text: str, family: str): | |
| t = (text or "").strip() | |
| if family == "e5": | |
| # E5-style instruction prompt | |
| return f"query: {t}" | |
| elif family == "gte": | |
| # GTE typically works without a strict prefix, but include one for consistency | |
| return t | |
| else: | |
| return t | |
| def compare(text_a: str, text_b: str, model_ids: str, family: str, normalize: bool): | |
| ids = [m.strip() for m in (model_ids or "").split(",") if m.strip()] | |
| if not text_a or not text_b or not ids: | |
| return [] | |
| # auto-detect family if requested (simple heuristic) | |
| fam = family | |
| rows = [] | |
| for mid in ids: | |
| try: | |
| if fam == "auto": | |
| if "e5" in mid.lower(): | |
| fam_use = "e5" | |
| elif "gte" in mid.lower(): | |
| fam_use = "gte" | |
| else: | |
| fam_use = "plain (no prefix)" | |
| else: | |
| fam_use = family | |
| a_text = prep(text_a, fam_use) | |
| b_text = prep(text_b, fam_use) | |
| model = load_model(mid) | |
| a = model.encode(a_text, convert_to_tensor=True, normalize_embeddings=normalize) | |
| b = model.encode(b_text, convert_to_tensor=True, normalize_embeddings=normalize) | |
| score = util.cos_sim(a, b).item() | |
| rows.append([mid, fam_use, round(score, 6)]) | |
| except Exception as e: | |
| rows.append([mid, "error", f"⚠️ {type(e).__name__}: {str(e)[:120]}"]) | |
| # sort numeric scores first (desc), then errors | |
| def keyfn(r): | |
| try: | |
| return (0, float(r[2]) * -1) # negative for descending | |
| except: | |
| return (1, 0) | |
| rows.sort(key=keyfn) | |
| return rows | |
| with gr.Blocks(title="Medical Embedding Similarity") as demo: | |
| gr.Markdown("## 🩺 Embedding Similarity (Medical models)\n" | |
| "Enter two texts, paste one or more Hugging Face model IDs (comma-separated), " | |
| "and view cosine similarity per model.") | |
| with gr.Row(): | |
| text_a = gr.Textbox(label="Text A", lines=3, placeholder="e.g., bone in the breast") | |
| text_b = gr.Textbox(label="Text B", lines=3, placeholder="e.g., sternum") | |
| model_ids = gr.Textbox( | |
| label="Model IDs (comma-separated)", | |
| placeholder="e.g., cambridgeltl/SapBERT-from-PubMedBERT-fulltext, pritamdeka/biomedical-sentence-transformers, sentence-transformers/all-MiniLM-L6-v2" | |
| ) | |
| with gr.Row(): | |
| family = gr.Dropdown(FAMILIES, value="auto", label="Prompt style") | |
| normalize = gr.Checkbox(True, label="Normalize embeddings (recommended for cosine)") | |
| btn = gr.Button("Compute similarity") | |
| out = gr.Dataframe( | |
| headers=["model", "family", "cosine_similarity / error"], | |
| datatype=["str", "str", "str"], | |
| wrap=True | |
| ) | |
| btn.click(compare, [text_a, text_b, model_ids, family, normalize], out) | |
| if __name__ == "__main__": | |
| demo.launch() |