Spaces:
Running
Running
| """ | |
| LangGraph Integration Module - Extract graph building with new MCTS core integration. | |
| Provides: | |
| - Graph building extracted from LangGraphMultiAgentFramework | |
| - Integration with new deterministic MCTS core | |
| - Backward compatibility with original process() signature | |
| - Support for parallel HRM/TRM execution | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import operator | |
| import time | |
| from typing import Annotated, Any, NotRequired, TypedDict | |
| # LangGraph imports (these would be installed dependencies) | |
| try: | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from langgraph.graph import END, StateGraph | |
| except ImportError: | |
| # Stubs for development without LangGraph installed | |
| StateGraph = None | |
| END = "END" | |
| MemorySaver = None | |
| # Import new MCTS modules | |
| from .mcts.config import ConfigPreset, MCTSConfig, create_preset_config | |
| from .mcts.core import MCTSEngine, MCTSNode, MCTSState | |
| from .mcts.experiments import ExperimentTracker | |
| from .mcts.policies import ( | |
| HybridRolloutPolicy, | |
| ) | |
| # Neural Meta-Controller imports (optional) | |
| try: | |
| from src.agents.meta_controller.base import ( | |
| AbstractMetaController, | |
| MetaControllerFeatures, | |
| ) | |
| from src.agents.meta_controller.bert_controller import BERTMetaController | |
| from src.agents.meta_controller.config_loader import ( | |
| MetaControllerConfig, | |
| MetaControllerConfigLoader, | |
| ) | |
| from src.agents.meta_controller.rnn_controller import RNNMetaController | |
| _META_CONTROLLER_AVAILABLE = True | |
| except ImportError: | |
| _META_CONTROLLER_AVAILABLE = False | |
| AbstractMetaController = None # type: ignore | |
| MetaControllerFeatures = None # type: ignore | |
| RNNMetaController = None # type: ignore | |
| BERTMetaController = None # type: ignore | |
| MetaControllerConfig = None # type: ignore | |
| MetaControllerConfigLoader = None # type: ignore | |
| class AgentState(TypedDict): | |
| """Shared state for LangGraph agent framework.""" | |
| # Input | |
| query: str | |
| use_mcts: bool | |
| use_rag: bool | |
| # RAG context | |
| rag_context: NotRequired[str] | |
| retrieved_docs: NotRequired[list[dict]] | |
| # Agent results | |
| hrm_results: NotRequired[dict] | |
| trm_results: NotRequired[dict] | |
| agent_outputs: Annotated[list[dict], operator.add] | |
| # MCTS simulation (updated for new core) | |
| mcts_root: NotRequired[Any] # MCTSNode | |
| mcts_iterations: NotRequired[int] | |
| mcts_best_action: NotRequired[str] | |
| mcts_stats: NotRequired[dict] | |
| mcts_config: NotRequired[dict] | |
| # Evaluation | |
| confidence_scores: NotRequired[dict[str, float]] | |
| consensus_reached: NotRequired[bool] | |
| consensus_score: NotRequired[float] | |
| # Control flow | |
| iteration: int | |
| max_iterations: int | |
| # Neural Meta-Controller (optional) | |
| routing_history: NotRequired[list[dict]] | |
| meta_controller_predictions: NotRequired[list[dict]] | |
| last_routed_agent: NotRequired[str] | |
| # Output | |
| final_response: NotRequired[str] | |
| metadata: NotRequired[dict] | |
| class GraphBuilder: | |
| """ | |
| Builds and configures the LangGraph state machine for multi-agent orchestration. | |
| Extracts graph building logic from LangGraphMultiAgentFramework for modularity. | |
| """ | |
| def __init__( | |
| self, | |
| hrm_agent, | |
| trm_agent, | |
| model_adapter, | |
| logger, | |
| vector_store=None, | |
| mcts_config: MCTSConfig | None = None, | |
| top_k_retrieval: int = 5, | |
| max_iterations: int = 3, | |
| consensus_threshold: float = 0.75, | |
| enable_parallel_agents: bool = True, | |
| meta_controller_config: Any | None = None, | |
| ): | |
| """ | |
| Initialize graph builder. | |
| Args: | |
| hrm_agent: HRM agent instance | |
| trm_agent: TRM agent instance | |
| model_adapter: Model adapter for LLM calls | |
| logger: Logger instance | |
| vector_store: Optional vector store for RAG | |
| mcts_config: MCTS configuration (uses balanced preset if None) | |
| top_k_retrieval: Number of documents for RAG | |
| max_iterations: Maximum agent iterations | |
| consensus_threshold: Threshold for consensus | |
| enable_parallel_agents: Run HRM/TRM in parallel | |
| meta_controller_config: Optional neural meta-controller configuration | |
| """ | |
| self.hrm_agent = hrm_agent | |
| self.trm_agent = trm_agent | |
| self.model_adapter = model_adapter | |
| self.logger = logger | |
| self.vector_store = vector_store | |
| self.top_k_retrieval = top_k_retrieval | |
| self.max_iterations = max_iterations | |
| self.consensus_threshold = consensus_threshold | |
| self.enable_parallel_agents = enable_parallel_agents | |
| # MCTS configuration | |
| self.mcts_config = mcts_config or create_preset_config(ConfigPreset.BALANCED) | |
| # MCTS engine with deterministic behavior | |
| self.mcts_engine = MCTSEngine( | |
| seed=self.mcts_config.seed, | |
| exploration_weight=self.mcts_config.exploration_weight, | |
| progressive_widening_k=self.mcts_config.progressive_widening_k, | |
| progressive_widening_alpha=self.mcts_config.progressive_widening_alpha, | |
| max_parallel_rollouts=self.mcts_config.max_parallel_rollouts, | |
| cache_size_limit=self.mcts_config.cache_size_limit, | |
| ) | |
| # Experiment tracking | |
| self.experiment_tracker = ExperimentTracker(name="langgraph_mcts") | |
| # Neural Meta-Controller (optional) | |
| self.meta_controller: Any | None = None | |
| self.meta_controller_config = meta_controller_config | |
| self.use_neural_routing = False | |
| if meta_controller_config is not None: | |
| self._init_meta_controller(meta_controller_config) | |
| def build_graph(self) -> StateGraph: | |
| """ | |
| Build LangGraph state machine. | |
| Returns: | |
| Configured StateGraph | |
| """ | |
| if StateGraph is None: | |
| raise ImportError("LangGraph not installed. Install with: pip install langgraph") | |
| workflow = StateGraph(AgentState) | |
| # Add nodes | |
| workflow.add_node("entry", self._entry_node) | |
| workflow.add_node("retrieve_context", self._retrieve_context_node) | |
| workflow.add_node("route_decision", self._route_decision_node) | |
| workflow.add_node("parallel_agents", self._parallel_agents_node) | |
| workflow.add_node("hrm_agent", self._hrm_agent_node) | |
| workflow.add_node("trm_agent", self._trm_agent_node) | |
| workflow.add_node("mcts_simulator", self._mcts_simulator_node) | |
| workflow.add_node("aggregate_results", self._aggregate_results_node) | |
| workflow.add_node("evaluate_consensus", self._evaluate_consensus_node) | |
| workflow.add_node("synthesize", self._synthesize_node) | |
| # Define edges | |
| workflow.set_entry_point("entry") | |
| workflow.add_edge("entry", "retrieve_context") | |
| workflow.add_edge("retrieve_context", "route_decision") | |
| # Conditional routing | |
| workflow.add_conditional_edges( | |
| "route_decision", | |
| self._route_to_agents, | |
| { | |
| "parallel": "parallel_agents", | |
| "hrm": "hrm_agent", | |
| "trm": "trm_agent", | |
| "mcts": "mcts_simulator", | |
| "aggregate": "aggregate_results", | |
| }, | |
| ) | |
| # Parallel agents to aggregation | |
| workflow.add_edge("parallel_agents", "aggregate_results") | |
| # Sequential agent nodes | |
| workflow.add_edge("hrm_agent", "aggregate_results") | |
| workflow.add_edge("trm_agent", "aggregate_results") | |
| workflow.add_edge("mcts_simulator", "aggregate_results") | |
| # Aggregation to evaluation | |
| workflow.add_edge("aggregate_results", "evaluate_consensus") | |
| # Conditional consensus check | |
| workflow.add_conditional_edges( | |
| "evaluate_consensus", | |
| self._check_consensus, | |
| { | |
| "synthesize": "synthesize", | |
| "iterate": "route_decision", | |
| }, | |
| ) | |
| # Synthesis to end | |
| workflow.add_edge("synthesize", END) | |
| return workflow | |
| def _entry_node(self, state: AgentState) -> dict: | |
| """Initialize state and parse query.""" | |
| self.logger.info(f"Entry node: {state['query'][:100]}") | |
| return { | |
| "iteration": 0, | |
| "agent_outputs": [], | |
| "mcts_config": self.mcts_config.to_dict(), | |
| } | |
| def _retrieve_context_node(self, state: AgentState) -> dict: | |
| """Retrieve context from vector store using RAG.""" | |
| if not state.get("use_rag", True) or not self.vector_store: | |
| return {"rag_context": ""} | |
| query = state["query"] | |
| # Retrieve documents | |
| docs = self.vector_store.similarity_search(query, k=self.top_k_retrieval) | |
| # Format context | |
| context = "\n\n".join([doc.page_content for doc in docs]) | |
| self.logger.info(f"Retrieved {len(docs)} documents") | |
| return { | |
| "rag_context": context, | |
| "retrieved_docs": [{"content": doc.page_content, "metadata": doc.metadata} for doc in docs], | |
| } | |
| def _route_decision_node(self, _state: AgentState) -> dict: | |
| """Prepare routing decision.""" | |
| return {} | |
| def _init_meta_controller(self, config: Any) -> None: | |
| """ | |
| Initialize the neural meta-controller based on configuration. | |
| Args: | |
| config: MetaControllerConfig or dict with configuration | |
| """ | |
| if not _META_CONTROLLER_AVAILABLE: | |
| self.logger.warning("Meta-controller modules not available. Falling back to rule-based routing.") | |
| return | |
| try: | |
| # Handle both config object and dict | |
| mc_config = MetaControllerConfigLoader.load_from_dict(config) if isinstance(config, dict) else config | |
| if not mc_config.enabled: | |
| self.logger.info("Neural meta-controller disabled in config") | |
| return | |
| # Initialize based on type | |
| if mc_config.type == "rnn": | |
| self.meta_controller = RNNMetaController( | |
| name="GraphBuilder_RNN", | |
| seed=mc_config.inference.seed, | |
| hidden_dim=mc_config.rnn.hidden_dim, | |
| num_layers=mc_config.rnn.num_layers, | |
| dropout=mc_config.rnn.dropout, | |
| device=mc_config.inference.device, | |
| ) | |
| # Load trained model if path specified | |
| if mc_config.rnn.model_path: | |
| self.meta_controller.load_model(mc_config.rnn.model_path) | |
| self.logger.info(f"Loaded RNN model from {mc_config.rnn.model_path}") | |
| elif mc_config.type == "bert": | |
| self.meta_controller = BERTMetaController( | |
| name="GraphBuilder_BERT", | |
| seed=mc_config.inference.seed, | |
| model_name=mc_config.bert.model_name, | |
| lora_r=mc_config.bert.lora_r, | |
| lora_alpha=mc_config.bert.lora_alpha, | |
| lora_dropout=mc_config.bert.lora_dropout, | |
| device=mc_config.inference.device, | |
| use_lora=mc_config.bert.use_lora, | |
| ) | |
| # Load trained model if path specified | |
| if mc_config.bert.model_path: | |
| self.meta_controller.load_model(mc_config.bert.model_path) | |
| self.logger.info(f"Loaded BERT model from {mc_config.bert.model_path}") | |
| else: | |
| raise ValueError(f"Unknown meta-controller type: {mc_config.type}") | |
| self.use_neural_routing = True | |
| self.logger.info(f"Initialized {mc_config.type.upper()} neural meta-controller") | |
| except Exception as e: | |
| self.logger.error(f"Failed to initialize meta-controller: {e}") | |
| if hasattr(config, "fallback_to_rule_based") and config.fallback_to_rule_based: | |
| self.logger.warning("Falling back to rule-based routing") | |
| else: | |
| raise | |
| def _extract_meta_controller_features(self, state: AgentState) -> Any: | |
| """ | |
| Extract features from AgentState for meta-controller prediction. | |
| Args: | |
| state: Current agent state | |
| Returns: | |
| MetaControllerFeatures instance | |
| """ | |
| if not _META_CONTROLLER_AVAILABLE or MetaControllerFeatures is None: | |
| return None | |
| # Extract HRM confidence | |
| hrm_conf = 0.0 | |
| if "hrm_results" in state: | |
| hrm_conf = state["hrm_results"].get("metadata", {}).get("decomposition_quality_score", 0.5) | |
| # Extract TRM confidence | |
| trm_conf = 0.0 | |
| if "trm_results" in state: | |
| trm_conf = state["trm_results"].get("metadata", {}).get("final_quality_score", 0.5) | |
| # Extract MCTS value | |
| mcts_val = 0.0 | |
| if "mcts_stats" in state: | |
| mcts_val = state["mcts_stats"].get("best_action_value", 0.5) | |
| # Consensus score | |
| consensus = state.get("consensus_score", 0.0) | |
| # Last agent used | |
| last_agent = state.get("last_routed_agent", "none") | |
| # Iteration | |
| iteration = state.get("iteration", 0) | |
| # Query length | |
| query_length = len(state.get("query", "")) | |
| # Has RAG context | |
| has_rag = bool(state.get("rag_context", "")) | |
| return MetaControllerFeatures( | |
| hrm_confidence=hrm_conf, | |
| trm_confidence=trm_conf, | |
| mcts_value=mcts_val, | |
| consensus_score=consensus, | |
| last_agent=last_agent, | |
| iteration=iteration, | |
| query_length=query_length, | |
| has_rag_context=has_rag, | |
| ) | |
| def _neural_route_decision(self, state: AgentState) -> str: | |
| """ | |
| Make routing decision using neural meta-controller. | |
| Args: | |
| state: Current agent state | |
| Returns: | |
| Route decision string ("parallel", "hrm", "trm", "mcts", "aggregate") | |
| """ | |
| try: | |
| features = self._extract_meta_controller_features(state) | |
| if features is None: | |
| return self._rule_based_route_decision(state) | |
| prediction = self.meta_controller.predict(features) | |
| # Log prediction for debugging | |
| self.logger.debug( | |
| f"Neural routing: agent={prediction.agent}, " | |
| f"confidence={prediction.confidence:.3f}, " | |
| f"probs={prediction.probabilities}" | |
| ) | |
| # Map agent prediction to route | |
| agent = prediction.agent | |
| # Handle routing based on predicted agent | |
| state.get("iteration", 0) | |
| if agent == "hrm": | |
| if "hrm_results" not in state: | |
| return "hrm" | |
| elif agent == "trm": | |
| if "trm_results" not in state: | |
| return "trm" | |
| elif agent == "mcts" and state.get("use_mcts", False) and "mcts_stats" not in state: | |
| return "mcts" | |
| # If predicted agent already ran or not applicable, use rule-based | |
| return self._rule_based_route_decision(state) | |
| except Exception as e: | |
| self.logger.error(f"Neural routing failed: {e}") | |
| # Fallback to rule-based routing | |
| return self._rule_based_route_decision(state) | |
| def _rule_based_route_decision(self, state: AgentState) -> str: | |
| """ | |
| Make routing decision using rule-based logic. | |
| Args: | |
| state: Current agent state | |
| Returns: | |
| Route decision string | |
| """ | |
| iteration = state.get("iteration", 0) | |
| # First iteration: run HRM and TRM | |
| if iteration == 0: | |
| if self.enable_parallel_agents: | |
| if "hrm_results" not in state and "trm_results" not in state: | |
| return "parallel" | |
| else: | |
| if "hrm_results" not in state: | |
| return "hrm" | |
| elif "trm_results" not in state: | |
| return "trm" | |
| # Run MCTS if enabled and not yet done | |
| if state.get("use_mcts", False) and "mcts_stats" not in state: | |
| return "mcts" | |
| return "aggregate" | |
| def _route_to_agents(self, state: AgentState) -> str: | |
| """Route to appropriate agent based on state.""" | |
| # Use neural routing if enabled | |
| if self.use_neural_routing and self.meta_controller is not None: | |
| return self._neural_route_decision(state) | |
| # Fall back to rule-based routing | |
| return self._rule_based_route_decision(state) | |
| async def _parallel_agents_node(self, state: AgentState) -> dict: | |
| """Execute HRM and TRM agents in parallel.""" | |
| self.logger.info("Executing HRM and TRM agents in parallel") | |
| # Run both agents concurrently | |
| hrm_task = asyncio.create_task( | |
| self.hrm_agent.process( | |
| query=state["query"], | |
| rag_context=state.get("rag_context"), | |
| ) | |
| ) | |
| trm_task = asyncio.create_task( | |
| self.trm_agent.process( | |
| query=state["query"], | |
| rag_context=state.get("rag_context"), | |
| ) | |
| ) | |
| # Await both results | |
| hrm_result, trm_result = await asyncio.gather(hrm_task, trm_task) | |
| # Combine outputs | |
| return { | |
| "hrm_results": { | |
| "response": hrm_result["response"], | |
| "metadata": hrm_result["metadata"], | |
| }, | |
| "trm_results": { | |
| "response": trm_result["response"], | |
| "metadata": trm_result["metadata"], | |
| }, | |
| "agent_outputs": [ | |
| { | |
| "agent": "hrm", | |
| "response": hrm_result["response"], | |
| "confidence": hrm_result["metadata"].get("decomposition_quality_score", 0.7), | |
| }, | |
| { | |
| "agent": "trm", | |
| "response": trm_result["response"], | |
| "confidence": trm_result["metadata"].get("final_quality_score", 0.7), | |
| }, | |
| ], | |
| } | |
| async def _hrm_agent_node(self, state: AgentState) -> dict: | |
| """Execute HRM agent.""" | |
| self.logger.info("Executing HRM agent") | |
| result = await self.hrm_agent.process( | |
| query=state["query"], | |
| rag_context=state.get("rag_context"), | |
| ) | |
| return { | |
| "hrm_results": { | |
| "response": result["response"], | |
| "metadata": result["metadata"], | |
| }, | |
| "agent_outputs": [ | |
| { | |
| "agent": "hrm", | |
| "response": result["response"], | |
| "confidence": result["metadata"].get("decomposition_quality_score", 0.7), | |
| } | |
| ], | |
| } | |
| async def _trm_agent_node(self, state: AgentState) -> dict: | |
| """Execute TRM agent.""" | |
| self.logger.info("Executing TRM agent") | |
| result = await self.trm_agent.process( | |
| query=state["query"], | |
| rag_context=state.get("rag_context"), | |
| ) | |
| return { | |
| "trm_results": { | |
| "response": result["response"], | |
| "metadata": result["metadata"], | |
| }, | |
| "agent_outputs": [ | |
| { | |
| "agent": "trm", | |
| "response": result["response"], | |
| "confidence": result["metadata"].get("final_quality_score", 0.7), | |
| } | |
| ], | |
| } | |
| async def _mcts_simulator_node(self, state: AgentState) -> dict: | |
| """Execute MCTS simulation using new deterministic engine.""" | |
| self.logger.info("Executing MCTS simulation with deterministic engine") | |
| start_time = time.perf_counter() | |
| # Reset engine for this simulation | |
| self.mcts_engine.clear_cache() | |
| # Create root state | |
| root_state = MCTSState( | |
| state_id="root", | |
| features={ | |
| "query": state["query"][:100], # Truncate for hashing | |
| "has_hrm": "hrm_results" in state, | |
| "has_trm": "trm_results" in state, | |
| }, | |
| ) | |
| root = MCTSNode( | |
| state=root_state, | |
| rng=self.mcts_engine.rng, | |
| ) | |
| # Define action generator based on domain | |
| def action_generator(mcts_state: MCTSState) -> list[str]: | |
| """Generate available actions for state.""" | |
| depth = len(mcts_state.state_id.split("_")) - 1 | |
| if depth == 0: | |
| # Root level actions | |
| return ["action_A", "action_B", "action_C", "action_D"] | |
| elif depth < self.mcts_config.max_tree_depth: | |
| # Subsequent actions | |
| return ["continue", "refine", "fallback", "escalate"] | |
| else: | |
| return [] # Terminal | |
| # Define state transition | |
| def state_transition(mcts_state: MCTSState, action: str) -> MCTSState: | |
| """Compute next state from action.""" | |
| new_id = f"{mcts_state.state_id}_{action}" | |
| new_features = mcts_state.features.copy() | |
| new_features["last_action"] = action | |
| new_features["depth"] = len(new_id.split("_")) - 1 | |
| return MCTSState(state_id=new_id, features=new_features) | |
| # Create rollout policy using agent results | |
| def heuristic_fn(mcts_state: MCTSState) -> float: | |
| """Evaluate state using agent confidence.""" | |
| base = 0.5 | |
| # Bias based on agent confidence | |
| if state.get("hrm_results"): | |
| hrm_conf = state["hrm_results"]["metadata"].get("decomposition_quality_score", 0.5) | |
| base += hrm_conf * 0.2 | |
| if state.get("trm_results"): | |
| trm_conf = state["trm_results"]["metadata"].get("final_quality_score", 0.5) | |
| base += trm_conf * 0.2 | |
| return min(base, 1.0) | |
| rollout_policy = HybridRolloutPolicy( | |
| heuristic_fn=heuristic_fn, | |
| heuristic_weight=0.7, | |
| random_weight=0.3, | |
| ) | |
| # Run MCTS search | |
| best_action, stats = await self.mcts_engine.search( | |
| root=root, | |
| num_iterations=self.mcts_config.num_iterations, | |
| action_generator=action_generator, | |
| state_transition=state_transition, | |
| rollout_policy=rollout_policy, | |
| max_rollout_depth=self.mcts_config.max_rollout_depth, | |
| selection_policy=self.mcts_config.selection_policy, | |
| ) | |
| end_time = time.perf_counter() | |
| execution_time_ms = (end_time - start_time) * 1000 | |
| # Compute tree statistics | |
| tree_depth = self.mcts_engine.get_tree_depth(root) | |
| tree_node_count = self.mcts_engine.count_nodes(root) | |
| # Track experiment | |
| self.experiment_tracker.create_result( | |
| experiment_id=f"mcts_{int(time.time())}", | |
| config=self.mcts_config, | |
| mcts_stats=stats, | |
| execution_time_ms=execution_time_ms, | |
| tree_depth=tree_depth, | |
| tree_node_count=tree_node_count, | |
| metadata={ | |
| "query": state["query"][:100], | |
| "has_rag": state.get("use_rag", False), | |
| }, | |
| ) | |
| self.logger.info( | |
| f"MCTS complete: best_action={best_action}, " | |
| f"iterations={stats['iterations']}, " | |
| f"cache_hit_rate={stats['cache_hit_rate']:.2%}" | |
| ) | |
| return { | |
| "mcts_root": root, | |
| "mcts_best_action": best_action, | |
| "mcts_stats": stats, | |
| "agent_outputs": [ | |
| { | |
| "agent": "mcts", | |
| "response": ( | |
| f"Simulated {stats['iterations']} scenarios with " | |
| f"seed {self.mcts_config.seed}. " | |
| f"Recommended action: {best_action} " | |
| f"(visits={stats['best_action_visits']}, " | |
| f"value={stats['best_action_value']:.3f})" | |
| ), | |
| "confidence": min( | |
| stats["best_action_visits"] / stats["iterations"] if stats["iterations"] > 0 else 0.5, | |
| 1.0, | |
| ), | |
| } | |
| ], | |
| } | |
| def _aggregate_results_node(self, state: AgentState) -> dict: | |
| """Aggregate results from all agents.""" | |
| self.logger.info("Aggregating agent results") | |
| agent_outputs = state.get("agent_outputs", []) | |
| confidence_scores = {output["agent"]: output["confidence"] for output in agent_outputs} | |
| return {"confidence_scores": confidence_scores} | |
| def _evaluate_consensus_node(self, state: AgentState) -> dict: | |
| """Evaluate consensus among agents.""" | |
| agent_outputs = state.get("agent_outputs", []) | |
| if len(agent_outputs) < 2: | |
| return { | |
| "consensus_reached": True, | |
| "consensus_score": 1.0, | |
| } | |
| avg_confidence = sum(o["confidence"] for o in agent_outputs) / len(agent_outputs) | |
| consensus_reached = avg_confidence >= self.consensus_threshold | |
| self.logger.info(f"Consensus: {consensus_reached} (score={avg_confidence:.2f})") | |
| return { | |
| "consensus_reached": consensus_reached, | |
| "consensus_score": avg_confidence, | |
| } | |
| def _check_consensus(self, state: AgentState) -> str: | |
| """Check if consensus reached or need more iterations.""" | |
| if state.get("consensus_reached", False): | |
| return "synthesize" | |
| if state.get("iteration", 0) >= state.get("max_iterations", self.max_iterations): | |
| return "synthesize" | |
| return "iterate" | |
| async def _synthesize_node(self, state: AgentState) -> dict: | |
| """Synthesize final response from agent outputs.""" | |
| self.logger.info("Synthesizing final response") | |
| agent_outputs = state.get("agent_outputs", []) | |
| synthesis_prompt = f"""Query: {state["query"]} | |
| Agent Outputs: | |
| """ | |
| for output in agent_outputs: | |
| synthesis_prompt += f""" | |
| {output["agent"].upper()} (confidence={output["confidence"]:.2f}): | |
| {output["response"]} | |
| """ | |
| synthesis_prompt += """ | |
| Synthesize these outputs into a comprehensive final response. | |
| Prioritize higher-confidence outputs. Integrate insights from all agents. | |
| Final Response:""" | |
| try: | |
| response = await self.model_adapter.generate( | |
| prompt=synthesis_prompt, | |
| temperature=0.5, | |
| ) | |
| final_response = response.text | |
| except Exception as e: | |
| self.logger.error(f"Synthesis failed: {e}") | |
| best_output = max(agent_outputs, key=lambda o: o["confidence"]) | |
| final_response = best_output["response"] | |
| metadata = { | |
| "agents_used": [o["agent"] for o in agent_outputs], | |
| "confidence_scores": state.get("confidence_scores", {}), | |
| "consensus_score": state.get("consensus_score", 0.0), | |
| "iterations": state.get("iteration", 0), | |
| "mcts_config": state.get("mcts_config", {}), | |
| } | |
| if state.get("mcts_stats"): | |
| metadata["mcts_stats"] = state["mcts_stats"] | |
| return { | |
| "final_response": final_response, | |
| "metadata": metadata, | |
| } | |
| class IntegratedFramework: | |
| """ | |
| Integrated multi-agent framework with new MCTS core. | |
| Maintains backward compatibility with original process() signature. | |
| """ | |
| def __init__( | |
| self, | |
| model_adapter, | |
| logger, | |
| vector_store=None, | |
| _embedding_model=None, | |
| hrm_config: dict | None = None, | |
| trm_config: dict | None = None, | |
| mcts_config: MCTSConfig | None = None, | |
| top_k_retrieval: int = 5, | |
| max_iterations: int = 3, | |
| consensus_threshold: float = 0.75, | |
| enable_parallel_agents: bool = True, | |
| ): | |
| """ | |
| Initialize integrated framework. | |
| Backward compatible with LangGraphMultiAgentFramework. | |
| """ | |
| self.model_adapter = model_adapter | |
| self.logger = logger | |
| self.vector_store = vector_store | |
| # Import agents (would be real imports in production) | |
| try: | |
| from improved_hrm_agent import HRMAgent | |
| from improved_trm_agent import TRMAgent | |
| self.hrm_agent = HRMAgent( | |
| model_adapter=model_adapter, | |
| logger=logger, | |
| **(hrm_config or {}), | |
| ) | |
| self.trm_agent = TRMAgent( | |
| model_adapter=model_adapter, | |
| logger=logger, | |
| **(trm_config or {}), | |
| ) | |
| except ImportError: | |
| self.hrm_agent = None | |
| self.trm_agent = None | |
| self.logger.warning("Could not import HRM/TRM agents") | |
| # Build graph | |
| self.graph_builder = GraphBuilder( | |
| hrm_agent=self.hrm_agent, | |
| trm_agent=self.trm_agent, | |
| model_adapter=model_adapter, | |
| logger=logger, | |
| vector_store=vector_store, | |
| mcts_config=mcts_config, | |
| top_k_retrieval=top_k_retrieval, | |
| max_iterations=max_iterations, | |
| consensus_threshold=consensus_threshold, | |
| enable_parallel_agents=enable_parallel_agents, | |
| ) | |
| # Compile graph | |
| if StateGraph is not None: | |
| self.graph = self.graph_builder.build_graph() | |
| self.memory = MemorySaver() if MemorySaver else None | |
| self.app = self.graph.compile(checkpointer=self.memory) if self.memory else self.graph.compile() | |
| else: | |
| self.graph = None | |
| self.app = None | |
| self.logger.info("Integrated framework initialized with new MCTS core") | |
| async def process( | |
| self, | |
| query: str, | |
| use_rag: bool = True, | |
| use_mcts: bool = False, | |
| config: dict | None = None, | |
| ) -> dict: | |
| """ | |
| Process query through LangGraph. | |
| Backward compatible with original signature. | |
| Args: | |
| query: User query to process | |
| use_rag: Enable RAG context retrieval | |
| use_mcts: Enable MCTS simulation | |
| config: Optional LangGraph config | |
| Returns: | |
| Dictionary with response, metadata, and state | |
| """ | |
| if self.app is None: | |
| raise RuntimeError("LangGraph not available. Install with: pip install langgraph") | |
| initial_state = { | |
| "query": query, | |
| "use_rag": use_rag, | |
| "use_mcts": use_mcts, | |
| "iteration": 0, | |
| "max_iterations": self.graph_builder.max_iterations, | |
| "agent_outputs": [], | |
| } | |
| config = config or {"configurable": {"thread_id": "default"}} | |
| result = await self.app.ainvoke(initial_state, config=config) | |
| return { | |
| "response": result.get("final_response", ""), | |
| "metadata": result.get("metadata", {}), | |
| "state": result, | |
| } | |
| def get_experiment_tracker(self) -> ExperimentTracker: | |
| """Get the experiment tracker for analysis.""" | |
| return self.graph_builder.experiment_tracker | |
| def set_mcts_seed(self, seed: int) -> None: | |
| """Set MCTS seed for deterministic behavior.""" | |
| self.graph_builder.mcts_engine.reset_seed(seed) | |
| self.graph_builder.mcts_config.seed = seed | |