File size: 3,256 Bytes
182ca5a
 
 
a716f42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182ca5a
 
a716f42
 
182ca5a
a716f42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182ca5a
 
a716f42
 
 
 
182ca5a
 
a716f42
 
182ca5a
a716f42
 
 
 
 
 
 
182ca5a
 
a716f42
 
 
 
 
 
182ca5a
 
a716f42
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import gradio as gr
from sentence_transformers import SentenceTransformer, util

# cache models to avoid reloading
_cache = {}

def load_model(name: str):
    if name not in _cache:
        _cache[name] = SentenceTransformer(name)
    return _cache[name]

FAMILIES = ["auto", "plain (no prefix)", "e5", "gte"]

def prep(text: str, family: str):
    t = (text or "").strip()
    if family == "e5":
        # E5-style instruction prompt
        return f"query: {t}"
    elif family == "gte":
        # GTE typically works without a strict prefix, but include one for consistency
        return t
    else:
        return t

def compare(text_a: str, text_b: str, model_ids: str, family: str, normalize: bool):
    ids = [m.strip() for m in (model_ids or "").split(",") if m.strip()]
    if not text_a or not text_b or not ids:
        return []

    # auto-detect family if requested (simple heuristic)
    fam = family
    rows = []
    for mid in ids:
        try:
            if fam == "auto":
                if "e5" in mid.lower():
                    fam_use = "e5"
                elif "gte" in mid.lower():
                    fam_use = "gte"
                else:
                    fam_use = "plain (no prefix)"
            else:
                fam_use = family

            a_text = prep(text_a, fam_use)
            b_text = prep(text_b, fam_use)

            model = load_model(mid)
            a = model.encode(a_text, convert_to_tensor=True, normalize_embeddings=normalize)
            b = model.encode(b_text, convert_to_tensor=True, normalize_embeddings=normalize)
            score = util.cos_sim(a, b).item()
            rows.append([mid, fam_use, round(score, 6)])
        except Exception as e:
            rows.append([mid, "error", f"⚠️ {type(e).__name__}: {str(e)[:120]}"])
    # sort numeric scores first (desc), then errors
    def keyfn(r):
        try:
            return (0, float(r[2]) * -1)  # negative for descending
        except:
            return (1, 0)
    rows.sort(key=keyfn)
    return rows

with gr.Blocks(title="Medical Embedding Similarity") as demo:
    gr.Markdown("## 🩺 Embedding Similarity (Medical models)\n"
                "Enter two texts, paste one or more Hugging Face model IDs (comma-separated), "
                "and view cosine similarity per model.")

    with gr.Row():
        text_a = gr.Textbox(label="Text A", lines=3, placeholder="e.g., bone in the breast")
        text_b = gr.Textbox(label="Text B", lines=3, placeholder="e.g., sternum")

    model_ids = gr.Textbox(
        label="Model IDs (comma-separated)",
        placeholder="e.g., cambridgeltl/SapBERT-from-PubMedBERT-fulltext, pritamdeka/biomedical-sentence-transformers, sentence-transformers/all-MiniLM-L6-v2"
    )
    with gr.Row():
        family = gr.Dropdown(FAMILIES, value="auto", label="Prompt style")
        normalize = gr.Checkbox(True, label="Normalize embeddings (recommended for cosine)")

    btn = gr.Button("Compute similarity")
    out = gr.Dataframe(
        headers=["model", "family", "cosine_similarity / error"],
        datatype=["str", "str", "str"],
        wrap=True
    )
    btn.click(compare, [text_a, text_b, model_ids, family, normalize], out)

if __name__ == "__main__":
    demo.launch()