Spaces:
Running
Running
| """ | |
| Neural-Guided Monte Carlo Tree Search (MCTS). | |
| Implements AlphaZero-style MCTS with: | |
| - Policy and value network guidance | |
| - PUCT (Predictor + UCT) selection | |
| - Dirichlet noise for exploration | |
| - Virtual loss for parallel search | |
| - Temperature-based action selection | |
| Based on: | |
| - "Mastering the Game of Go with Deep Neural Networks and Tree Search" (AlphaGo) | |
| - "Mastering Chess and Shogi by Self-Play with a General RL Algorithm" (AlphaZero) | |
| """ | |
| from __future__ import annotations | |
| import math | |
| from collections.abc import Callable | |
| from dataclasses import dataclass | |
| from typing import Any | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from ...training.system_config import MCTSConfig | |
| class GameState: | |
| """ | |
| Abstract game/problem state interface. | |
| Users should subclass this for their specific domain. | |
| """ | |
| def get_legal_actions(self) -> list[Any]: | |
| """Return list of legal actions from this state.""" | |
| raise NotImplementedError | |
| def apply_action(self, action: Any) -> GameState: | |
| """Apply action and return new state.""" | |
| raise NotImplementedError | |
| def is_terminal(self) -> bool: | |
| """Check if this is a terminal state.""" | |
| raise NotImplementedError | |
| def get_reward(self, player: int = 1) -> float: | |
| """Get reward for the player (1 or -1).""" | |
| raise NotImplementedError | |
| def to_tensor(self) -> torch.Tensor: | |
| """Convert state to tensor for neural network input.""" | |
| raise NotImplementedError | |
| def get_canonical_form(self, player: int) -> GameState: # noqa: ARG002 | |
| """Get state from perspective of given player.""" | |
| return self | |
| def get_hash(self) -> str: | |
| """Get unique hash for this state (for caching).""" | |
| raise NotImplementedError | |
| def action_to_index(self, action: Any) -> int: | |
| """ | |
| Map action to its index in the neural network's action space. | |
| This method should return the index corresponding to the action | |
| in the network's policy output vector. | |
| Default implementation uses string-based mapping for Tic-Tac-Toe style | |
| actions (e.g., "0,0" -> 0, "0,1" -> 1, etc.). Override this method | |
| for custom action mappings. | |
| Args: | |
| action: The action to map | |
| Returns: | |
| Index in the action space (0 to action_size-1) | |
| """ | |
| # Default implementation for grid-based actions like "row,col" | |
| if isinstance(action, str) and "," in action: | |
| row, col = map(int, action.split(",")) | |
| # Assume 3x3 grid by default - override for different sizes | |
| return row * 3 + col | |
| # For other action types, assume they are already indices | |
| return int(action) | |
| class NeuralMCTSNode: | |
| """ | |
| MCTS node with neural network guidance. | |
| Stores statistics for PUCT selection and backpropagation. | |
| """ | |
| def __init__( | |
| self, | |
| state: GameState, | |
| parent: NeuralMCTSNode | None = None, | |
| action: Any | None = None, | |
| prior: float = 0.0, | |
| ): | |
| self.state = state | |
| self.parent = parent | |
| self.action = action # Action that led to this node | |
| self.prior = prior # Prior probability from policy network | |
| # Statistics | |
| self.visit_count: int = 0 | |
| self.value_sum: float = 0.0 | |
| self.virtual_loss: float = 0.0 | |
| # Children: action -> NeuralMCTSNode | |
| self.children: dict[Any, NeuralMCTSNode] = {} | |
| # Caching | |
| self.is_expanded: bool = False | |
| self.is_terminal: bool = state.is_terminal() | |
| def value(self) -> float: | |
| """Average value (Q-value) of this node.""" | |
| if self.visit_count == 0: | |
| return 0.0 | |
| return self.value_sum / self.visit_count | |
| def expand( | |
| self, | |
| policy_probs: np.ndarray, | |
| valid_actions: list[Any], | |
| ): | |
| """ | |
| Expand node by creating children for all legal actions. | |
| Args: | |
| policy_probs: Prior probabilities from policy network | |
| valid_actions: List of legal actions | |
| """ | |
| self.is_expanded = True | |
| for action, prior in zip(valid_actions, policy_probs, strict=True): | |
| if action not in self.children: | |
| next_state = self.state.apply_action(action) | |
| self.children[action] = NeuralMCTSNode( | |
| state=next_state, | |
| parent=self, | |
| action=action, | |
| prior=prior, | |
| ) | |
| def select_child(self, c_puct: float) -> tuple[Any, NeuralMCTSNode]: | |
| """ | |
| Select best child using PUCT algorithm. | |
| PUCT = Q(s,a) + c_puct * P(s,a) * sqrt(N(s)) / (1 + N(s,a)) | |
| Args: | |
| c_puct: Exploration constant | |
| Returns: | |
| (action, child_node) tuple | |
| """ | |
| best_score = -float("inf") | |
| best_action = None | |
| best_child = None | |
| # Precompute sqrt term for efficiency | |
| sqrt_parent_visits = math.sqrt(self.visit_count) | |
| for action, child in self.children.items(): | |
| # Q-value (average value) | |
| q_value = child.value | |
| # U-value (exploration bonus) | |
| u_value = c_puct * child.prior * sqrt_parent_visits / (1 + child.visit_count + child.virtual_loss) | |
| # PUCT score | |
| puct_score = q_value + u_value | |
| if puct_score > best_score: | |
| best_score = puct_score | |
| best_action = action | |
| best_child = child | |
| return best_action, best_child | |
| def add_virtual_loss(self, virtual_loss: float): | |
| """Add virtual loss for parallel search.""" | |
| self.virtual_loss += virtual_loss | |
| def revert_virtual_loss(self, virtual_loss: float): | |
| """Remove virtual loss after search completes.""" | |
| self.virtual_loss -= virtual_loss | |
| def update(self, value: float): | |
| """Update node statistics with search result.""" | |
| self.visit_count += 1 | |
| self.value_sum += value | |
| def get_action_probs(self, temperature: float = 1.0) -> dict[Any, float]: | |
| """ | |
| Get action selection probabilities based on visit counts. | |
| Args: | |
| temperature: Temperature parameter | |
| - temperature -> 0: argmax (deterministic) | |
| - temperature = 1: proportional to visits | |
| - temperature -> inf: uniform | |
| Returns: | |
| Dictionary mapping actions to probabilities | |
| """ | |
| if not self.children: | |
| return {} | |
| if temperature == 0: | |
| # Deterministic: select most visited | |
| visits = {action: child.visit_count for action, child in self.children.items()} | |
| max_visits = max(visits.values()) | |
| best_actions = [a for a, v in visits.items() if v == max_visits] | |
| # Uniform over best actions | |
| prob = 1.0 / len(best_actions) | |
| return {a: (prob if a in best_actions else 0.0) for a in self.children} | |
| # Temperature-scaled visits | |
| visits = np.array([child.visit_count for child in self.children.values()]) | |
| actions = list(self.children.keys()) | |
| if temperature != 1.0: | |
| visits = visits ** (1.0 / temperature) | |
| # Normalize to probabilities | |
| probs = visits / visits.sum() | |
| return dict(zip(actions, probs, strict=True)) | |
| class NeuralMCTS: | |
| """ | |
| Neural-guided MCTS for decision making. | |
| Combines tree search with neural network evaluation | |
| using the AlphaZero algorithm. | |
| """ | |
| def __init__( | |
| self, | |
| policy_value_network: nn.Module, | |
| config: MCTSConfig, | |
| device: str = "cpu", | |
| ): | |
| """ | |
| Initialize neural MCTS. | |
| Args: | |
| policy_value_network: Network that outputs (policy, value) | |
| config: MCTS configuration | |
| device: Device for neural network | |
| """ | |
| self.network = policy_value_network | |
| self.config = config | |
| self.device = device | |
| # Caching for network evaluations | |
| self.cache: dict[str, tuple[np.ndarray, float]] = {} | |
| self.cache_hits = 0 | |
| self.cache_misses = 0 | |
| def add_dirichlet_noise( | |
| self, | |
| policy_probs: np.ndarray, | |
| epsilon: float | None = None, | |
| alpha: float | None = None, | |
| ) -> np.ndarray: | |
| """ | |
| Add Dirichlet noise to policy for exploration (at root only). | |
| Policy' = (1 - epsilon) * Policy + epsilon * Noise | |
| Args: | |
| policy_probs: Original policy probabilities | |
| epsilon: Mixing parameter (defaults to config) | |
| alpha: Dirichlet concentration parameter (defaults to config) | |
| Returns: | |
| Noised policy probabilities | |
| """ | |
| epsilon = epsilon or self.config.dirichlet_epsilon | |
| alpha = alpha or self.config.dirichlet_alpha | |
| noise = np.random.dirichlet([alpha] * len(policy_probs)) | |
| return (1 - epsilon) * policy_probs + epsilon * noise | |
| async def evaluate_state(self, state: GameState, add_noise: bool = False) -> tuple[np.ndarray, float]: | |
| """ | |
| Evaluate state using neural network. | |
| Args: | |
| state: Game state to evaluate | |
| add_noise: Whether to add Dirichlet noise (for root exploration) | |
| Returns: | |
| (policy_probs, value) tuple | |
| """ | |
| # Check cache | |
| state_hash = state.get_hash() | |
| if not add_noise and state_hash in self.cache: | |
| self.cache_hits += 1 | |
| return self.cache[state_hash] | |
| self.cache_misses += 1 | |
| # Get legal actions | |
| legal_actions = state.get_legal_actions() | |
| if not legal_actions: | |
| return np.array([]), 0.0 | |
| # Convert state to tensor | |
| state_tensor = state.to_tensor().unsqueeze(0).to(self.device) | |
| # Network forward pass | |
| policy_logits, value = self.network(state_tensor) | |
| # Convert to numpy (detach to remove gradients) | |
| policy_logits = policy_logits.squeeze(0).detach().cpu().numpy() | |
| value = value.item() | |
| # Proper action masking: Map legal actions to their indices in the action space | |
| # Create a mask for legal actions | |
| action_mask = np.full_like(policy_logits, -np.inf) # Mask all actions initially | |
| action_indices = [] | |
| # Map legal actions to their network output indices | |
| for action in legal_actions: | |
| try: | |
| action_idx = state.action_to_index(action) | |
| if 0 <= action_idx < len(policy_logits): | |
| action_mask[action_idx] = 0 # Unmask legal actions | |
| action_indices.append(action_idx) | |
| except (ValueError, IndexError, AttributeError) as e: | |
| # Fallback: if action_to_index fails, use sequential mapping | |
| print(f"Warning: action_to_index failed for action {action}: {e}") | |
| action_indices = list(range(len(legal_actions))) | |
| action_mask = np.full_like(policy_logits, -np.inf) | |
| action_mask[action_indices] = 0 | |
| break | |
| # Apply mask before softmax for numerical stability | |
| masked_logits = policy_logits + action_mask | |
| # Compute softmax over legal actions only | |
| exp_logits = np.exp(masked_logits - np.max(masked_logits)) # Subtract max for stability | |
| policy_probs_full = exp_logits / exp_logits.sum() | |
| # Extract probabilities for legal actions in order | |
| policy_probs = policy_probs_full[action_indices] | |
| # Normalize to ensure probabilities sum to 1 (handle numerical errors) | |
| if policy_probs.sum() > 0: | |
| policy_probs = policy_probs / policy_probs.sum() | |
| else: | |
| # Fallback: uniform distribution over legal actions | |
| policy_probs = np.ones(len(legal_actions)) / len(legal_actions) | |
| # Add Dirichlet noise if requested (root exploration) | |
| if add_noise: | |
| policy_probs = self.add_dirichlet_noise(policy_probs) | |
| # Cache result (without noise) | |
| if not add_noise: | |
| self.cache[state_hash] = (policy_probs, value) | |
| return policy_probs, value | |
| async def search( | |
| self, | |
| root_state: GameState, | |
| num_simulations: int | None = None, | |
| temperature: float = 1.0, | |
| add_root_noise: bool = True, | |
| ) -> tuple[dict[Any, float], NeuralMCTSNode]: | |
| """ | |
| Run MCTS search from root state. | |
| Args: | |
| root_state: Initial state | |
| num_simulations: Number of MCTS simulations | |
| temperature: Temperature for action selection | |
| add_root_noise: Whether to add Dirichlet noise to root | |
| Returns: | |
| (action_probs, root_node) tuple | |
| """ | |
| num_simulations = num_simulations or self.config.num_simulations | |
| # Create root node | |
| root = NeuralMCTSNode(state=root_state) | |
| # Expand root | |
| policy_probs, _ = await self.evaluate_state(root_state, add_noise=add_root_noise) | |
| legal_actions = root_state.get_legal_actions() | |
| root.expand(policy_probs, legal_actions) | |
| # Run simulations | |
| for _ in range(num_simulations): | |
| await self._simulate(root) | |
| # Get action probabilities | |
| action_probs = root.get_action_probs(temperature) | |
| return action_probs, root | |
| async def _simulate(self, node: NeuralMCTSNode) -> float: | |
| """ | |
| Run single MCTS simulation (select, expand, evaluate, backpropagate). | |
| Args: | |
| node: Root node for this simulation | |
| Returns: | |
| Value from this simulation | |
| """ | |
| path: list[NeuralMCTSNode] = [] | |
| # Selection: traverse tree using PUCT | |
| current = node | |
| while current.is_expanded and not current.is_terminal: | |
| # Add virtual loss for parallel search | |
| current.add_virtual_loss(self.config.virtual_loss) | |
| path.append(current) | |
| # Select best child | |
| _, current = current.select_child(self.config.c_puct) | |
| # Add leaf to path | |
| path.append(current) | |
| current.add_virtual_loss(self.config.virtual_loss) | |
| # Evaluate leaf node | |
| if current.is_terminal: | |
| # Terminal node: use game result | |
| value = current.state.get_reward() | |
| else: | |
| # Non-terminal: expand and evaluate with network | |
| policy_probs, value = await self.evaluate_state(current.state, add_noise=False) | |
| if not current.is_expanded: | |
| legal_actions = current.state.get_legal_actions() | |
| current.expand(policy_probs, legal_actions) | |
| # Backpropagate | |
| for node_in_path in reversed(path): | |
| node_in_path.update(value) | |
| node_in_path.revert_virtual_loss(self.config.virtual_loss) | |
| # Flip value for opponent | |
| value = -value | |
| return value | |
| def select_action( | |
| self, | |
| action_probs: dict[Any, float], | |
| temperature: float = 1.0, | |
| deterministic: bool = False, | |
| ) -> Any: | |
| """ | |
| Select action from probability distribution. | |
| Args: | |
| action_probs: Action probability dictionary | |
| temperature: Temperature (unused if deterministic=True) | |
| deterministic: If True, select action with highest probability | |
| Returns: | |
| Selected action | |
| """ | |
| if not action_probs: | |
| return None | |
| actions = list(action_probs.keys()) | |
| probs = list(action_probs.values()) | |
| if deterministic or temperature == 0: | |
| return actions[np.argmax(probs)] | |
| # Sample from distribution | |
| return np.random.choice(actions, p=probs) | |
| def clear_cache(self): | |
| """Clear the evaluation cache.""" | |
| self.cache.clear() | |
| self.cache_hits = 0 | |
| self.cache_misses = 0 | |
| def get_cache_stats(self) -> dict: | |
| """Get cache performance statistics.""" | |
| total = self.cache_hits + self.cache_misses | |
| hit_rate = self.cache_hits / total if total > 0 else 0.0 | |
| return { | |
| "cache_size": len(self.cache), | |
| "cache_hits": self.cache_hits, | |
| "cache_misses": self.cache_misses, | |
| "hit_rate": hit_rate, | |
| } | |
| # Training data collection | |
| class MCTSExample: | |
| """Training example from MCTS self-play.""" | |
| state: torch.Tensor # State representation | |
| policy_target: np.ndarray # Target policy (visit counts) | |
| value_target: float # Target value (game outcome) | |
| player: int # Player to move (1 or -1) | |
| class SelfPlayCollector: | |
| """ | |
| Collect training data from self-play games. | |
| Uses MCTS to generate high-quality training examples. | |
| """ | |
| def __init__( | |
| self, | |
| mcts: NeuralMCTS, | |
| config: MCTSConfig, | |
| ): | |
| self.mcts = mcts | |
| self.config = config | |
| async def play_game( | |
| self, | |
| initial_state: GameState, | |
| temperature_threshold: int | None = None, | |
| ) -> list[MCTSExample]: | |
| """ | |
| Play a single self-play game. | |
| Args: | |
| initial_state: Starting game state | |
| temperature_threshold: Move number to switch to greedy play | |
| Returns: | |
| List of training examples from the game | |
| """ | |
| temperature_threshold = temperature_threshold or self.config.temperature_threshold | |
| examples: list[MCTSExample] = [] | |
| state = initial_state | |
| player = 1 # Current player (1 or -1) | |
| move_count = 0 | |
| while not state.is_terminal(): | |
| # Determine temperature | |
| temperature = ( | |
| self.config.temperature_init if move_count < temperature_threshold else self.config.temperature_final | |
| ) | |
| # Run MCTS | |
| action_probs, root = await self.mcts.search(state, temperature=temperature, add_root_noise=True) | |
| # Store training example | |
| # Convert action probs to array for all actions | |
| probs = np.array(list(action_probs.values())) | |
| examples.append( | |
| MCTSExample( | |
| state=state.to_tensor(), | |
| policy_target=probs, | |
| value_target=0.0, # Will be filled with game outcome | |
| player=player, | |
| ) | |
| ) | |
| # Select and apply action | |
| action = self.mcts.select_action(action_probs, temperature=temperature) | |
| state = state.apply_action(action) | |
| # Switch player | |
| player = -player | |
| move_count += 1 | |
| # Get game outcome | |
| outcome = state.get_reward() | |
| # Assign values to examples | |
| for example in examples: | |
| # Value is from perspective of the player who made the move | |
| example.value_target = outcome if example.player == 1 else -outcome | |
| return examples | |
| async def generate_batch(self, num_games: int, initial_state_fn: Callable[[], GameState]) -> list[MCTSExample]: | |
| """ | |
| Generate a batch of training examples from multiple games. | |
| Args: | |
| num_games: Number of games to play | |
| initial_state_fn: Function that returns initial game state | |
| Returns: | |
| Combined list of training examples | |
| """ | |
| all_examples = [] | |
| for _ in range(num_games): | |
| initial_state = initial_state_fn() | |
| examples = await self.play_game(initial_state) | |
| all_examples.extend(examples) | |
| # Clear cache periodically | |
| if len(self.mcts.cache) > 10000: | |
| self.mcts.clear_cache() | |
| return all_examples | |