ianshank
feat: add personality output and bug fixes
40ee6b4
"""
MCTS Policies Module - Selection, rollout, and evaluation policies.
Provides:
- UCB1 with configurable exploration weight
- Rollout heuristics (random, greedy, hybrid)
- Action selection policies (max visits, max value, robust child)
- Progressive widening parameters
"""
from __future__ import annotations
import math
from abc import ABC, abstractmethod
from collections.abc import Awaitable, Callable
from enum import Enum
from typing import TYPE_CHECKING
import numpy as np
if TYPE_CHECKING:
from .core import MCTSState
def ucb1(
value_sum: float,
visits: int,
parent_visits: int,
c: float = 1.414,
) -> float:
"""
Upper Confidence Bound 1 (UCB1) formula for tree selection.
Formula: Q(s,a) + c * sqrt(N(s)) / sqrt(N(s,a))
Args:
value_sum: Total accumulated value for the node
visits: Number of visits to the node
parent_visits: Number of visits to the parent node
c: Exploration weight constant (default sqrt(2))
Returns:
UCB1 score for node selection
"""
if visits == 0:
return float("inf")
exploitation = value_sum / visits
exploration = c * ((parent_visits) ** 0.5 / (visits) ** 0.5)
return exploitation + exploration
def ucb1_tuned(
value_sum: float,
value_squared_sum: float,
visits: int,
parent_visits: int,
c: float = 1.0,
) -> float:
"""
UCB1-Tuned variant with variance estimate.
Provides tighter bounds by considering value variance.
Args:
value_sum: Total accumulated value
value_squared_sum: Sum of squared values (for variance)
visits: Number of visits
parent_visits: Parent visit count
c: Exploration constant
Returns:
UCB1-Tuned score
"""
if visits == 0:
return float("inf")
mean_value = value_sum / visits
variance = value_squared_sum / visits - mean_value**2
variance = max(0, variance) # Ensure non-negative
# Variance bound term
ln_parent = math.log(parent_visits)
variance_bound = variance + math.sqrt(2 * ln_parent / visits)
min_bound = min(0.25, variance_bound)
exploitation = mean_value
exploration = c * math.sqrt(ln_parent / visits * min_bound)
return exploitation + exploration
class SelectionPolicy(Enum):
"""Policy for selecting the final action after MCTS search."""
MAX_VISITS = "max_visits"
"""Select action with most visits (most robust)."""
MAX_VALUE = "max_value"
"""Select action with highest average value (greedy)."""
ROBUST_CHILD = "robust_child"
"""Select action balancing visits and value."""
SECURE_CHILD = "secure_child"
"""Select action with lowest lower confidence bound."""
class RolloutPolicy(ABC):
"""Abstract base class for rollout/simulation policies."""
@abstractmethod
async def evaluate(
self,
state: MCTSState,
rng: np.random.Generator,
max_depth: int = 10,
) -> float:
"""
Evaluate a state through rollout simulation.
Args:
state: State to evaluate
rng: Seeded random number generator
max_depth: Maximum rollout depth
Returns:
Estimated value in [0, 1] range
"""
pass
class RandomRolloutPolicy(RolloutPolicy):
"""Random rollout policy - uniform random evaluation."""
def __init__(self, base_value: float = 0.5, noise_scale: float = 0.3):
"""
Initialize random rollout policy.
Args:
base_value: Base value for evaluations
noise_scale: Scale of random noise
"""
self.base_value = base_value
self.noise_scale = noise_scale
async def evaluate(
self,
_state: MCTSState,
rng: np.random.Generator,
_max_depth: int = 10,
) -> float:
"""Generate random evaluation with noise."""
noise = rng.uniform(-self.noise_scale, self.noise_scale)
value = self.base_value + noise
return max(0.0, min(1.0, value))
class GreedyRolloutPolicy(RolloutPolicy):
"""Greedy rollout policy using domain heuristics."""
def __init__(
self,
heuristic_fn: Callable[[MCTSState], float],
noise_scale: float = 0.05,
):
"""
Initialize greedy rollout policy.
Args:
heuristic_fn: Function to evaluate state heuristically
noise_scale: Small noise for tie-breaking
"""
self.heuristic_fn = heuristic_fn
self.noise_scale = noise_scale
async def evaluate(
self,
state: MCTSState,
rng: np.random.Generator,
_max_depth: int = 10,
) -> float:
"""Evaluate using heuristic with small noise."""
base_value = self.heuristic_fn(state)
noise = rng.uniform(-self.noise_scale, self.noise_scale)
value = base_value + noise
return max(0.0, min(1.0, value))
class HybridRolloutPolicy(RolloutPolicy):
"""Hybrid policy combining random and heuristic evaluation."""
def __init__(
self,
heuristic_fn: Callable[[MCTSState], float] | None = None,
heuristic_weight: float = 0.7,
random_weight: float = 0.3,
base_random_value: float = 0.5,
noise_scale: float = 0.2,
):
"""
Initialize hybrid rollout policy.
Args:
heuristic_fn: Optional heuristic evaluation function
heuristic_weight: Weight for heuristic component
random_weight: Weight for random component
base_random_value: Base value for random component
noise_scale: Noise scale for random component
"""
self.heuristic_fn = heuristic_fn
self.heuristic_weight = heuristic_weight
self.random_weight = random_weight
self.base_random_value = base_random_value
self.noise_scale = noise_scale
# Normalize weights
total_weight = heuristic_weight + random_weight
if total_weight > 0:
self.heuristic_weight /= total_weight
self.random_weight /= total_weight
async def evaluate(
self,
state: MCTSState,
rng: np.random.Generator,
_max_depth: int = 10,
) -> float:
"""Combine heuristic and random evaluation."""
# Random component
random_noise = rng.uniform(-self.noise_scale, self.noise_scale)
random_value = self.base_random_value + random_noise
# Heuristic component
heuristic_value = self.heuristic_fn(state) if self.heuristic_fn is not None else self.base_random_value
# Combine
value = self.heuristic_weight * heuristic_value + self.random_weight * random_value
return max(0.0, min(1.0, value))
class LLMRolloutPolicy(RolloutPolicy):
"""Rollout policy that uses an LLM for state evaluation."""
def __init__(
self,
evaluate_fn: Callable[[MCTSState], Awaitable[float]],
cache_results: bool = True,
):
"""
Initialize LLM rollout policy.
Args:
evaluate_fn: Async function to evaluate state with LLM
cache_results: Whether to cache evaluation results
"""
self.evaluate_fn = evaluate_fn
self.cache_results = cache_results
self._cache: dict = {}
async def evaluate(
self,
state: MCTSState,
_rng: np.random.Generator,
_max_depth: int = 10,
) -> float:
"""Evaluate state using LLM."""
state_key = state.to_hash_key()
if self.cache_results and state_key in self._cache:
return self._cache[state_key]
value = await self.evaluate_fn(state)
value = max(0.0, min(1.0, value))
if self.cache_results:
self._cache[state_key] = value
return value
class ProgressiveWideningConfig:
"""Configuration for progressive widening in MCTS."""
def __init__(
self,
k: float = 1.0,
alpha: float = 0.5,
):
"""
Configure progressive widening parameters.
Progressive widening expands when: visits > k * num_children^alpha
Args:
k: Coefficient controlling expansion threshold
alpha: Exponent controlling growth rate
Common configurations:
- k=1.0, alpha=0.5: Moderate widening (default)
- k=2.0, alpha=0.5: Conservative (fewer expansions)
- k=0.5, alpha=0.5: Aggressive (more expansions)
- k=1.0, alpha=0.3: Very aggressive
- k=1.0, alpha=0.7: Very conservative
"""
if k <= 0:
raise ValueError("k must be positive")
if not 0 < alpha < 1:
raise ValueError("alpha must be in (0, 1)")
self.k = k
self.alpha = alpha
def should_expand(self, visits: int, num_children: int) -> bool:
"""
Check if expansion should occur.
Args:
visits: Number of visits to node
num_children: Current number of children
Returns:
True if should expand, False otherwise
"""
threshold = self.k * (num_children**self.alpha)
return visits > threshold
def min_visits_for_expansion(self, num_children: int) -> int:
"""
Calculate minimum visits needed to expand to next child.
Args:
num_children: Current number of children
Returns:
Minimum visit count for expansion
"""
threshold = self.k * (num_children**self.alpha)
return int(math.ceil(threshold))
def __repr__(self) -> str:
return f"ProgressiveWideningConfig(k={self.k}, alpha={self.alpha})"
def compute_action_probabilities(
children_stats: list[dict],
temperature: float = 1.0,
) -> list[float]:
"""
Compute action probabilities from visit counts using softmax.
Args:
children_stats: List of dicts with 'visits' key
temperature: Temperature parameter (lower = more deterministic)
Returns:
List of probabilities for each action
"""
if not children_stats:
return []
visits = np.array([c["visits"] for c in children_stats], dtype=float)
if temperature == 0:
# Deterministic: assign 1.0 to max, 0 to others
probs = np.zeros_like(visits)
probs[np.argmax(visits)] = 1.0
return probs.tolist()
# Apply temperature
scaled_visits = visits ** (1.0 / temperature)
probs = scaled_visits / scaled_visits.sum()
return probs.tolist()
def select_action_stochastic(
children_stats: list[dict],
rng: np.random.Generator,
temperature: float = 1.0,
) -> int:
"""
Stochastically select action based on visit counts.
Args:
children_stats: List of child statistics
rng: Random number generator
temperature: Temperature for softmax
Returns:
Index of selected action
"""
probs = compute_action_probabilities(children_stats, temperature)
if not probs:
raise ValueError("No actions to select from")
return rng.choice(len(probs), p=probs)