Spaces:
Running
Running
| """ | |
| MCTS Policies Module - Selection, rollout, and evaluation policies. | |
| Provides: | |
| - UCB1 with configurable exploration weight | |
| - Rollout heuristics (random, greedy, hybrid) | |
| - Action selection policies (max visits, max value, robust child) | |
| - Progressive widening parameters | |
| """ | |
| from __future__ import annotations | |
| import math | |
| from abc import ABC, abstractmethod | |
| from collections.abc import Awaitable, Callable | |
| from enum import Enum | |
| from typing import TYPE_CHECKING | |
| import numpy as np | |
| if TYPE_CHECKING: | |
| from .core import MCTSState | |
| def ucb1( | |
| value_sum: float, | |
| visits: int, | |
| parent_visits: int, | |
| c: float = 1.414, | |
| ) -> float: | |
| """ | |
| Upper Confidence Bound 1 (UCB1) formula for tree selection. | |
| Formula: Q(s,a) + c * sqrt(N(s)) / sqrt(N(s,a)) | |
| Args: | |
| value_sum: Total accumulated value for the node | |
| visits: Number of visits to the node | |
| parent_visits: Number of visits to the parent node | |
| c: Exploration weight constant (default sqrt(2)) | |
| Returns: | |
| UCB1 score for node selection | |
| """ | |
| if visits == 0: | |
| return float("inf") | |
| exploitation = value_sum / visits | |
| exploration = c * ((parent_visits) ** 0.5 / (visits) ** 0.5) | |
| return exploitation + exploration | |
| def ucb1_tuned( | |
| value_sum: float, | |
| value_squared_sum: float, | |
| visits: int, | |
| parent_visits: int, | |
| c: float = 1.0, | |
| ) -> float: | |
| """ | |
| UCB1-Tuned variant with variance estimate. | |
| Provides tighter bounds by considering value variance. | |
| Args: | |
| value_sum: Total accumulated value | |
| value_squared_sum: Sum of squared values (for variance) | |
| visits: Number of visits | |
| parent_visits: Parent visit count | |
| c: Exploration constant | |
| Returns: | |
| UCB1-Tuned score | |
| """ | |
| if visits == 0: | |
| return float("inf") | |
| mean_value = value_sum / visits | |
| variance = value_squared_sum / visits - mean_value**2 | |
| variance = max(0, variance) # Ensure non-negative | |
| # Variance bound term | |
| ln_parent = math.log(parent_visits) | |
| variance_bound = variance + math.sqrt(2 * ln_parent / visits) | |
| min_bound = min(0.25, variance_bound) | |
| exploitation = mean_value | |
| exploration = c * math.sqrt(ln_parent / visits * min_bound) | |
| return exploitation + exploration | |
| class SelectionPolicy(Enum): | |
| """Policy for selecting the final action after MCTS search.""" | |
| MAX_VISITS = "max_visits" | |
| """Select action with most visits (most robust).""" | |
| MAX_VALUE = "max_value" | |
| """Select action with highest average value (greedy).""" | |
| ROBUST_CHILD = "robust_child" | |
| """Select action balancing visits and value.""" | |
| SECURE_CHILD = "secure_child" | |
| """Select action with lowest lower confidence bound.""" | |
| class RolloutPolicy(ABC): | |
| """Abstract base class for rollout/simulation policies.""" | |
| async def evaluate( | |
| self, | |
| state: MCTSState, | |
| rng: np.random.Generator, | |
| max_depth: int = 10, | |
| ) -> float: | |
| """ | |
| Evaluate a state through rollout simulation. | |
| Args: | |
| state: State to evaluate | |
| rng: Seeded random number generator | |
| max_depth: Maximum rollout depth | |
| Returns: | |
| Estimated value in [0, 1] range | |
| """ | |
| pass | |
| class RandomRolloutPolicy(RolloutPolicy): | |
| """Random rollout policy - uniform random evaluation.""" | |
| def __init__(self, base_value: float = 0.5, noise_scale: float = 0.3): | |
| """ | |
| Initialize random rollout policy. | |
| Args: | |
| base_value: Base value for evaluations | |
| noise_scale: Scale of random noise | |
| """ | |
| self.base_value = base_value | |
| self.noise_scale = noise_scale | |
| async def evaluate( | |
| self, | |
| _state: MCTSState, | |
| rng: np.random.Generator, | |
| _max_depth: int = 10, | |
| ) -> float: | |
| """Generate random evaluation with noise.""" | |
| noise = rng.uniform(-self.noise_scale, self.noise_scale) | |
| value = self.base_value + noise | |
| return max(0.0, min(1.0, value)) | |
| class GreedyRolloutPolicy(RolloutPolicy): | |
| """Greedy rollout policy using domain heuristics.""" | |
| def __init__( | |
| self, | |
| heuristic_fn: Callable[[MCTSState], float], | |
| noise_scale: float = 0.05, | |
| ): | |
| """ | |
| Initialize greedy rollout policy. | |
| Args: | |
| heuristic_fn: Function to evaluate state heuristically | |
| noise_scale: Small noise for tie-breaking | |
| """ | |
| self.heuristic_fn = heuristic_fn | |
| self.noise_scale = noise_scale | |
| async def evaluate( | |
| self, | |
| state: MCTSState, | |
| rng: np.random.Generator, | |
| _max_depth: int = 10, | |
| ) -> float: | |
| """Evaluate using heuristic with small noise.""" | |
| base_value = self.heuristic_fn(state) | |
| noise = rng.uniform(-self.noise_scale, self.noise_scale) | |
| value = base_value + noise | |
| return max(0.0, min(1.0, value)) | |
| class HybridRolloutPolicy(RolloutPolicy): | |
| """Hybrid policy combining random and heuristic evaluation.""" | |
| def __init__( | |
| self, | |
| heuristic_fn: Callable[[MCTSState], float] | None = None, | |
| heuristic_weight: float = 0.7, | |
| random_weight: float = 0.3, | |
| base_random_value: float = 0.5, | |
| noise_scale: float = 0.2, | |
| ): | |
| """ | |
| Initialize hybrid rollout policy. | |
| Args: | |
| heuristic_fn: Optional heuristic evaluation function | |
| heuristic_weight: Weight for heuristic component | |
| random_weight: Weight for random component | |
| base_random_value: Base value for random component | |
| noise_scale: Noise scale for random component | |
| """ | |
| self.heuristic_fn = heuristic_fn | |
| self.heuristic_weight = heuristic_weight | |
| self.random_weight = random_weight | |
| self.base_random_value = base_random_value | |
| self.noise_scale = noise_scale | |
| # Normalize weights | |
| total_weight = heuristic_weight + random_weight | |
| if total_weight > 0: | |
| self.heuristic_weight /= total_weight | |
| self.random_weight /= total_weight | |
| async def evaluate( | |
| self, | |
| state: MCTSState, | |
| rng: np.random.Generator, | |
| _max_depth: int = 10, | |
| ) -> float: | |
| """Combine heuristic and random evaluation.""" | |
| # Random component | |
| random_noise = rng.uniform(-self.noise_scale, self.noise_scale) | |
| random_value = self.base_random_value + random_noise | |
| # Heuristic component | |
| heuristic_value = self.heuristic_fn(state) if self.heuristic_fn is not None else self.base_random_value | |
| # Combine | |
| value = self.heuristic_weight * heuristic_value + self.random_weight * random_value | |
| return max(0.0, min(1.0, value)) | |
| class LLMRolloutPolicy(RolloutPolicy): | |
| """Rollout policy that uses an LLM for state evaluation.""" | |
| def __init__( | |
| self, | |
| evaluate_fn: Callable[[MCTSState], Awaitable[float]], | |
| cache_results: bool = True, | |
| ): | |
| """ | |
| Initialize LLM rollout policy. | |
| Args: | |
| evaluate_fn: Async function to evaluate state with LLM | |
| cache_results: Whether to cache evaluation results | |
| """ | |
| self.evaluate_fn = evaluate_fn | |
| self.cache_results = cache_results | |
| self._cache: dict = {} | |
| async def evaluate( | |
| self, | |
| state: MCTSState, | |
| _rng: np.random.Generator, | |
| _max_depth: int = 10, | |
| ) -> float: | |
| """Evaluate state using LLM.""" | |
| state_key = state.to_hash_key() | |
| if self.cache_results and state_key in self._cache: | |
| return self._cache[state_key] | |
| value = await self.evaluate_fn(state) | |
| value = max(0.0, min(1.0, value)) | |
| if self.cache_results: | |
| self._cache[state_key] = value | |
| return value | |
| class ProgressiveWideningConfig: | |
| """Configuration for progressive widening in MCTS.""" | |
| def __init__( | |
| self, | |
| k: float = 1.0, | |
| alpha: float = 0.5, | |
| ): | |
| """ | |
| Configure progressive widening parameters. | |
| Progressive widening expands when: visits > k * num_children^alpha | |
| Args: | |
| k: Coefficient controlling expansion threshold | |
| alpha: Exponent controlling growth rate | |
| Common configurations: | |
| - k=1.0, alpha=0.5: Moderate widening (default) | |
| - k=2.0, alpha=0.5: Conservative (fewer expansions) | |
| - k=0.5, alpha=0.5: Aggressive (more expansions) | |
| - k=1.0, alpha=0.3: Very aggressive | |
| - k=1.0, alpha=0.7: Very conservative | |
| """ | |
| if k <= 0: | |
| raise ValueError("k must be positive") | |
| if not 0 < alpha < 1: | |
| raise ValueError("alpha must be in (0, 1)") | |
| self.k = k | |
| self.alpha = alpha | |
| def should_expand(self, visits: int, num_children: int) -> bool: | |
| """ | |
| Check if expansion should occur. | |
| Args: | |
| visits: Number of visits to node | |
| num_children: Current number of children | |
| Returns: | |
| True if should expand, False otherwise | |
| """ | |
| threshold = self.k * (num_children**self.alpha) | |
| return visits > threshold | |
| def min_visits_for_expansion(self, num_children: int) -> int: | |
| """ | |
| Calculate minimum visits needed to expand to next child. | |
| Args: | |
| num_children: Current number of children | |
| Returns: | |
| Minimum visit count for expansion | |
| """ | |
| threshold = self.k * (num_children**self.alpha) | |
| return int(math.ceil(threshold)) | |
| def __repr__(self) -> str: | |
| return f"ProgressiveWideningConfig(k={self.k}, alpha={self.alpha})" | |
| def compute_action_probabilities( | |
| children_stats: list[dict], | |
| temperature: float = 1.0, | |
| ) -> list[float]: | |
| """ | |
| Compute action probabilities from visit counts using softmax. | |
| Args: | |
| children_stats: List of dicts with 'visits' key | |
| temperature: Temperature parameter (lower = more deterministic) | |
| Returns: | |
| List of probabilities for each action | |
| """ | |
| if not children_stats: | |
| return [] | |
| visits = np.array([c["visits"] for c in children_stats], dtype=float) | |
| if temperature == 0: | |
| # Deterministic: assign 1.0 to max, 0 to others | |
| probs = np.zeros_like(visits) | |
| probs[np.argmax(visits)] = 1.0 | |
| return probs.tolist() | |
| # Apply temperature | |
| scaled_visits = visits ** (1.0 / temperature) | |
| probs = scaled_visits / scaled_visits.sum() | |
| return probs.tolist() | |
| def select_action_stochastic( | |
| children_stats: list[dict], | |
| rng: np.random.Generator, | |
| temperature: float = 1.0, | |
| ) -> int: | |
| """ | |
| Stochastically select action based on visit counts. | |
| Args: | |
| children_stats: List of child statistics | |
| rng: Random number generator | |
| temperature: Temperature for softmax | |
| Returns: | |
| Index of selected action | |
| """ | |
| probs = compute_action_probabilities(children_stats, temperature) | |
| if not probs: | |
| raise ValueError("No actions to select from") | |
| return rng.choice(len(probs), p=probs) | |