""" MCTS Configuration Module - Parameter management and presets. Provides: - MCTSConfig dataclass with all parameters - Validation of parameter bounds - Preset configurations (fast, balanced, thorough) - Serialization support for experiment tracking """ from __future__ import annotations import json from dataclasses import asdict, dataclass from enum import Enum from typing import Any from .policies import SelectionPolicy class ConfigPreset(Enum): """Preset configuration names.""" FAST = "fast" BALANCED = "balanced" THOROUGH = "thorough" EXPLORATION_HEAVY = "exploration_heavy" EXPLOITATION_HEAVY = "exploitation_heavy" @dataclass class MCTSConfig: """ Complete configuration for MCTS engine. All MCTS parameters are centralized here with validation. Supports serialization for experiment tracking and reproducibility. """ # Core MCTS parameters num_iterations: int = 100 """Number of MCTS iterations to run.""" seed: int = 42 """Random seed for deterministic behavior.""" exploration_weight: float = 1.414 """UCB1 exploration constant (c). Higher = more exploration.""" # Progressive widening progressive_widening_k: float = 1.0 """Progressive widening coefficient. Higher = more conservative.""" progressive_widening_alpha: float = 0.5 """Progressive widening exponent. Lower = more aggressive expansion.""" # Rollout configuration max_rollout_depth: int = 10 """Maximum depth for rollout simulations.""" rollout_policy: str = "hybrid" """Rollout policy: 'random', 'greedy', 'hybrid'.""" # Action selection selection_policy: SelectionPolicy = SelectionPolicy.MAX_VISITS """Policy for final action selection.""" # Parallelization max_parallel_rollouts: int = 4 """Maximum concurrent rollout simulations.""" # Caching enable_cache: bool = True """Enable simulation result caching.""" cache_size_limit: int = 10000 """Maximum number of cached simulation results.""" # Tree structure max_tree_depth: int = 20 """Maximum depth of MCTS tree.""" max_children_per_node: int = 50 """Maximum children per node (action branching limit).""" # Early termination early_termination_threshold: float = 0.95 """Stop if best action has this fraction of total visits.""" min_iterations_before_termination: int = 50 """Minimum iterations before early termination check.""" # Value bounds min_value: float = 0.0 """Minimum value for normalization.""" max_value: float = 1.0 """Maximum value for normalization.""" # Metadata name: str = "default" """Configuration name for tracking.""" description: str = "" """Description of this configuration.""" def __post_init__(self): """Validate configuration parameters after initialization.""" self.validate() def validate(self) -> None: """ Validate all configuration parameters. Raises: ValueError: If any parameter is out of valid bounds. """ errors = [] # Core parameters if self.num_iterations < 1: errors.append("num_iterations must be >= 1") if self.num_iterations > 100000: errors.append("num_iterations should be <= 100000 for practical use") if self.exploration_weight < 0: errors.append("exploration_weight must be >= 0") if self.exploration_weight > 10: errors.append("exploration_weight should be <= 10") # Progressive widening if self.progressive_widening_k <= 0: errors.append("progressive_widening_k must be > 0") if not 0 < self.progressive_widening_alpha < 1: errors.append("progressive_widening_alpha must be in (0, 1)") # Rollout if self.max_rollout_depth < 1: errors.append("max_rollout_depth must be >= 1") if self.rollout_policy not in ["random", "greedy", "hybrid", "llm"]: errors.append("rollout_policy must be one of: random, greedy, hybrid, llm") # Parallelization if self.max_parallel_rollouts < 1: errors.append("max_parallel_rollouts must be >= 1") if self.max_parallel_rollouts > 100: errors.append("max_parallel_rollouts should be <= 100") # Caching if self.cache_size_limit < 0: errors.append("cache_size_limit must be >= 0") # Tree structure if self.max_tree_depth < 1: errors.append("max_tree_depth must be >= 1") if self.max_children_per_node < 1: errors.append("max_children_per_node must be >= 1") # Early termination if not 0 < self.early_termination_threshold <= 1: errors.append("early_termination_threshold must be in (0, 1]") if self.min_iterations_before_termination < 1: errors.append("min_iterations_before_termination must be >= 1") if self.min_iterations_before_termination > self.num_iterations: errors.append("min_iterations_before_termination must be <= num_iterations") # Value bounds if self.min_value >= self.max_value: errors.append("min_value must be < max_value") if errors: raise ValueError("Invalid MCTS configuration:\n" + "\n".join(f" - {e}" for e in errors)) def to_dict(self) -> dict[str, Any]: """ Convert configuration to dictionary for serialization. Returns: Dictionary representation of config. """ d = asdict(self) # Convert enum to string d["selection_policy"] = self.selection_policy.value return d def to_json(self, indent: int = 2) -> str: """ Serialize configuration to JSON string. Args: indent: JSON indentation level Returns: JSON string representation """ return json.dumps(self.to_dict(), indent=indent) @classmethod def from_dict(cls, data: dict[str, Any]) -> MCTSConfig: """ Create configuration from dictionary. Args: data: Dictionary with configuration parameters Returns: MCTSConfig instance """ # Convert selection_policy string back to enum if "selection_policy" in data and isinstance(data["selection_policy"], str): data["selection_policy"] = SelectionPolicy(data["selection_policy"]) return cls(**data) @classmethod def from_json(cls, json_str: str) -> MCTSConfig: """ Deserialize configuration from JSON string. Args: json_str: JSON string Returns: MCTSConfig instance """ data = json.loads(json_str) return cls.from_dict(data) def copy(self, **overrides) -> MCTSConfig: """ Create a copy with optional parameter overrides. Args: **overrides: Parameters to override Returns: New MCTSConfig instance """ data = self.to_dict() data.update(overrides) return self.from_dict(data) def __repr__(self) -> str: return ( f"MCTSConfig(name={self.name!r}, " f"iterations={self.num_iterations}, " f"c={self.exploration_weight}, " f"widening_k={self.progressive_widening_k}, " f"widening_alpha={self.progressive_widening_alpha})" ) def create_preset_config(preset: ConfigPreset) -> MCTSConfig: """ Create a preset configuration. Args: preset: Preset type to create Returns: MCTSConfig with preset parameters """ if preset == ConfigPreset.FAST: return MCTSConfig( name="fast", description="Fast search with minimal iterations", num_iterations=25, exploration_weight=1.414, progressive_widening_k=0.5, # Aggressive widening progressive_widening_alpha=0.5, max_rollout_depth=5, rollout_policy="random", selection_policy=SelectionPolicy.MAX_VISITS, max_parallel_rollouts=8, cache_size_limit=1000, early_termination_threshold=0.8, min_iterations_before_termination=10, ) elif preset == ConfigPreset.BALANCED: return MCTSConfig( name="balanced", description="Balanced search for typical use cases", num_iterations=100, exploration_weight=1.414, progressive_widening_k=1.0, progressive_widening_alpha=0.5, max_rollout_depth=10, rollout_policy="hybrid", selection_policy=SelectionPolicy.MAX_VISITS, max_parallel_rollouts=4, cache_size_limit=10000, early_termination_threshold=0.9, min_iterations_before_termination=50, ) elif preset == ConfigPreset.THOROUGH: return MCTSConfig( name="thorough", description="Thorough search for high-stakes decisions", num_iterations=500, exploration_weight=1.414, progressive_widening_k=2.0, # Conservative widening progressive_widening_alpha=0.6, max_rollout_depth=20, rollout_policy="hybrid", selection_policy=SelectionPolicy.ROBUST_CHILD, max_parallel_rollouts=4, cache_size_limit=50000, early_termination_threshold=0.95, min_iterations_before_termination=200, ) elif preset == ConfigPreset.EXPLORATION_HEAVY: return MCTSConfig( name="exploration_heavy", description="High exploration for diverse action discovery", num_iterations=200, exploration_weight=2.5, # High exploration progressive_widening_k=0.8, # More widening progressive_widening_alpha=0.4, # Aggressive max_rollout_depth=15, rollout_policy="random", selection_policy=SelectionPolicy.MAX_VISITS, max_parallel_rollouts=6, cache_size_limit=20000, early_termination_threshold=0.95, min_iterations_before_termination=100, ) elif preset == ConfigPreset.EXPLOITATION_HEAVY: return MCTSConfig( name="exploitation_heavy", description="High exploitation for known-good action refinement", num_iterations=150, exploration_weight=0.5, # Low exploration progressive_widening_k=3.0, # Conservative progressive_widening_alpha=0.7, # Very conservative max_rollout_depth=10, rollout_policy="greedy", selection_policy=SelectionPolicy.MAX_VALUE, max_parallel_rollouts=4, cache_size_limit=10000, early_termination_threshold=0.85, min_iterations_before_termination=75, ) else: raise ValueError(f"Unknown preset: {preset}") # Default configurations for easy access DEFAULT_CONFIG = MCTSConfig() FAST_CONFIG = create_preset_config(ConfigPreset.FAST) BALANCED_CONFIG = create_preset_config(ConfigPreset.BALANCED) THOROUGH_CONFIG = create_preset_config(ConfigPreset.THOROUGH)