Spaces:
Running
Running
| """ | |
| MCTS Configuration Module - Parameter management and presets. | |
| Provides: | |
| - MCTSConfig dataclass with all parameters | |
| - Validation of parameter bounds | |
| - Preset configurations (fast, balanced, thorough) | |
| - Serialization support for experiment tracking | |
| """ | |
| from __future__ import annotations | |
| import json | |
| from dataclasses import asdict, dataclass | |
| from enum import Enum | |
| from typing import Any | |
| from .policies import SelectionPolicy | |
| class ConfigPreset(Enum): | |
| """Preset configuration names.""" | |
| FAST = "fast" | |
| BALANCED = "balanced" | |
| THOROUGH = "thorough" | |
| EXPLORATION_HEAVY = "exploration_heavy" | |
| EXPLOITATION_HEAVY = "exploitation_heavy" | |
| class MCTSConfig: | |
| """ | |
| Complete configuration for MCTS engine. | |
| All MCTS parameters are centralized here with validation. | |
| Supports serialization for experiment tracking and reproducibility. | |
| """ | |
| # Core MCTS parameters | |
| num_iterations: int = 100 | |
| """Number of MCTS iterations to run.""" | |
| seed: int = 42 | |
| """Random seed for deterministic behavior.""" | |
| exploration_weight: float = 1.414 | |
| """UCB1 exploration constant (c). Higher = more exploration.""" | |
| # Progressive widening | |
| progressive_widening_k: float = 1.0 | |
| """Progressive widening coefficient. Higher = more conservative.""" | |
| progressive_widening_alpha: float = 0.5 | |
| """Progressive widening exponent. Lower = more aggressive expansion.""" | |
| # Rollout configuration | |
| max_rollout_depth: int = 10 | |
| """Maximum depth for rollout simulations.""" | |
| rollout_policy: str = "hybrid" | |
| """Rollout policy: 'random', 'greedy', 'hybrid'.""" | |
| # Action selection | |
| selection_policy: SelectionPolicy = SelectionPolicy.MAX_VISITS | |
| """Policy for final action selection.""" | |
| # Parallelization | |
| max_parallel_rollouts: int = 4 | |
| """Maximum concurrent rollout simulations.""" | |
| # Caching | |
| enable_cache: bool = True | |
| """Enable simulation result caching.""" | |
| cache_size_limit: int = 10000 | |
| """Maximum number of cached simulation results.""" | |
| # Tree structure | |
| max_tree_depth: int = 20 | |
| """Maximum depth of MCTS tree.""" | |
| max_children_per_node: int = 50 | |
| """Maximum children per node (action branching limit).""" | |
| # Early termination | |
| early_termination_threshold: float = 0.95 | |
| """Stop if best action has this fraction of total visits.""" | |
| min_iterations_before_termination: int = 50 | |
| """Minimum iterations before early termination check.""" | |
| # Value bounds | |
| min_value: float = 0.0 | |
| """Minimum value for normalization.""" | |
| max_value: float = 1.0 | |
| """Maximum value for normalization.""" | |
| # Metadata | |
| name: str = "default" | |
| """Configuration name for tracking.""" | |
| description: str = "" | |
| """Description of this configuration.""" | |
| def __post_init__(self): | |
| """Validate configuration parameters after initialization.""" | |
| self.validate() | |
| def validate(self) -> None: | |
| """ | |
| Validate all configuration parameters. | |
| Raises: | |
| ValueError: If any parameter is out of valid bounds. | |
| """ | |
| errors = [] | |
| # Core parameters | |
| if self.num_iterations < 1: | |
| errors.append("num_iterations must be >= 1") | |
| if self.num_iterations > 100000: | |
| errors.append("num_iterations should be <= 100000 for practical use") | |
| if self.exploration_weight < 0: | |
| errors.append("exploration_weight must be >= 0") | |
| if self.exploration_weight > 10: | |
| errors.append("exploration_weight should be <= 10") | |
| # Progressive widening | |
| if self.progressive_widening_k <= 0: | |
| errors.append("progressive_widening_k must be > 0") | |
| if not 0 < self.progressive_widening_alpha < 1: | |
| errors.append("progressive_widening_alpha must be in (0, 1)") | |
| # Rollout | |
| if self.max_rollout_depth < 1: | |
| errors.append("max_rollout_depth must be >= 1") | |
| if self.rollout_policy not in ["random", "greedy", "hybrid", "llm"]: | |
| errors.append("rollout_policy must be one of: random, greedy, hybrid, llm") | |
| # Parallelization | |
| if self.max_parallel_rollouts < 1: | |
| errors.append("max_parallel_rollouts must be >= 1") | |
| if self.max_parallel_rollouts > 100: | |
| errors.append("max_parallel_rollouts should be <= 100") | |
| # Caching | |
| if self.cache_size_limit < 0: | |
| errors.append("cache_size_limit must be >= 0") | |
| # Tree structure | |
| if self.max_tree_depth < 1: | |
| errors.append("max_tree_depth must be >= 1") | |
| if self.max_children_per_node < 1: | |
| errors.append("max_children_per_node must be >= 1") | |
| # Early termination | |
| if not 0 < self.early_termination_threshold <= 1: | |
| errors.append("early_termination_threshold must be in (0, 1]") | |
| if self.min_iterations_before_termination < 1: | |
| errors.append("min_iterations_before_termination must be >= 1") | |
| if self.min_iterations_before_termination > self.num_iterations: | |
| errors.append("min_iterations_before_termination must be <= num_iterations") | |
| # Value bounds | |
| if self.min_value >= self.max_value: | |
| errors.append("min_value must be < max_value") | |
| if errors: | |
| raise ValueError("Invalid MCTS configuration:\n" + "\n".join(f" - {e}" for e in errors)) | |
| def to_dict(self) -> dict[str, Any]: | |
| """ | |
| Convert configuration to dictionary for serialization. | |
| Returns: | |
| Dictionary representation of config. | |
| """ | |
| d = asdict(self) | |
| # Convert enum to string | |
| d["selection_policy"] = self.selection_policy.value | |
| return d | |
| def to_json(self, indent: int = 2) -> str: | |
| """ | |
| Serialize configuration to JSON string. | |
| Args: | |
| indent: JSON indentation level | |
| Returns: | |
| JSON string representation | |
| """ | |
| return json.dumps(self.to_dict(), indent=indent) | |
| def from_dict(cls, data: dict[str, Any]) -> MCTSConfig: | |
| """ | |
| Create configuration from dictionary. | |
| Args: | |
| data: Dictionary with configuration parameters | |
| Returns: | |
| MCTSConfig instance | |
| """ | |
| # Convert selection_policy string back to enum | |
| if "selection_policy" in data and isinstance(data["selection_policy"], str): | |
| data["selection_policy"] = SelectionPolicy(data["selection_policy"]) | |
| return cls(**data) | |
| def from_json(cls, json_str: str) -> MCTSConfig: | |
| """ | |
| Deserialize configuration from JSON string. | |
| Args: | |
| json_str: JSON string | |
| Returns: | |
| MCTSConfig instance | |
| """ | |
| data = json.loads(json_str) | |
| return cls.from_dict(data) | |
| def copy(self, **overrides) -> MCTSConfig: | |
| """ | |
| Create a copy with optional parameter overrides. | |
| Args: | |
| **overrides: Parameters to override | |
| Returns: | |
| New MCTSConfig instance | |
| """ | |
| data = self.to_dict() | |
| data.update(overrides) | |
| return self.from_dict(data) | |
| def __repr__(self) -> str: | |
| return ( | |
| f"MCTSConfig(name={self.name!r}, " | |
| f"iterations={self.num_iterations}, " | |
| f"c={self.exploration_weight}, " | |
| f"widening_k={self.progressive_widening_k}, " | |
| f"widening_alpha={self.progressive_widening_alpha})" | |
| ) | |
| def create_preset_config(preset: ConfigPreset) -> MCTSConfig: | |
| """ | |
| Create a preset configuration. | |
| Args: | |
| preset: Preset type to create | |
| Returns: | |
| MCTSConfig with preset parameters | |
| """ | |
| if preset == ConfigPreset.FAST: | |
| return MCTSConfig( | |
| name="fast", | |
| description="Fast search with minimal iterations", | |
| num_iterations=25, | |
| exploration_weight=1.414, | |
| progressive_widening_k=0.5, # Aggressive widening | |
| progressive_widening_alpha=0.5, | |
| max_rollout_depth=5, | |
| rollout_policy="random", | |
| selection_policy=SelectionPolicy.MAX_VISITS, | |
| max_parallel_rollouts=8, | |
| cache_size_limit=1000, | |
| early_termination_threshold=0.8, | |
| min_iterations_before_termination=10, | |
| ) | |
| elif preset == ConfigPreset.BALANCED: | |
| return MCTSConfig( | |
| name="balanced", | |
| description="Balanced search for typical use cases", | |
| num_iterations=100, | |
| exploration_weight=1.414, | |
| progressive_widening_k=1.0, | |
| progressive_widening_alpha=0.5, | |
| max_rollout_depth=10, | |
| rollout_policy="hybrid", | |
| selection_policy=SelectionPolicy.MAX_VISITS, | |
| max_parallel_rollouts=4, | |
| cache_size_limit=10000, | |
| early_termination_threshold=0.9, | |
| min_iterations_before_termination=50, | |
| ) | |
| elif preset == ConfigPreset.THOROUGH: | |
| return MCTSConfig( | |
| name="thorough", | |
| description="Thorough search for high-stakes decisions", | |
| num_iterations=500, | |
| exploration_weight=1.414, | |
| progressive_widening_k=2.0, # Conservative widening | |
| progressive_widening_alpha=0.6, | |
| max_rollout_depth=20, | |
| rollout_policy="hybrid", | |
| selection_policy=SelectionPolicy.ROBUST_CHILD, | |
| max_parallel_rollouts=4, | |
| cache_size_limit=50000, | |
| early_termination_threshold=0.95, | |
| min_iterations_before_termination=200, | |
| ) | |
| elif preset == ConfigPreset.EXPLORATION_HEAVY: | |
| return MCTSConfig( | |
| name="exploration_heavy", | |
| description="High exploration for diverse action discovery", | |
| num_iterations=200, | |
| exploration_weight=2.5, # High exploration | |
| progressive_widening_k=0.8, # More widening | |
| progressive_widening_alpha=0.4, # Aggressive | |
| max_rollout_depth=15, | |
| rollout_policy="random", | |
| selection_policy=SelectionPolicy.MAX_VISITS, | |
| max_parallel_rollouts=6, | |
| cache_size_limit=20000, | |
| early_termination_threshold=0.95, | |
| min_iterations_before_termination=100, | |
| ) | |
| elif preset == ConfigPreset.EXPLOITATION_HEAVY: | |
| return MCTSConfig( | |
| name="exploitation_heavy", | |
| description="High exploitation for known-good action refinement", | |
| num_iterations=150, | |
| exploration_weight=0.5, # Low exploration | |
| progressive_widening_k=3.0, # Conservative | |
| progressive_widening_alpha=0.7, # Very conservative | |
| max_rollout_depth=10, | |
| rollout_policy="greedy", | |
| selection_policy=SelectionPolicy.MAX_VALUE, | |
| max_parallel_rollouts=4, | |
| cache_size_limit=10000, | |
| early_termination_threshold=0.85, | |
| min_iterations_before_termination=75, | |
| ) | |
| else: | |
| raise ValueError(f"Unknown preset: {preset}") | |
| # Default configurations for easy access | |
| DEFAULT_CONFIG = MCTSConfig() | |
| FAST_CONFIG = create_preset_config(ConfigPreset.FAST) | |
| BALANCED_CONFIG = create_preset_config(ConfigPreset.BALANCED) | |
| THOROUGH_CONFIG = create_preset_config(ConfigPreset.THOROUGH) | |