achenyx1412 commited on
Commit
248f5f0
·
verified ·
1 Parent(s): 148cd09

Update graphrag_agent.py

Browse files
Files changed (1) hide show
  1. graphrag_agent.py +68 -5
graphrag_agent.py CHANGED
@@ -42,9 +42,17 @@ HF_TOKEN = os.getenv("HF_TOKEN")
42
  Entrez.email = ENTREZ_EMAIL
43
  MAX_TOKENS = 128000
44
  encoding = tiktoken.get_encoding("cl100k_base")
45
- sapbert_client = InferenceClient(provider="hf-inference",api_key=HF_TOKEN,)
46
- bge_client = InferenceClient(provider="hf-inference", api_key=HF_TOKEN)
47
- cross_client = InferenceClient(provider="hf-inference", api_key=HF_TOKEN)
 
 
 
 
 
 
 
 
48
  # ======================== 全局变量 ========================
49
  faiss_indices = {}
50
  metadata = {}
@@ -147,6 +155,18 @@ def _extract_json_from_text(text: str) -> Dict[str, Any]:
147
  return {}
148
  return {}
149
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  def embed_entity_cloud(entity_text: str):
151
  """
152
  使用 Hugging Face Inference API 获取 SapBERT 嵌入
@@ -163,7 +183,6 @@ def embed_entity_cloud(entity_text: str):
163
  print(f"Embedding error: {e}")
164
  return None
165
 
166
-
167
  def search_pubmed(pubmed_query: str, max_results: int = 3) -> str:
168
  try:
169
  handle = Entrez.esearch(db="pubmed", term=pubmed_query, retmax=max_results)
@@ -864,7 +883,51 @@ def neo4j_retrieval(state: MyState):
864
  logger.warning(f"'{entity}'failed in faiss {e}")
865
  continue
866
 
867
- return rerank_paths_cloud(query_text, path_kv)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
868
 
869
  def decide_router(state: MyState) -> dict:
870
  print("---EDGE: decide_router---")
 
42
  Entrez.email = ENTREZ_EMAIL
43
  MAX_TOKENS = 128000
44
  encoding = tiktoken.get_encoding("cl100k_base")
45
+ tokenizer = AutoTokenizer.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext",token=HF_TOKEN)
46
+ model = AutoModel.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext",token=HF_TOKEN).to(DEVICE)
47
+ model.eval()
48
+
49
+ bi_tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-m3",token=HF_TOKEN)
50
+ bi_model = AutoModel.from_pretrained("BAAI/bge-m3",token=HF_TOKEN)
51
+ bi_model.eval()
52
+
53
+ cross_tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-reranker-v2-m3",token=HF_TOKEN)
54
+ cross_model = AutoModelForSequenceClassification.from_pretrained("BAAI/bge-reranker-v2-m3",token=HF_TOKEN)
55
+ cross_model.eval()
56
  # ======================== 全局变量 ========================
57
  faiss_indices = {}
58
  metadata = {}
 
155
  return {}
156
  return {}
157
 
158
+ def embed_entity(entity_text: str):
159
+ if not tokenizer or not model:
160
+ raise ValueError("embedding model not loaded")
161
+ with torch.no_grad():
162
+ inputs = tokenizer(
163
+ entity_text, return_tensors="pt",
164
+ padding=True, truncation=True, max_length=64
165
+ ).to(DEVICE)
166
+ outputs = model(**inputs)
167
+ embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
168
+ return embedding
169
+
170
  def embed_entity_cloud(entity_text: str):
171
  """
172
  使用 Hugging Face Inference API 获取 SapBERT 嵌入
 
183
  print(f"Embedding error: {e}")
184
  return None
185
 
 
186
  def search_pubmed(pubmed_query: str, max_results: int = 3) -> str:
187
  try:
188
  handle = Entrez.esearch(db="pubmed", term=pubmed_query, retmax=max_results)
 
883
  logger.warning(f"'{entity}'failed in faiss {e}")
884
  continue
885
 
886
+ try:
887
+ query_inputs = bi_tokenizer(query_text, return_tensors="pt", truncation=True, max_length=512,padding=True)
888
+ with torch.no_grad():
889
+ query_emb = bi_model(**query_inputs).last_hidden_state[:, 0]
890
+ query_emb = F.normalize(query_emb, dim=-1)
891
+
892
+ path_keys = list(path_kv.keys())
893
+ batch_size = 32
894
+ all_cand_embs = []
895
+ with torch.no_grad():
896
+ for i in range(0, len(path_keys), batch_size):
897
+ batch = path_keys[i:i + batch_size]
898
+ cand_inputs = bi_tokenizer(batch, return_tensors="pt", truncation=True, max_length=512,padding=True)
899
+ cand_embs_batch = bi_model(**cand_inputs).last_hidden_state[:, 0]
900
+ cand_embs_batch = F.normalize(cand_embs_batch, dim=-1)
901
+ all_cand_embs.append(cand_embs_batch)
902
+
903
+ cand_embs = torch.cat(all_cand_embs, dim=0)
904
+ sim_scores = torch.matmul(query_emb, cand_embs.T).squeeze(0).tolist()
905
+ scored_paths = list(zip(path_keys, sim_scores))
906
+ scored_paths.sort(key=lambda x: x[1], reverse=True)
907
+
908
+ top100 = scored_paths[:100]
909
+ pairs = [(query_text, pk) for pk, _ in top100]
910
+ all_cross_scores = []
911
+ cross_batch_size = 16
912
+ with torch.no_grad():
913
+ for i in range(0, len(pairs), cross_batch_size):
914
+ batch_pairs = pairs[i:i + cross_batch_size]
915
+ inputs = cross_tokenizer(batch_pairs, padding=True, truncation=True, max_length=512,return_tensors="pt")
916
+ scores = cross_model(**inputs).logits.view(-1).tolist()
917
+ all_cross_scores.extend(scores)
918
+
919
+ rerank_final = list(zip([p[0] for p in top100], all_cross_scores))
920
+ rerank_final.sort(key=lambda x: x[1], reverse=True)
921
+ top30 = rerank_final[:30]
922
+
923
+ top30_values = [path_kv[pk] for pk, _ in top30]
924
+ logger.info(f"Cross-encoder reranked 30 path: {top30_values}")
925
+ return {"neo4j_retrieval": top30_values}
926
+
927
+ except Exception as e:
928
+ logger.warning(f"rerank error: {e}")
929
+ fallback_values = list(path_kv.values())[:50]
930
+ return {"neo4j_retrieval": fallback_values}
931
 
932
  def decide_router(state: MyState) -> dict:
933
  print("---EDGE: decide_router---")