Spaces:
Build error
Build error
| import os | |
| import gradio as gr | |
| import logging | |
| from typing import List, Tuple | |
| from graphrag_agent import ( | |
| build_graphrag_agent, | |
| system_initialized, | |
| MyState | |
| ) | |
| from langchain_core.messages import HumanMessage | |
| # ======================== 配置 ======================== | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # ======================== 全局变量 ======================== | |
| compiled_graph = None | |
| # 会话状态存储:用于跟踪每个用户的对话状态 | |
| session_states = {} | |
| # ======================== 初始化系统 ======================== | |
| def check_system_status(): | |
| """检查系统初始化状态""" | |
| global compiled_graph | |
| try: | |
| # 导入 graphrag_agent 已经自动加载了数据 | |
| # 我们只需要构建图 | |
| if system_initialized and compiled_graph is None: | |
| logger.info("🔧 Building GraphRAG Agent...") | |
| compiled_graph = build_graphrag_agent() | |
| logger.info("✅ GraphRAG Agent compiled successfully!") | |
| return "✅ System ready! You can start asking questions." | |
| elif system_initialized: | |
| return "✅ System already initialized." | |
| else: | |
| return "⏳ System is still initializing, please wait..." | |
| except Exception as e: | |
| logger.error(f"❌ Initialization check failed: {e}") | |
| return f"❌ Error: {str(e)}" | |
| # ======================== 核心处理函数 ======================== | |
| def process_message(message: str, history: List[Tuple[str, str]], session_id: str): | |
| """ | |
| 处理用户消息的核心函数 | |
| Args: | |
| message: 用户输入的消息 | |
| history: 对话历史 [(user_msg, bot_msg), ...] | |
| session_id: 会话ID,用于跟踪状态 | |
| Returns: | |
| ("", updated_history, session_state_display) | |
| """ | |
| # 检查系统是否初始化 | |
| if not system_initialized or compiled_graph is None: | |
| history.append((message, "⚠️ System not initialized. Please wait for data loading...")) | |
| return "", history, "System Status: Not Ready" | |
| if not message.strip(): | |
| return "", history, "System Status: Waiting for input" | |
| try: | |
| # 获取或创建会话状态 | |
| if session_id not in session_states: | |
| session_states[session_id] = { | |
| "waiting_for_clarification": False, | |
| "current_graph_state": None, | |
| "conversation_messages": [] | |
| } | |
| session_state = session_states[session_id] | |
| # ===== 情况 1: 正在等待用户补充信息 ===== | |
| if session_state["waiting_for_clarification"]: | |
| logger.info(f"📝 User provided clarification: {message}") | |
| # 将用户的补充信息添加到对话历史 | |
| session_state["conversation_messages"].append(HumanMessage(content=message)) | |
| # 从 user_input 节点继续执行 | |
| prev_state = session_state["current_graph_state"] | |
| # 调用 user_input 节点,传入用户回复 | |
| from graphrag_agent import user_input | |
| user_input_result = user_input(prev_state, user_reply_text=message) | |
| # 更新状态 | |
| prev_state.update(user_input_result) | |
| # 如果还需要更多信息(理论上不会,但以防万一) | |
| if user_input_result.get("need_user_reply", False): | |
| ai_msg = user_input_result.get("ai_message", "Please provide more information.") | |
| history.append((message, f"🤔 {ai_msg}")) | |
| return "", history, "Status: Still waiting for clarification" | |
| # 继续执行图的剩余部分(从 parse_query 重新开始) | |
| result = compiled_graph.invoke(prev_state) | |
| # 获取最终答案 | |
| final_answer = result.get("llm_answer", "No answer generated.") | |
| history.append((message, final_answer)) | |
| # 重置等待状态 | |
| session_state["waiting_for_clarification"] = False | |
| session_state["current_graph_state"] = None | |
| return "", history, "✅ Status: Answer generated (after clarification)" | |
| # ===== 情况 2: 新的查询 ===== | |
| else: | |
| logger.info(f"🔍 Processing new query: {message}") | |
| # 将用户消息添加到会话历史 | |
| session_state["conversation_messages"].append(HumanMessage(content=message)) | |
| # 创建初始状态 | |
| initial_state = { | |
| "messages": session_state["conversation_messages"].copy() | |
| } | |
| # 执行图 - 使用流式处理以捕获中间状态 | |
| final_state = None | |
| need_clarification = False | |
| clarification_message = None | |
| for event in compiled_graph.stream(initial_state): | |
| # 检查是否到达 user_input 节点 | |
| if "user_input" in event: | |
| user_input_state = event["user_input"] | |
| # 检查是否需要用户补充信息 | |
| if user_input_state.get("need_user_reply", False): | |
| need_clarification = True | |
| clarification_message = user_input_state.get("ai_message", "Please provide more information.") | |
| # 保存当前状态,等待用户回复 | |
| session_state["waiting_for_clarification"] = True | |
| session_state["current_graph_state"] = user_input_state | |
| # 不继续执行,等待用户输入 | |
| break | |
| # 保存最终状态 | |
| final_state = event | |
| # 如果需要用户补充信息 | |
| if need_clarification: | |
| history.append((message, f"🤔 **Need More Information:**\n\n{clarification_message}\n\n*Please provide additional details.*")) | |
| return "", history, "⏳ Status: Waiting for clarification" | |
| # 如果直接得到答案(信息充足的情况) | |
| if final_state: | |
| # 获取最终答案 | |
| answer_key = list(final_state.keys())[0] if final_state else None | |
| if answer_key: | |
| state_data = final_state[answer_key] | |
| final_answer = state_data.get("llm_answer", "No answer generated.") | |
| else: | |
| final_answer = "Unable to generate answer." | |
| history.append((message, final_answer)) | |
| return "", history, "✅ Status: Answer generated" | |
| else: | |
| history.append((message, "❌ Error: No response from agent.")) | |
| return "", history, "❌ Status: Error" | |
| except Exception as e: | |
| logger.error(f"❌ Error processing message: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| error_msg = f"❌ Error: {str(e)}\n\nPlease try again or rephrase your question." | |
| history.append((message, error_msg)) | |
| # 重置会话状态 | |
| if session_id in session_states: | |
| session_states[session_id]["waiting_for_clarification"] = False | |
| return "", history, f"❌ Status: Error - {str(e)}" | |
| # ======================== 清除会话 ======================== | |
| def clear_session(session_id: str): | |
| """清除会话状态""" | |
| if session_id in session_states: | |
| del session_states[session_id] | |
| return [], "🔄 Status: Session cleared" | |
| # ======================== Gradio 界面 ======================== | |
| def create_demo(): | |
| """创建 Gradio 界面""" | |
| with gr.Blocks( | |
| theme=gr.themes.Soft(), | |
| title="Dental Knowledge QA System", | |
| css=""" | |
| .chatbot-container { | |
| height: 600px !important; | |
| } | |
| .status-box { | |
| background-color: #f0f0f0; | |
| padding: 10px; | |
| border-radius: 5px; | |
| font-family: monospace; | |
| } | |
| """ | |
| ) as demo: | |
| # 标题 | |
| gr.Markdown(""" | |
| # 🦷 Dental Knowledge QA System | |
| ### GraphRAG-based Agent with Knowledge Graph & External Search | |
| Ask any dental medicine question. The system will: | |
| 1. Parse your query and extract entities | |
| 2. **Ask for clarification if needed** 🔄 | |
| 3. Search the knowledge graph | |
| 4. Optionally search PubMed & Wikipedia | |
| 5. Generate a comprehensive answer with sources | |
| """) | |
| # 系统状态 | |
| with gr.Row(): | |
| init_status = gr.Textbox( | |
| label="🔧 System Status", | |
| value="Checking system status...", | |
| interactive=False, | |
| lines=2 | |
| ) | |
| # 主对话区域 | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| chatbot = gr.Chatbot( | |
| label="💬 Conversation", | |
| height=600, | |
| bubble_full_width=False, | |
| avatar_images=(None, "🤖") | |
| ) | |
| with gr.Row(): | |
| msg_input = gr.Textbox( | |
| label="Your Question", | |
| placeholder="Ask a dental question... (e.g., 'What is the treatment of gingivitis?')", | |
| lines=2, | |
| scale=4 | |
| ) | |
| with gr.Column(scale=1): | |
| send_btn = gr.Button("Send 📤", variant="primary", size="lg") | |
| clear_btn = gr.Button("Clear 🗑️", size="lg") | |
| # 示例问题 | |
| gr.Examples( | |
| examples=[ | |
| "What is the treatment of gingivitis?", | |
| "What are the complications of implant restoration?", | |
| "How to prevent dental caries in children?", | |
| "What is the best treatment?", # 这个会触发交互 | |
| "印模材料凝固后,其软度通常用什么指标表示?" | |
| ], | |
| inputs=msg_input, | |
| label="📝 Example Questions" | |
| ) | |
| # 右侧状态面板 | |
| with gr.Column(scale=1): | |
| session_status = gr.Textbox( | |
| label="📊 Session Status", | |
| value="Status: Ready", | |
| interactive=False, | |
| lines=3 | |
| ) | |
| gr.Markdown(""" | |
| ### 💡 Tips | |
| - Ask clear, specific questions | |
| - If asked for clarification, provide more details | |
| - Include context when possible | |
| ### 🔄 Interaction Flow | |
| 1. **Insufficient info** → System asks for clarification | |
| 2. You provide details → System continues | |
| 3. **Sufficient info** → Direct answer | |
| """) | |
| # 技术信息 | |
| with gr.Accordion("ℹ️ System Information", open=False): | |
| gr.Markdown(""" | |
| ### Architecture | |
| - **Knowledge Graph**: Neo4j-based dental knowledge | |
| - **Retrieval**: FAISS + SapBERT embeddings | |
| - **Reranking**: BGE bi-encoder + cross-encoder | |
| - **LLM**: DeepSeek Chat | |
| - **External Sources**: PubMed, Wikipedia | |
| ### Features | |
| - Multi-turn conversation support | |
| - Intelligent query clarification | |
| - Entity extraction and linking | |
| - Source attribution | |
| """) | |
| # 隐藏的会话ID(每个用户独立) | |
| session_id = gr.State(value=lambda: str(id(object()))) | |
| # ===== 事件绑定 ===== | |
| # 检查系统状态(页面加载时) | |
| demo.load( | |
| fn=check_system_status, | |
| outputs=init_status | |
| ) | |
| # 发送消息 | |
| send_btn.click( | |
| fn=process_message, | |
| inputs=[msg_input, chatbot, session_id], | |
| outputs=[msg_input, chatbot, session_status] | |
| ) | |
| # 回车发送 | |
| msg_input.submit( | |
| fn=process_message, | |
| inputs=[msg_input, chatbot, session_id], | |
| outputs=[msg_input, chatbot, session_status] | |
| ) | |
| # 清除对话 | |
| clear_btn.click( | |
| fn=clear_session, | |
| inputs=session_id, | |
| outputs=[chatbot, session_status] | |
| ) | |
| return demo | |
| # ======================== 启动应用 ======================== | |
| if __name__ == "__main__": | |
| demo = create_demo() | |
| demo.queue(max_size=20) # 支持并发 | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_error=True | |
| ) |