""" Debug utilities for multi-agent MCTS framework. Provides: - MCTS tree visualization (text-based) - Step-by-step MCTS execution logging when LOG_LEVEL=DEBUG - UCB score logging at each selection - State diff tracking between iterations - Export tree to DOT format for graphviz """ import logging import os from typing import Any from .logging import get_logger class MCTSDebugger: """ Comprehensive debugger for MCTS operations. Provides detailed step-by-step logging, tree visualization, and state tracking for MCTS execution. """ def __init__(self, session_id: str = "default", enabled: bool | None = None): """ Initialize MCTS debugger. Args: session_id: Unique identifier for debug session enabled: Enable debugging (defaults to LOG_LEVEL=DEBUG) """ self.session_id = session_id self.logger = get_logger("observability.debug") # Auto-enable if LOG_LEVEL is DEBUG if enabled is None: log_level = os.environ.get("LOG_LEVEL", "INFO").upper() self.enabled = log_level == "DEBUG" else: self.enabled = enabled # State tracking self._iteration_count = 0 self._state_history: list[dict[str, Any]] = [] self._selection_history: list[dict[str, Any]] = [] self._ucb_history: list[dict[str, float]] = [] def log_iteration_start(self, iteration: int) -> None: """Log the start of an MCTS iteration.""" if not self.enabled: return self._iteration_count = iteration self.logger.debug( f"=== MCTS Iteration {iteration} START ===", extra={ "debug_event": "iteration_start", "mcts_iteration": iteration, "session_id": self.session_id, }, ) def log_selection( self, node_id: str, ucb_score: float, visits: int, value: float, depth: int, children_count: int, is_selected: bool = False, ) -> None: """Log UCB score and selection decision for a node.""" if not self.enabled: return selection_data = { "node_id": node_id, "ucb_score": round(ucb_score, 6), "visits": visits, "value": round(value, 6), "avg_value": round(value / max(visits, 1), 6), "depth": depth, "children_count": children_count, "is_selected": is_selected, } self._selection_history.append(selection_data) log_msg = f"Selection: node={node_id} UCB={ucb_score:.4f} visits={visits} value={value:.4f} depth={depth}" if is_selected: log_msg += " [SELECTED]" self.logger.debug( log_msg, extra={ "debug_event": "mcts_selection", "selection": selection_data, "session_id": self.session_id, "mcts_iteration": self._iteration_count, }, ) def log_ucb_comparison( self, parent_id: str, children_ucb: dict[str, float], selected_child: str, ) -> None: """Log UCB score comparison for all children of a node.""" if not self.enabled: return self._ucb_history.append(children_ucb) ucb_summary = ", ".join( [ f"{cid}={score:.4f}{'*' if cid == selected_child else ''}" for cid, score in sorted(children_ucb.items(), key=lambda x: x[1], reverse=True) ] ) self.logger.debug( f"UCB Comparison at {parent_id}: {ucb_summary}", extra={ "debug_event": "ucb_comparison", "parent_id": parent_id, "children_ucb": {k: round(v, 6) for k, v in children_ucb.items()}, "selected_child": selected_child, "session_id": self.session_id, "mcts_iteration": self._iteration_count, }, ) def log_expansion( self, parent_id: str, action: str, new_node_id: str, available_actions: list[str], ) -> None: """Log node expansion details.""" if not self.enabled: return self.logger.debug( f"Expansion: parent={parent_id} action={action} new_node={new_node_id} " f"available={len(available_actions)} actions", extra={ "debug_event": "mcts_expansion", "parent_id": parent_id, "action": action, "new_node_id": new_node_id, "available_actions": available_actions, "session_id": self.session_id, "mcts_iteration": self._iteration_count, }, ) def log_simulation( self, node_id: str, simulation_result: float, simulation_details: dict[str, Any] | None = None, ) -> None: """Log simulation/rollout results.""" if not self.enabled: return self.logger.debug( f"Simulation: node={node_id} result={simulation_result:.4f}", extra={ "debug_event": "mcts_simulation", "node_id": node_id, "simulation_result": round(simulation_result, 6), "simulation_details": simulation_details or {}, "session_id": self.session_id, "mcts_iteration": self._iteration_count, }, ) def log_backpropagation( self, path: list[str], value: float, updates: list[dict[str, Any]], ) -> None: """Log backpropagation path and value updates.""" if not self.enabled: return self.logger.debug( f"Backprop: path={' -> '.join(path)} value={value:.4f}", extra={ "debug_event": "mcts_backprop", "path": path, "value": round(value, 6), "updates": updates, "session_id": self.session_id, "mcts_iteration": self._iteration_count, }, ) def log_iteration_end( self, iteration: int, best_action: str, best_ucb: float, tree_size: int, ) -> None: """Log the end of an MCTS iteration.""" if not self.enabled: return self.logger.debug( f"=== MCTS Iteration {iteration} END === " f"best_action={best_action} UCB={best_ucb:.4f} tree_size={tree_size}", extra={ "debug_event": "iteration_end", "mcts_iteration": iteration, "best_action": best_action, "best_ucb": round(best_ucb, 6), "tree_size": tree_size, "session_id": self.session_id, }, ) def log_state_diff( self, old_state: dict[str, Any], new_state: dict[str, Any], description: str = "State change", ) -> None: """Log differences between two states.""" if not self.enabled: return diff = self._compute_state_diff(old_state, new_state) if diff: self._state_history.append( { "iteration": self._iteration_count, "description": description, "diff": diff, } ) self.logger.debug( f"State diff: {description}", extra={ "debug_event": "state_diff", "description": description, "diff": diff, "session_id": self.session_id, "mcts_iteration": self._iteration_count, }, ) def _compute_state_diff( self, old: dict[str, Any], new: dict[str, Any], prefix: str = "", ) -> dict[str, Any]: """Compute differences between two dictionaries.""" diff = {} all_keys = set(old.keys()) | set(new.keys()) for key in all_keys: full_key = f"{prefix}.{key}" if prefix else key if key not in old: diff[full_key] = {"added": new[key]} elif key not in new: diff[full_key] = {"removed": old[key]} elif old[key] != new[key]: if isinstance(old[key], dict) and isinstance(new[key], dict): nested_diff = self._compute_state_diff(old[key], new[key], full_key) diff.update(nested_diff) else: diff[full_key] = {"old": old[key], "new": new[key]} return diff def get_debug_summary(self) -> dict[str, Any]: """Get summary of debug information collected.""" return { "session_id": self.session_id, "total_iterations": self._iteration_count, "selection_history_count": len(self._selection_history), "state_changes_count": len(self._state_history), "ucb_comparisons_count": len(self._ucb_history), } def visualize_mcts_tree( root_node: Any, max_depth: int = 10, max_children: int = 5, show_ucb: bool = True, indent: str = " ", ) -> str: """ Generate text-based visualization of MCTS tree. Args: root_node: Root MCTSNode max_depth: Maximum depth to visualize max_children: Maximum children to show per node show_ucb: Show UCB scores indent: Indentation string Returns: Text representation of the tree """ lines = ["MCTS Tree Visualization", "=" * 40] def render_node(node: Any, depth: int = 0, prefix: str = "") -> None: if depth > max_depth: lines.append(f"{prefix}{indent}... (max depth reached)") return # Node info visits = getattr(node, "visits", 0) value = getattr(node, "value", 0.0) action = getattr(node, "action", "root") state_id = getattr(node, "state_id", "unknown") avg_value = value / max(visits, 1) node_info = f"[{state_id}] action={action} visits={visits} value={value:.3f} avg={avg_value:.3f}" if show_ucb and hasattr(node, "ucb1") and visits > 0: try: ucb = node.ucb1() if ucb != float("inf"): node_info += f" UCB={ucb:.3f}" except Exception: pass lines.append(f"{prefix}{node_info}") # Children children = getattr(node, "children", []) if children: # Sort by visits (most visited first) sorted_children = sorted(children, key=lambda c: getattr(c, "visits", 0), reverse=True) display_children = sorted_children[:max_children] for i, child in enumerate(display_children): is_last = i == len(display_children) - 1 child_prefix = prefix + indent + ("└── " if is_last else "├── ") next_prefix = prefix + indent + (" " if is_last else "│ ") lines.append(f"{child_prefix[:-4]}") render_node(child, depth + 1, next_prefix) if len(children) > max_children: lines.append(f"{prefix}{indent}... and {len(children) - max_children} more children") render_node(root_node) lines.append("=" * 40) return "\n".join(lines) def export_tree_to_dot( root_node: Any, filename: str = "mcts_tree.dot", max_depth: int = 10, include_ucb: bool = True, ) -> str: """ Export MCTS tree to DOT format for graphviz visualization. Args: root_node: Root MCTSNode filename: Output filename (optional) max_depth: Maximum depth to export include_ucb: Include UCB scores in labels Returns: DOT format string """ lines = [ "digraph MCTSTree {", ' graph [rankdir=TB, label="MCTS Tree", fontsize=16];', " node [shape=box, style=filled, fillcolor=lightblue];", " edge [fontsize=10];", "", ] node_counter = [0] # Use list for mutable counter in closure def add_node(node: Any, depth: int = 0, parent_dot_id: str | None = None) -> None: if depth > max_depth: return # Generate unique DOT ID dot_id = f"node_{node_counter[0]}" node_counter[0] += 1 # Node attributes visits = getattr(node, "visits", 0) value = getattr(node, "value", 0.0) action = getattr(node, "action", "root") state_id = getattr(node, "state_id", "unknown") avg_value = value / max(visits, 1) # Build label label_parts = [ f"ID: {state_id}", f"Action: {action}", f"Visits: {visits}", f"Value: {value:.3f}", f"Avg: {avg_value:.3f}", ] if include_ucb and hasattr(node, "ucb1") and visits > 0: try: ucb = node.ucb1() if ucb != float("inf"): label_parts.append(f"UCB: {ucb:.3f}") except Exception: pass label = "\\n".join(label_parts) # Color based on value if avg_value >= 0.7: color = "lightgreen" elif avg_value >= 0.4: color = "lightyellow" else: color = "lightcoral" lines.append(f' {dot_id} [label="{label}", fillcolor={color}];') # Edge from parent if parent_dot_id: lines.append(f' {parent_dot_id} -> {dot_id} [label="{action}"];') # Process children children = getattr(node, "children", []) for child in children: add_node(child, depth + 1, dot_id) add_node(root_node) lines.append("}") dot_content = "\n".join(lines) # Write to file if filename provided if filename: with open(filename, "w") as f: f.write(dot_content) return dot_content def print_debug_banner(message: str, char: str = "=", width: int = 60) -> None: """Print a debug banner message.""" logger = get_logger("observability.debug") border = char * width logger.debug(border) logger.debug(f"{message:^{width}}") logger.debug(border) def log_agent_state_snapshot( agent_name: str, state: dict[str, Any], include_keys: list[str] | None = None, ) -> None: """ Log a snapshot of agent state for debugging. Args: agent_name: Name of the agent state: Current state dictionary include_keys: Specific keys to include (None = all) """ logger = get_logger("observability.debug") filtered_state = {k: state.get(k) for k in include_keys if k in state} if include_keys else state logger.debug( f"Agent {agent_name} state snapshot", extra={ "debug_event": "agent_state_snapshot", "agent_name": agent_name, "state": filtered_state, }, ) def enable_verbose_debugging() -> None: """Enable verbose debugging by setting LOG_LEVEL to DEBUG.""" os.environ["LOG_LEVEL"] = "DEBUG" # Reconfigure root logger root_logger = logging.getLogger() root_logger.setLevel(logging.DEBUG) for handler in root_logger.handlers: handler.setLevel(logging.DEBUG) logger = get_logger("observability.debug") logger.info("Verbose debugging ENABLED") def disable_verbose_debugging() -> None: """Disable verbose debugging by setting LOG_LEVEL to INFO.""" os.environ["LOG_LEVEL"] = "INFO" root_logger = logging.getLogger() root_logger.setLevel(logging.INFO) for handler in root_logger.handlers: handler.setLevel(logging.INFO) logger = get_logger("observability.debug") logger.info("Verbose debugging DISABLED")