achenyx1412 commited on
Commit
4bd3201
·
verified ·
1 Parent(s): e68aadb

Update graphrag_agent.py

Browse files
Files changed (1) hide show
  1. graphrag_agent.py +35 -26
graphrag_agent.py CHANGED
@@ -700,35 +700,44 @@ def whether_to_interact(state):
700
 
701
  # 数据存放路径
702
  DATA_DIR = "data"
703
- ZIP_FILE = "data.zip"
704
-
705
- # 如果 data/ 已存在且非空,就跳过下载
706
- if os.path.exists(DATA_DIR) and any(os.scandir(DATA_DIR)):
707
- print("✅ 已检测到本地 data 文件夹,跳过下载。")
708
-
709
- else:
710
- # 如果没有解压好的 data,但有 data.zip,则直接解压
711
- if os.path.exists(ZIP_FILE):
712
- print("📦 检测到本地 data.zip,正在解压...")
713
- with ZipFile(ZIP_FILE, "r") as zip_ref:
714
- zip_ref.extractall(DATA_DIR)
715
- print("✅ 已成功解压本地 data.zip")
 
 
716
 
717
- else:
718
- # 如果连 data.zip 都没有,才从 HF 下载
719
- print("🌐 未检测到本地数据,开始从 Hugging Face 下载 data.zip...")
720
- zip_path = hf_hub_download(
721
- repo_id="achenyx1412/DGADIS",
722
- filename="data.zip",
 
 
 
 
 
 
 
 
723
  repo_type="dataset",
724
- token=HF_TOKEN
 
 
725
  )
726
-
727
- # 解压
728
- print("📦 正在解压 data.zip...")
729
- with ZipFile(zip_path, "r") as zip_ref:
730
- zip_ref.extractall(DATA_DIR)
731
- print("✅ 已成功下载并解压 data.zip")
732
 
733
  def neo4j_retrieval(state: MyState):
734
  logger.info("---NODE: neo4j_retrieval---")
 
700
 
701
  # 数据存放路径
702
  DATA_DIR = "data"
703
+ os.makedirs(DATA_DIR, exist_ok=True)
704
+
705
+ # Hugging Face Dataset repo ID
706
+ REPO_ID = "achenyx1412/DGADIS"
707
+
708
+ # 需要下载的文件列表
709
+ FILES = [
710
+ "faiss_node+desc.index",
711
+ "faiss_node+desc.pkl",
712
+ "faiss_node.index",
713
+ "faiss_node.pkl",
714
+ "faiss_triple3.index",
715
+ "faiss_triple3.pkl",
716
+ "kg.gpickle"
717
+ ]
718
 
719
+ # 遍历文件,逐个下载
720
+ for file_name in FILES:
721
+ local_path = os.path.join(DATA_DIR, file_name)
722
+
723
+ # 如果本地已存在,则跳过下载
724
+ if os.path.exists(local_path):
725
+ print(f"✅ 已检测到本地文件 {file_name},跳过下载。")
726
+ continue
727
+
728
+ print(f"🌐 正在从 Hugging Face 下载 {file_name} ...")
729
+ try:
730
+ hf_hub_download(
731
+ repo_id=REPO_ID,
732
+ filename=file_name,
733
  repo_type="dataset",
734
+ token=HF_TOKEN,
735
+ local_dir=DATA_DIR,
736
+ local_dir_use_symlinks=False # 防止 symlink 问题
737
  )
738
+ print(f"✅ 已成功下载 {file_name}")
739
+ except Exception as e:
740
+ print(f" 下载 {file_name} 失败: {e}")
 
 
 
741
 
742
  def neo4j_retrieval(state: MyState):
743
  logger.info("---NODE: neo4j_retrieval---")