Spaces:
Build error
Build error
Update graphrag_agent.py
Browse files- graphrag_agent.py +68 -5
graphrag_agent.py
CHANGED
|
@@ -42,9 +42,17 @@ HF_TOKEN = os.getenv("HF_TOKEN")
|
|
| 42 |
Entrez.email = ENTREZ_EMAIL
|
| 43 |
MAX_TOKENS = 128000
|
| 44 |
encoding = tiktoken.get_encoding("cl100k_base")
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
# ======================== 全局变量 ========================
|
| 49 |
faiss_indices = {}
|
| 50 |
metadata = {}
|
|
@@ -147,6 +155,18 @@ def _extract_json_from_text(text: str) -> Dict[str, Any]:
|
|
| 147 |
return {}
|
| 148 |
return {}
|
| 149 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
def embed_entity_cloud(entity_text: str):
|
| 151 |
"""
|
| 152 |
使用 Hugging Face Inference API 获取 SapBERT 嵌入
|
|
@@ -163,7 +183,6 @@ def embed_entity_cloud(entity_text: str):
|
|
| 163 |
print(f"Embedding error: {e}")
|
| 164 |
return None
|
| 165 |
|
| 166 |
-
|
| 167 |
def search_pubmed(pubmed_query: str, max_results: int = 3) -> str:
|
| 168 |
try:
|
| 169 |
handle = Entrez.esearch(db="pubmed", term=pubmed_query, retmax=max_results)
|
|
@@ -864,7 +883,51 @@ def neo4j_retrieval(state: MyState):
|
|
| 864 |
logger.warning(f"'{entity}'failed in faiss {e}")
|
| 865 |
continue
|
| 866 |
|
| 867 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 868 |
|
| 869 |
def decide_router(state: MyState) -> dict:
|
| 870 |
print("---EDGE: decide_router---")
|
|
|
|
| 42 |
Entrez.email = ENTREZ_EMAIL
|
| 43 |
MAX_TOKENS = 128000
|
| 44 |
encoding = tiktoken.get_encoding("cl100k_base")
|
| 45 |
+
tokenizer = AutoTokenizer.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext",token=HF_TOKEN)
|
| 46 |
+
model = AutoModel.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext",token=HF_TOKEN).to(DEVICE)
|
| 47 |
+
model.eval()
|
| 48 |
+
|
| 49 |
+
bi_tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-m3",token=HF_TOKEN)
|
| 50 |
+
bi_model = AutoModel.from_pretrained("BAAI/bge-m3",token=HF_TOKEN)
|
| 51 |
+
bi_model.eval()
|
| 52 |
+
|
| 53 |
+
cross_tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-reranker-v2-m3",token=HF_TOKEN)
|
| 54 |
+
cross_model = AutoModelForSequenceClassification.from_pretrained("BAAI/bge-reranker-v2-m3",token=HF_TOKEN)
|
| 55 |
+
cross_model.eval()
|
| 56 |
# ======================== 全局变量 ========================
|
| 57 |
faiss_indices = {}
|
| 58 |
metadata = {}
|
|
|
|
| 155 |
return {}
|
| 156 |
return {}
|
| 157 |
|
| 158 |
+
def embed_entity(entity_text: str):
|
| 159 |
+
if not tokenizer or not model:
|
| 160 |
+
raise ValueError("embedding model not loaded")
|
| 161 |
+
with torch.no_grad():
|
| 162 |
+
inputs = tokenizer(
|
| 163 |
+
entity_text, return_tensors="pt",
|
| 164 |
+
padding=True, truncation=True, max_length=64
|
| 165 |
+
).to(DEVICE)
|
| 166 |
+
outputs = model(**inputs)
|
| 167 |
+
embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
|
| 168 |
+
return embedding
|
| 169 |
+
|
| 170 |
def embed_entity_cloud(entity_text: str):
|
| 171 |
"""
|
| 172 |
使用 Hugging Face Inference API 获取 SapBERT 嵌入
|
|
|
|
| 183 |
print(f"Embedding error: {e}")
|
| 184 |
return None
|
| 185 |
|
|
|
|
| 186 |
def search_pubmed(pubmed_query: str, max_results: int = 3) -> str:
|
| 187 |
try:
|
| 188 |
handle = Entrez.esearch(db="pubmed", term=pubmed_query, retmax=max_results)
|
|
|
|
| 883 |
logger.warning(f"'{entity}'failed in faiss {e}")
|
| 884 |
continue
|
| 885 |
|
| 886 |
+
try:
|
| 887 |
+
query_inputs = bi_tokenizer(query_text, return_tensors="pt", truncation=True, max_length=512,padding=True)
|
| 888 |
+
with torch.no_grad():
|
| 889 |
+
query_emb = bi_model(**query_inputs).last_hidden_state[:, 0]
|
| 890 |
+
query_emb = F.normalize(query_emb, dim=-1)
|
| 891 |
+
|
| 892 |
+
path_keys = list(path_kv.keys())
|
| 893 |
+
batch_size = 32
|
| 894 |
+
all_cand_embs = []
|
| 895 |
+
with torch.no_grad():
|
| 896 |
+
for i in range(0, len(path_keys), batch_size):
|
| 897 |
+
batch = path_keys[i:i + batch_size]
|
| 898 |
+
cand_inputs = bi_tokenizer(batch, return_tensors="pt", truncation=True, max_length=512,padding=True)
|
| 899 |
+
cand_embs_batch = bi_model(**cand_inputs).last_hidden_state[:, 0]
|
| 900 |
+
cand_embs_batch = F.normalize(cand_embs_batch, dim=-1)
|
| 901 |
+
all_cand_embs.append(cand_embs_batch)
|
| 902 |
+
|
| 903 |
+
cand_embs = torch.cat(all_cand_embs, dim=0)
|
| 904 |
+
sim_scores = torch.matmul(query_emb, cand_embs.T).squeeze(0).tolist()
|
| 905 |
+
scored_paths = list(zip(path_keys, sim_scores))
|
| 906 |
+
scored_paths.sort(key=lambda x: x[1], reverse=True)
|
| 907 |
+
|
| 908 |
+
top100 = scored_paths[:100]
|
| 909 |
+
pairs = [(query_text, pk) for pk, _ in top100]
|
| 910 |
+
all_cross_scores = []
|
| 911 |
+
cross_batch_size = 16
|
| 912 |
+
with torch.no_grad():
|
| 913 |
+
for i in range(0, len(pairs), cross_batch_size):
|
| 914 |
+
batch_pairs = pairs[i:i + cross_batch_size]
|
| 915 |
+
inputs = cross_tokenizer(batch_pairs, padding=True, truncation=True, max_length=512,return_tensors="pt")
|
| 916 |
+
scores = cross_model(**inputs).logits.view(-1).tolist()
|
| 917 |
+
all_cross_scores.extend(scores)
|
| 918 |
+
|
| 919 |
+
rerank_final = list(zip([p[0] for p in top100], all_cross_scores))
|
| 920 |
+
rerank_final.sort(key=lambda x: x[1], reverse=True)
|
| 921 |
+
top30 = rerank_final[:30]
|
| 922 |
+
|
| 923 |
+
top30_values = [path_kv[pk] for pk, _ in top30]
|
| 924 |
+
logger.info(f"Cross-encoder reranked 30 path: {top30_values}")
|
| 925 |
+
return {"neo4j_retrieval": top30_values}
|
| 926 |
+
|
| 927 |
+
except Exception as e:
|
| 928 |
+
logger.warning(f"rerank error: {e}")
|
| 929 |
+
fallback_values = list(path_kv.values())[:50]
|
| 930 |
+
return {"neo4j_retrieval": fallback_values}
|
| 931 |
|
| 932 |
def decide_router(state: MyState) -> dict:
|
| 933 |
print("---EDGE: decide_router---")
|