achenyx1412 commited on
Commit
235e8c2
·
verified ·
1 Parent(s): 55d38fb

Update graphrag_agent.py

Browse files
Files changed (1) hide show
  1. graphrag_agent.py +70 -79
graphrag_agent.py CHANGED
@@ -45,23 +45,20 @@ HF_TOKEN = os.getenv("HF_TOKEN")
45
  Entrez.email = ENTREZ_EMAIL
46
  MAX_TOKENS = 128000
47
  encoding = tiktoken.get_encoding("cl100k_base")
48
- # tokenizer = AutoTokenizer.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext",token=HF_TOKEN)
49
- # model = AutoModel.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext",token=HF_TOKEN).to(DEVICE)
50
- # model.eval()
51
 
52
- # bi_tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-m3",token=HF_TOKEN)
53
- # bi_model = AutoModel.from_pretrained("BAAI/bge-m3",token=HF_TOKEN)
54
- # bi_model.eval()
55
 
56
- # cross_tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-reranker-v2-m3",token=HF_TOKEN)
57
- # cross_model = AutoModelForSequenceClassification.from_pretrained("BAAI/bge-reranker-v2-m3",token=HF_TOKEN)
58
- # cross_model.eval()
59
- sapbert_client = InferenceClient(
60
- provider="hf-inference",
61
- api_key=HF_TOKEN,
62
- )
63
- bge_client = InferenceClient(provider="hf-inference", api_key=HF_TOKEN)
64
- cross_client = InferenceClient(provider="hf-inference", api_key=HF_TOKEN)
65
  # ======================== 全局变量 ========================
66
  faiss_indices = {}
67
  metadata = {}
@@ -164,7 +161,7 @@ def _extract_json_from_text(text: str) -> Dict[str, Any]:
164
  return {}
165
  return {}
166
 
167
- # def embed_entity(entity_text: str):
168
  if not tokenizer or not model:
169
  raise ValueError("embedding model not loaded")
170
  with torch.no_grad():
@@ -176,21 +173,15 @@ def _extract_json_from_text(text: str) -> Dict[str, Any]:
176
  embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
177
  return embedding
178
 
179
- def embed_entity(entity_text: str):
180
- """
181
- 使用 Hugging Face Inference API 获取 SapBERT 嵌入
182
- """
183
- try:
184
- result = sapbert_client.feature_extraction(
185
- entity_text,
186
- model="cambridgeltl/SapBERT-from-PubMedBERT-fulltext"
187
- )
188
  # 返回结果通常是 list[list[float]],取平均或第一 token
189
- embedding = [sum(x)/len(x) for x in zip(*result)] # 对每个维度求平均
190
- return embedding
191
- except Exception as e:
192
- print(f"Embedding error: {e}")
193
- return None
194
 
195
  def search_pubmed(pubmed_query: str, max_results: int = 3) -> str:
196
  try:
@@ -702,7 +693,7 @@ def whether_to_interact(state):
702
  return "user_input"
703
  elif interaction == "sufficient":
704
  print("决策: 信息充分,进入Neo4j检索。")
705
- return "neo4j_retrieval"
706
  else:
707
  return "stop_flow"
708
 
@@ -722,51 +713,7 @@ with ZipFile(zip_path, "r") as zip_ref:
722
  zip_ref.extractall("data/")
723
 
724
  print("✅ 已成功下载并解压 data.zip")
725
- def rerank_paths_cloud(query_text, path_kv):
726
- try:
727
- # 1. query embedding
728
- query_emb = bge_client.feature_extraction(query_text, model="BAAI/bge-m3")
729
- query_emb = torch.tensor(query_emb[0]).unsqueeze(0)
730
- query_emb = F.normalize(query_emb, dim=-1)
731
-
732
- # 2. path embeddings
733
- path_keys = list(path_kv.keys())
734
- all_cand_embs = []
735
- for pk in path_keys:
736
- cand_emb = bge_client.feature_extraction(pk, model="BAAI/bge-m3")
737
- emb_tensor = torch.tensor(cand_emb[0]).unsqueeze(0)
738
- emb_tensor = F.normalize(emb_tensor, dim=-1)
739
- all_cand_embs.append(emb_tensor)
740
-
741
- cand_embs = torch.cat(all_cand_embs, dim=0)
742
- sim_scores = torch.matmul(query_emb, cand_embs.T).squeeze(0).tolist()
743
-
744
- scored_paths = list(zip(path_keys, sim_scores))
745
- scored_paths.sort(key=lambda x: x[1], reverse=True)
746
- top100 = scored_paths[:100]
747
-
748
- # 3. cross-encoder rerank
749
- pairs = [(query_text, pk) for pk, _ in top100]
750
- all_cross_scores = []
751
- for q, pk in pairs:
752
- input_pair = [(q, pk)]
753
- scores = cross_client.text_classification(
754
- input_pair,
755
- model="BAAI/bge-reranker-v2-m3"
756
- )
757
- all_cross_scores.append(scores[0]["score"])
758
-
759
- rerank_final = list(zip([p[1] for p in top100], all_cross_scores))
760
- rerank_final.sort(key=lambda x: x[1], reverse=True)
761
- top30 = rerank_final[:30]
762
-
763
- top30_values = [path_kv[pk] for pk, _ in top30]
764
- return {"neo4j_retrieval": top30_values}
765
 
766
- except Exception as e:
767
- print(f"rerank error: {e}")
768
- fallback_values = list(path_kv.values())[:50]
769
- return {"neo4j_retrieval": fallback_values}
770
  def neo4j_retrieval(state: MyState):
771
  logger.info("---NODE: neo4j_retrieval---")
772
  #user_query = [message.content for message in state["messages"] if hasattr(message, 'content')]
@@ -892,7 +839,51 @@ def neo4j_retrieval(state: MyState):
892
  except Exception as e:
893
  logger.warning(f"'{entity}'failed in faiss {e}")
894
  continue
895
- return rerank_paths_cloud(query_text, path_kv)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
896
 
897
 
898
  def decide_router(state: MyState) -> dict:
@@ -1009,7 +1000,7 @@ def build_graphrag_agent():
1009
  builder = StateGraph(MyState)
1010
  builder.add_node("parse_query", parse_query)
1011
  builder.add_node("user_input", user_input)
1012
- builder.add_node("neo4j_retrieval", neo4j_retrieval)
1013
  builder.add_node("decide_router", decide_router)
1014
  builder.add_node("api_search", api_search)
1015
  builder.add_node("llm_answer", llm_answer)
@@ -1020,11 +1011,11 @@ def build_graphrag_agent():
1020
  whether_to_interact,
1021
  {
1022
  "user_input": "user_input",
1023
- "neo4j_retrieval": "neo4j_retrieval"
1024
  }
1025
  )
1026
  builder.add_edge("user_input", "parse_query")
1027
- builder.add_edge("neo4j_retrieval", "decide_router")
1028
  builder.add_conditional_edges(
1029
  "decide_router",
1030
  lambda state: state["route"],
 
45
  Entrez.email = ENTREZ_EMAIL
46
  MAX_TOKENS = 128000
47
  encoding = tiktoken.get_encoding("cl100k_base")
48
+ tokenizer = AutoTokenizer.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext",token=HF_TOKEN)
49
+ model = AutoModel.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext",token=HF_TOKEN).to(DEVICE)
50
+ model.eval()
51
 
52
+ bi_tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-m3",token=HF_TOKEN)
53
+ bi_model = AutoModel.from_pretrained("BAAI/bge-m3",token=HF_TOKEN)
54
+ bi_model.eval()
55
 
56
+ cross_tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-reranker-v2-m3",token=HF_TOKEN)
57
+ cross_model = AutoModelForSequenceClassification.from_pretrained("BAAI/bge-reranker-v2-m3",token=HF_TOKEN)
58
+ cross_model.eval()
59
+ #sapbert_client = InferenceClient(provider="hf-inference",api_key=HF_TOKEN)
60
+ #bge_client = InferenceClient(provider="hf-inference", api_key=HF_TOKEN)
61
+ #cross_client = InferenceClient(provider="hf-inference", api_key=HF_TOKEN)
 
 
 
62
  # ======================== 全局变量 ========================
63
  faiss_indices = {}
64
  metadata = {}
 
161
  return {}
162
  return {}
163
 
164
+ def embed_entity(entity_text: str):
165
  if not tokenizer or not model:
166
  raise ValueError("embedding model not loaded")
167
  with torch.no_grad():
 
173
  embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
174
  return embedding
175
 
176
+ #def embed_entity(entity_text: str):
177
+ # """使用 Hugging Face Inference API 获取 SapBERT 嵌入"""
178
+ # try:result = sapbert_client.feature_extraction(entity_text,model="cambridgeltl/SapBERT-from-PubMedBERT-fulltext")
 
 
 
 
 
 
179
  # 返回结果通常是 list[list[float]],取平均或第一 token
180
+ # embedding = [sum(x)/len(x) for x in zip(*result)] # 对每个维度求平均
181
+ # return embedding
182
+ # except Exception as e:
183
+ # print(f"Embedding error: {e}")
184
+ # return None
185
 
186
  def search_pubmed(pubmed_query: str, max_results: int = 3) -> str:
187
  try:
 
693
  return "user_input"
694
  elif interaction == "sufficient":
695
  print("决策: 信息充分,进入Neo4j检索。")
696
+ return "kg_retrieval"
697
  else:
698
  return "stop_flow"
699
 
 
713
  zip_ref.extractall("data/")
714
 
715
  print("✅ 已成功下载并解压 data.zip")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
716
 
 
 
 
 
717
  def neo4j_retrieval(state: MyState):
718
  logger.info("---NODE: neo4j_retrieval---")
719
  #user_query = [message.content for message in state["messages"] if hasattr(message, 'content')]
 
839
  except Exception as e:
840
  logger.warning(f"'{entity}'failed in faiss {e}")
841
  continue
842
+ try:
843
+ query_inputs = bi_tokenizer(query_text, return_tensors="pt", truncation=True, max_length=512,padding=True)
844
+ with torch.no_grad():
845
+ query_emb = bi_model(**query_inputs).last_hidden_state[:, 0]
846
+ query_emb = F.normalize(query_emb, dim=-1)
847
+
848
+ path_keys = list(path_kv.keys())
849
+ batch_size = 32
850
+ all_cand_embs = []
851
+ with torch.no_grad():
852
+ for i in range(0, len(path_keys), batch_size):
853
+ batch = path_keys[i:i + batch_size]
854
+ cand_inputs = bi_tokenizer(batch, return_tensors="pt", truncation=True, max_length=512,padding=True)
855
+ cand_embs_batch = bi_model(**cand_inputs).last_hidden_state[:, 0]
856
+ cand_embs_batch = F.normalize(cand_embs_batch, dim=-1)
857
+ all_cand_embs.append(cand_embs_batch)
858
+
859
+ cand_embs = torch.cat(all_cand_embs, dim=0)
860
+ sim_scores = torch.matmul(query_emb, cand_embs.T).squeeze(0).tolist()
861
+ scored_paths = list(zip(path_keys, sim_scores))
862
+ scored_paths.sort(key=lambda x: x[1], reverse=True)
863
+
864
+ top100 = scored_paths[:100]
865
+ pairs = [(query_text, pk) for pk, _ in top100]
866
+ all_cross_scores = []
867
+ cross_batch_size = 16
868
+ with torch.no_grad():
869
+ for i in range(0, len(pairs), cross_batch_size):
870
+ batch_pairs = pairs[i:i + cross_batch_size]
871
+ inputs = cross_tokenizer(batch_pairs, padding=True, truncation=True, max_length=512,return_tensors="pt")
872
+ scores = cross_model(**inputs).logits.view(-1).tolist()
873
+ all_cross_scores.extend(scores)
874
+
875
+ rerank_final = list(zip([p[0] for p in top100], all_cross_scores))
876
+ rerank_final.sort(key=lambda x: x[1], reverse=True)
877
+ top30 = rerank_final[:30]
878
+
879
+ top30_values = [path_kv[pk] for pk, _ in top30]
880
+ logger.info(f"Cross-encoder reranked 30 path: {top30_values}")
881
+ return {"neo4j_retrieval": top30_values}
882
+
883
+ except Exception as e:
884
+ logger.warning(f"rerank error: {e}")
885
+ fallback_values = list(path_kv.values())[:50]
886
+ return {"neo4j_retrieval": fallback_values}
887
 
888
 
889
  def decide_router(state: MyState) -> dict:
 
1000
  builder = StateGraph(MyState)
1001
  builder.add_node("parse_query", parse_query)
1002
  builder.add_node("user_input", user_input)
1003
+ builder.add_node("kg_retrieval", neo4j_retrieval)
1004
  builder.add_node("decide_router", decide_router)
1005
  builder.add_node("api_search", api_search)
1006
  builder.add_node("llm_answer", llm_answer)
 
1011
  whether_to_interact,
1012
  {
1013
  "user_input": "user_input",
1014
+ "kg_retrieval": "kg_retrieval"
1015
  }
1016
  )
1017
  builder.add_edge("user_input", "parse_query")
1018
+ builder.add_edge("kg_retrieval", "decide_router")
1019
  builder.add_conditional_edges(
1020
  "decide_router",
1021
  lambda state: state["route"],