Spaces:
Build error
Build error
Update graphrag_agent.py
Browse files- graphrag_agent.py +70 -79
graphrag_agent.py
CHANGED
|
@@ -45,23 +45,20 @@ HF_TOKEN = os.getenv("HF_TOKEN")
|
|
| 45 |
Entrez.email = ENTREZ_EMAIL
|
| 46 |
MAX_TOKENS = 128000
|
| 47 |
encoding = tiktoken.get_encoding("cl100k_base")
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
sapbert_client = InferenceClient(
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
)
|
| 63 |
-
bge_client = InferenceClient(provider="hf-inference", api_key=HF_TOKEN)
|
| 64 |
-
cross_client = InferenceClient(provider="hf-inference", api_key=HF_TOKEN)
|
| 65 |
# ======================== 全局变量 ========================
|
| 66 |
faiss_indices = {}
|
| 67 |
metadata = {}
|
|
@@ -164,7 +161,7 @@ def _extract_json_from_text(text: str) -> Dict[str, Any]:
|
|
| 164 |
return {}
|
| 165 |
return {}
|
| 166 |
|
| 167 |
-
|
| 168 |
if not tokenizer or not model:
|
| 169 |
raise ValueError("embedding model not loaded")
|
| 170 |
with torch.no_grad():
|
|
@@ -176,21 +173,15 @@ def _extract_json_from_text(text: str) -> Dict[str, Any]:
|
|
| 176 |
embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
|
| 177 |
return embedding
|
| 178 |
|
| 179 |
-
def embed_entity(entity_text: str):
|
| 180 |
-
"""
|
| 181 |
-
|
| 182 |
-
"""
|
| 183 |
-
try:
|
| 184 |
-
result = sapbert_client.feature_extraction(
|
| 185 |
-
entity_text,
|
| 186 |
-
model="cambridgeltl/SapBERT-from-PubMedBERT-fulltext"
|
| 187 |
-
)
|
| 188 |
# 返回结果通常是 list[list[float]],取平均或第一 token
|
| 189 |
-
embedding = [sum(x)/len(x) for x in zip(*result)] # 对每个维度求平均
|
| 190 |
-
return embedding
|
| 191 |
-
except Exception as e:
|
| 192 |
-
print(f"Embedding error: {e}")
|
| 193 |
-
return None
|
| 194 |
|
| 195 |
def search_pubmed(pubmed_query: str, max_results: int = 3) -> str:
|
| 196 |
try:
|
|
@@ -702,7 +693,7 @@ def whether_to_interact(state):
|
|
| 702 |
return "user_input"
|
| 703 |
elif interaction == "sufficient":
|
| 704 |
print("决策: 信息充分,进入Neo4j检索。")
|
| 705 |
-
return "
|
| 706 |
else:
|
| 707 |
return "stop_flow"
|
| 708 |
|
|
@@ -722,51 +713,7 @@ with ZipFile(zip_path, "r") as zip_ref:
|
|
| 722 |
zip_ref.extractall("data/")
|
| 723 |
|
| 724 |
print("✅ 已成功下载并解压 data.zip")
|
| 725 |
-
def rerank_paths_cloud(query_text, path_kv):
|
| 726 |
-
try:
|
| 727 |
-
# 1. query embedding
|
| 728 |
-
query_emb = bge_client.feature_extraction(query_text, model="BAAI/bge-m3")
|
| 729 |
-
query_emb = torch.tensor(query_emb[0]).unsqueeze(0)
|
| 730 |
-
query_emb = F.normalize(query_emb, dim=-1)
|
| 731 |
-
|
| 732 |
-
# 2. path embeddings
|
| 733 |
-
path_keys = list(path_kv.keys())
|
| 734 |
-
all_cand_embs = []
|
| 735 |
-
for pk in path_keys:
|
| 736 |
-
cand_emb = bge_client.feature_extraction(pk, model="BAAI/bge-m3")
|
| 737 |
-
emb_tensor = torch.tensor(cand_emb[0]).unsqueeze(0)
|
| 738 |
-
emb_tensor = F.normalize(emb_tensor, dim=-1)
|
| 739 |
-
all_cand_embs.append(emb_tensor)
|
| 740 |
-
|
| 741 |
-
cand_embs = torch.cat(all_cand_embs, dim=0)
|
| 742 |
-
sim_scores = torch.matmul(query_emb, cand_embs.T).squeeze(0).tolist()
|
| 743 |
-
|
| 744 |
-
scored_paths = list(zip(path_keys, sim_scores))
|
| 745 |
-
scored_paths.sort(key=lambda x: x[1], reverse=True)
|
| 746 |
-
top100 = scored_paths[:100]
|
| 747 |
-
|
| 748 |
-
# 3. cross-encoder rerank
|
| 749 |
-
pairs = [(query_text, pk) for pk, _ in top100]
|
| 750 |
-
all_cross_scores = []
|
| 751 |
-
for q, pk in pairs:
|
| 752 |
-
input_pair = [(q, pk)]
|
| 753 |
-
scores = cross_client.text_classification(
|
| 754 |
-
input_pair,
|
| 755 |
-
model="BAAI/bge-reranker-v2-m3"
|
| 756 |
-
)
|
| 757 |
-
all_cross_scores.append(scores[0]["score"])
|
| 758 |
-
|
| 759 |
-
rerank_final = list(zip([p[1] for p in top100], all_cross_scores))
|
| 760 |
-
rerank_final.sort(key=lambda x: x[1], reverse=True)
|
| 761 |
-
top30 = rerank_final[:30]
|
| 762 |
-
|
| 763 |
-
top30_values = [path_kv[pk] for pk, _ in top30]
|
| 764 |
-
return {"neo4j_retrieval": top30_values}
|
| 765 |
|
| 766 |
-
except Exception as e:
|
| 767 |
-
print(f"rerank error: {e}")
|
| 768 |
-
fallback_values = list(path_kv.values())[:50]
|
| 769 |
-
return {"neo4j_retrieval": fallback_values}
|
| 770 |
def neo4j_retrieval(state: MyState):
|
| 771 |
logger.info("---NODE: neo4j_retrieval---")
|
| 772 |
#user_query = [message.content for message in state["messages"] if hasattr(message, 'content')]
|
|
@@ -892,7 +839,51 @@ def neo4j_retrieval(state: MyState):
|
|
| 892 |
except Exception as e:
|
| 893 |
logger.warning(f"'{entity}'failed in faiss {e}")
|
| 894 |
continue
|
| 895 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 896 |
|
| 897 |
|
| 898 |
def decide_router(state: MyState) -> dict:
|
|
@@ -1009,7 +1000,7 @@ def build_graphrag_agent():
|
|
| 1009 |
builder = StateGraph(MyState)
|
| 1010 |
builder.add_node("parse_query", parse_query)
|
| 1011 |
builder.add_node("user_input", user_input)
|
| 1012 |
-
builder.add_node("
|
| 1013 |
builder.add_node("decide_router", decide_router)
|
| 1014 |
builder.add_node("api_search", api_search)
|
| 1015 |
builder.add_node("llm_answer", llm_answer)
|
|
@@ -1020,11 +1011,11 @@ def build_graphrag_agent():
|
|
| 1020 |
whether_to_interact,
|
| 1021 |
{
|
| 1022 |
"user_input": "user_input",
|
| 1023 |
-
"
|
| 1024 |
}
|
| 1025 |
)
|
| 1026 |
builder.add_edge("user_input", "parse_query")
|
| 1027 |
-
builder.add_edge("
|
| 1028 |
builder.add_conditional_edges(
|
| 1029 |
"decide_router",
|
| 1030 |
lambda state: state["route"],
|
|
|
|
| 45 |
Entrez.email = ENTREZ_EMAIL
|
| 46 |
MAX_TOKENS = 128000
|
| 47 |
encoding = tiktoken.get_encoding("cl100k_base")
|
| 48 |
+
tokenizer = AutoTokenizer.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext",token=HF_TOKEN)
|
| 49 |
+
model = AutoModel.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext",token=HF_TOKEN).to(DEVICE)
|
| 50 |
+
model.eval()
|
| 51 |
|
| 52 |
+
bi_tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-m3",token=HF_TOKEN)
|
| 53 |
+
bi_model = AutoModel.from_pretrained("BAAI/bge-m3",token=HF_TOKEN)
|
| 54 |
+
bi_model.eval()
|
| 55 |
|
| 56 |
+
cross_tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-reranker-v2-m3",token=HF_TOKEN)
|
| 57 |
+
cross_model = AutoModelForSequenceClassification.from_pretrained("BAAI/bge-reranker-v2-m3",token=HF_TOKEN)
|
| 58 |
+
cross_model.eval()
|
| 59 |
+
#sapbert_client = InferenceClient(provider="hf-inference",api_key=HF_TOKEN)
|
| 60 |
+
#bge_client = InferenceClient(provider="hf-inference", api_key=HF_TOKEN)
|
| 61 |
+
#cross_client = InferenceClient(provider="hf-inference", api_key=HF_TOKEN)
|
|
|
|
|
|
|
|
|
|
| 62 |
# ======================== 全局变量 ========================
|
| 63 |
faiss_indices = {}
|
| 64 |
metadata = {}
|
|
|
|
| 161 |
return {}
|
| 162 |
return {}
|
| 163 |
|
| 164 |
+
def embed_entity(entity_text: str):
|
| 165 |
if not tokenizer or not model:
|
| 166 |
raise ValueError("embedding model not loaded")
|
| 167 |
with torch.no_grad():
|
|
|
|
| 173 |
embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
|
| 174 |
return embedding
|
| 175 |
|
| 176 |
+
#def embed_entity(entity_text: str):
|
| 177 |
+
# """使用 Hugging Face Inference API 获取 SapBERT 嵌入"""
|
| 178 |
+
# try:result = sapbert_client.feature_extraction(entity_text,model="cambridgeltl/SapBERT-from-PubMedBERT-fulltext")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
# 返回结果通常是 list[list[float]],取平均或第一 token
|
| 180 |
+
# embedding = [sum(x)/len(x) for x in zip(*result)] # 对每个维度求平均
|
| 181 |
+
# return embedding
|
| 182 |
+
# except Exception as e:
|
| 183 |
+
# print(f"Embedding error: {e}")
|
| 184 |
+
# return None
|
| 185 |
|
| 186 |
def search_pubmed(pubmed_query: str, max_results: int = 3) -> str:
|
| 187 |
try:
|
|
|
|
| 693 |
return "user_input"
|
| 694 |
elif interaction == "sufficient":
|
| 695 |
print("决策: 信息充分,进入Neo4j检索。")
|
| 696 |
+
return "kg_retrieval"
|
| 697 |
else:
|
| 698 |
return "stop_flow"
|
| 699 |
|
|
|
|
| 713 |
zip_ref.extractall("data/")
|
| 714 |
|
| 715 |
print("✅ 已成功下载并解压 data.zip")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 716 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 717 |
def neo4j_retrieval(state: MyState):
|
| 718 |
logger.info("---NODE: neo4j_retrieval---")
|
| 719 |
#user_query = [message.content for message in state["messages"] if hasattr(message, 'content')]
|
|
|
|
| 839 |
except Exception as e:
|
| 840 |
logger.warning(f"'{entity}'failed in faiss {e}")
|
| 841 |
continue
|
| 842 |
+
try:
|
| 843 |
+
query_inputs = bi_tokenizer(query_text, return_tensors="pt", truncation=True, max_length=512,padding=True)
|
| 844 |
+
with torch.no_grad():
|
| 845 |
+
query_emb = bi_model(**query_inputs).last_hidden_state[:, 0]
|
| 846 |
+
query_emb = F.normalize(query_emb, dim=-1)
|
| 847 |
+
|
| 848 |
+
path_keys = list(path_kv.keys())
|
| 849 |
+
batch_size = 32
|
| 850 |
+
all_cand_embs = []
|
| 851 |
+
with torch.no_grad():
|
| 852 |
+
for i in range(0, len(path_keys), batch_size):
|
| 853 |
+
batch = path_keys[i:i + batch_size]
|
| 854 |
+
cand_inputs = bi_tokenizer(batch, return_tensors="pt", truncation=True, max_length=512,padding=True)
|
| 855 |
+
cand_embs_batch = bi_model(**cand_inputs).last_hidden_state[:, 0]
|
| 856 |
+
cand_embs_batch = F.normalize(cand_embs_batch, dim=-1)
|
| 857 |
+
all_cand_embs.append(cand_embs_batch)
|
| 858 |
+
|
| 859 |
+
cand_embs = torch.cat(all_cand_embs, dim=0)
|
| 860 |
+
sim_scores = torch.matmul(query_emb, cand_embs.T).squeeze(0).tolist()
|
| 861 |
+
scored_paths = list(zip(path_keys, sim_scores))
|
| 862 |
+
scored_paths.sort(key=lambda x: x[1], reverse=True)
|
| 863 |
+
|
| 864 |
+
top100 = scored_paths[:100]
|
| 865 |
+
pairs = [(query_text, pk) for pk, _ in top100]
|
| 866 |
+
all_cross_scores = []
|
| 867 |
+
cross_batch_size = 16
|
| 868 |
+
with torch.no_grad():
|
| 869 |
+
for i in range(0, len(pairs), cross_batch_size):
|
| 870 |
+
batch_pairs = pairs[i:i + cross_batch_size]
|
| 871 |
+
inputs = cross_tokenizer(batch_pairs, padding=True, truncation=True, max_length=512,return_tensors="pt")
|
| 872 |
+
scores = cross_model(**inputs).logits.view(-1).tolist()
|
| 873 |
+
all_cross_scores.extend(scores)
|
| 874 |
+
|
| 875 |
+
rerank_final = list(zip([p[0] for p in top100], all_cross_scores))
|
| 876 |
+
rerank_final.sort(key=lambda x: x[1], reverse=True)
|
| 877 |
+
top30 = rerank_final[:30]
|
| 878 |
+
|
| 879 |
+
top30_values = [path_kv[pk] for pk, _ in top30]
|
| 880 |
+
logger.info(f"Cross-encoder reranked 30 path: {top30_values}")
|
| 881 |
+
return {"neo4j_retrieval": top30_values}
|
| 882 |
+
|
| 883 |
+
except Exception as e:
|
| 884 |
+
logger.warning(f"rerank error: {e}")
|
| 885 |
+
fallback_values = list(path_kv.values())[:50]
|
| 886 |
+
return {"neo4j_retrieval": fallback_values}
|
| 887 |
|
| 888 |
|
| 889 |
def decide_router(state: MyState) -> dict:
|
|
|
|
| 1000 |
builder = StateGraph(MyState)
|
| 1001 |
builder.add_node("parse_query", parse_query)
|
| 1002 |
builder.add_node("user_input", user_input)
|
| 1003 |
+
builder.add_node("kg_retrieval", neo4j_retrieval)
|
| 1004 |
builder.add_node("decide_router", decide_router)
|
| 1005 |
builder.add_node("api_search", api_search)
|
| 1006 |
builder.add_node("llm_answer", llm_answer)
|
|
|
|
| 1011 |
whether_to_interact,
|
| 1012 |
{
|
| 1013 |
"user_input": "user_input",
|
| 1014 |
+
"kg_retrieval": "kg_retrieval"
|
| 1015 |
}
|
| 1016 |
)
|
| 1017 |
builder.add_edge("user_input", "parse_query")
|
| 1018 |
+
builder.add_edge("kg_retrieval", "decide_router")
|
| 1019 |
builder.add_conditional_edges(
|
| 1020 |
"decide_router",
|
| 1021 |
lambda state: state["route"],
|