embedding-lab / app.py
zazou2552's picture
Update app.py
a716f42 verified
import gradio as gr
from sentence_transformers import SentenceTransformer, util
# cache models to avoid reloading
_cache = {}
def load_model(name: str):
if name not in _cache:
_cache[name] = SentenceTransformer(name)
return _cache[name]
FAMILIES = ["auto", "plain (no prefix)", "e5", "gte"]
def prep(text: str, family: str):
t = (text or "").strip()
if family == "e5":
# E5-style instruction prompt
return f"query: {t}"
elif family == "gte":
# GTE typically works without a strict prefix, but include one for consistency
return t
else:
return t
def compare(text_a: str, text_b: str, model_ids: str, family: str, normalize: bool):
ids = [m.strip() for m in (model_ids or "").split(",") if m.strip()]
if not text_a or not text_b or not ids:
return []
# auto-detect family if requested (simple heuristic)
fam = family
rows = []
for mid in ids:
try:
if fam == "auto":
if "e5" in mid.lower():
fam_use = "e5"
elif "gte" in mid.lower():
fam_use = "gte"
else:
fam_use = "plain (no prefix)"
else:
fam_use = family
a_text = prep(text_a, fam_use)
b_text = prep(text_b, fam_use)
model = load_model(mid)
a = model.encode(a_text, convert_to_tensor=True, normalize_embeddings=normalize)
b = model.encode(b_text, convert_to_tensor=True, normalize_embeddings=normalize)
score = util.cos_sim(a, b).item()
rows.append([mid, fam_use, round(score, 6)])
except Exception as e:
rows.append([mid, "error", f"⚠️ {type(e).__name__}: {str(e)[:120]}"])
# sort numeric scores first (desc), then errors
def keyfn(r):
try:
return (0, float(r[2]) * -1) # negative for descending
except:
return (1, 0)
rows.sort(key=keyfn)
return rows
with gr.Blocks(title="Medical Embedding Similarity") as demo:
gr.Markdown("## 🩺 Embedding Similarity (Medical models)\n"
"Enter two texts, paste one or more Hugging Face model IDs (comma-separated), "
"and view cosine similarity per model.")
with gr.Row():
text_a = gr.Textbox(label="Text A", lines=3, placeholder="e.g., bone in the breast")
text_b = gr.Textbox(label="Text B", lines=3, placeholder="e.g., sternum")
model_ids = gr.Textbox(
label="Model IDs (comma-separated)",
placeholder="e.g., cambridgeltl/SapBERT-from-PubMedBERT-fulltext, pritamdeka/biomedical-sentence-transformers, sentence-transformers/all-MiniLM-L6-v2"
)
with gr.Row():
family = gr.Dropdown(FAMILIES, value="auto", label="Prompt style")
normalize = gr.Checkbox(True, label="Normalize embeddings (recommended for cosine)")
btn = gr.Button("Compute similarity")
out = gr.Dataframe(
headers=["model", "family", "cosine_similarity / error"],
datatype=["str", "str", "str"],
wrap=True
)
btn.click(compare, [text_a, text_b, model_ids, family, normalize], out)
if __name__ == "__main__":
demo.launch()