Spaces:
Build error
Build error
Update graphrag_agent.py
Browse files- graphrag_agent.py +35 -26
graphrag_agent.py
CHANGED
|
@@ -700,35 +700,44 @@ def whether_to_interact(state):
|
|
| 700 |
|
| 701 |
# 数据存放路径
|
| 702 |
DATA_DIR = "data"
|
| 703 |
-
|
| 704 |
-
|
| 705 |
-
#
|
| 706 |
-
|
| 707 |
-
|
| 708 |
-
|
| 709 |
-
|
| 710 |
-
|
| 711 |
-
|
| 712 |
-
|
| 713 |
-
|
| 714 |
-
|
| 715 |
-
|
|
|
|
|
|
|
| 716 |
|
| 717 |
-
|
| 718 |
-
|
| 719 |
-
|
| 720 |
-
|
| 721 |
-
|
| 722 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 723 |
repo_type="dataset",
|
| 724 |
-
token=HF_TOKEN
|
|
|
|
|
|
|
| 725 |
)
|
| 726 |
-
|
| 727 |
-
|
| 728 |
-
print("
|
| 729 |
-
with ZipFile(zip_path, "r") as zip_ref:
|
| 730 |
-
zip_ref.extractall(DATA_DIR)
|
| 731 |
-
print("✅ 已成功下载并解压 data.zip")
|
| 732 |
|
| 733 |
def neo4j_retrieval(state: MyState):
|
| 734 |
logger.info("---NODE: neo4j_retrieval---")
|
|
|
|
| 700 |
|
| 701 |
# 数据存放路径
|
| 702 |
DATA_DIR = "data"
|
| 703 |
+
os.makedirs(DATA_DIR, exist_ok=True)
|
| 704 |
+
|
| 705 |
+
# Hugging Face Dataset repo ID
|
| 706 |
+
REPO_ID = "achenyx1412/DGADIS"
|
| 707 |
+
|
| 708 |
+
# 需要下载的文件列表
|
| 709 |
+
FILES = [
|
| 710 |
+
"faiss_node+desc.index",
|
| 711 |
+
"faiss_node+desc.pkl",
|
| 712 |
+
"faiss_node.index",
|
| 713 |
+
"faiss_node.pkl",
|
| 714 |
+
"faiss_triple3.index",
|
| 715 |
+
"faiss_triple3.pkl",
|
| 716 |
+
"kg.gpickle"
|
| 717 |
+
]
|
| 718 |
|
| 719 |
+
# 遍历文件,逐个下载
|
| 720 |
+
for file_name in FILES:
|
| 721 |
+
local_path = os.path.join(DATA_DIR, file_name)
|
| 722 |
+
|
| 723 |
+
# 如果本地已存在,则跳过下载
|
| 724 |
+
if os.path.exists(local_path):
|
| 725 |
+
print(f"✅ 已检测到本地文件 {file_name},跳过下载。")
|
| 726 |
+
continue
|
| 727 |
+
|
| 728 |
+
print(f"🌐 正在从 Hugging Face 下载 {file_name} ...")
|
| 729 |
+
try:
|
| 730 |
+
hf_hub_download(
|
| 731 |
+
repo_id=REPO_ID,
|
| 732 |
+
filename=file_name,
|
| 733 |
repo_type="dataset",
|
| 734 |
+
token=HF_TOKEN,
|
| 735 |
+
local_dir=DATA_DIR,
|
| 736 |
+
local_dir_use_symlinks=False # 防止 symlink 问题
|
| 737 |
)
|
| 738 |
+
print(f"✅ 已成功下载 {file_name}")
|
| 739 |
+
except Exception as e:
|
| 740 |
+
print(f"❌ 下载 {file_name} 失败: {e}")
|
|
|
|
|
|
|
|
|
|
| 741 |
|
| 742 |
def neo4j_retrieval(state: MyState):
|
| 743 |
logger.info("---NODE: neo4j_retrieval---")
|