File size: 19,858 Bytes
40ee6b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
"""
Neural-Guided Monte Carlo Tree Search (MCTS).

Implements AlphaZero-style MCTS with:
- Policy and value network guidance
- PUCT (Predictor + UCT) selection
- Dirichlet noise for exploration
- Virtual loss for parallel search
- Temperature-based action selection

Based on:
- "Mastering the Game of Go with Deep Neural Networks and Tree Search" (AlphaGo)
- "Mastering Chess and Shogi by Self-Play with a General RL Algorithm" (AlphaZero)
"""

from __future__ import annotations

import math
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any

import numpy as np
import torch
import torch.nn as nn

from ...training.system_config import MCTSConfig


@dataclass
class GameState:
    """
    Abstract game/problem state interface.

    Users should subclass this for their specific domain.
    """

    def get_legal_actions(self) -> list[Any]:
        """Return list of legal actions from this state."""
        raise NotImplementedError

    def apply_action(self, action: Any) -> GameState:
        """Apply action and return new state."""
        raise NotImplementedError

    def is_terminal(self) -> bool:
        """Check if this is a terminal state."""
        raise NotImplementedError

    def get_reward(self, player: int = 1) -> float:
        """Get reward for the player (1 or -1)."""
        raise NotImplementedError

    def to_tensor(self) -> torch.Tensor:
        """Convert state to tensor for neural network input."""
        raise NotImplementedError

    def get_canonical_form(self, player: int) -> GameState:  # noqa: ARG002
        """Get state from perspective of given player."""
        return self

    def get_hash(self) -> str:
        """Get unique hash for this state (for caching)."""
        raise NotImplementedError

    def action_to_index(self, action: Any) -> int:
        """
        Map action to its index in the neural network's action space.

        This method should return the index corresponding to the action
        in the network's policy output vector.

        Default implementation uses string-based mapping for Tic-Tac-Toe style
        actions (e.g., "0,0" -> 0, "0,1" -> 1, etc.). Override this method
        for custom action mappings.

        Args:
            action: The action to map

        Returns:
            Index in the action space (0 to action_size-1)
        """
        # Default implementation for grid-based actions like "row,col"
        if isinstance(action, str) and "," in action:
            row, col = map(int, action.split(","))
            # Assume 3x3 grid by default - override for different sizes
            return row * 3 + col
        # For other action types, assume they are already indices
        return int(action)


class NeuralMCTSNode:
    """
    MCTS node with neural network guidance.

    Stores statistics for PUCT selection and backpropagation.
    """

    def __init__(
        self,
        state: GameState,
        parent: NeuralMCTSNode | None = None,
        action: Any | None = None,
        prior: float = 0.0,
    ):
        self.state = state
        self.parent = parent
        self.action = action  # Action that led to this node
        self.prior = prior  # Prior probability from policy network

        # Statistics
        self.visit_count: int = 0
        self.value_sum: float = 0.0
        self.virtual_loss: float = 0.0

        # Children: action -> NeuralMCTSNode
        self.children: dict[Any, NeuralMCTSNode] = {}

        # Caching
        self.is_expanded: bool = False
        self.is_terminal: bool = state.is_terminal()

    @property
    def value(self) -> float:
        """Average value (Q-value) of this node."""
        if self.visit_count == 0:
            return 0.0
        return self.value_sum / self.visit_count

    def expand(
        self,
        policy_probs: np.ndarray,
        valid_actions: list[Any],
    ):
        """
        Expand node by creating children for all legal actions.

        Args:
            policy_probs: Prior probabilities from policy network
            valid_actions: List of legal actions
        """
        self.is_expanded = True

        for action, prior in zip(valid_actions, policy_probs, strict=True):
            if action not in self.children:
                next_state = self.state.apply_action(action)
                self.children[action] = NeuralMCTSNode(
                    state=next_state,
                    parent=self,
                    action=action,
                    prior=prior,
                )

    def select_child(self, c_puct: float) -> tuple[Any, NeuralMCTSNode]:
        """
        Select best child using PUCT algorithm.

        PUCT = Q(s,a) + c_puct * P(s,a) * sqrt(N(s)) / (1 + N(s,a))

        Args:
            c_puct: Exploration constant

        Returns:
            (action, child_node) tuple
        """
        best_score = -float("inf")
        best_action = None
        best_child = None

        # Precompute sqrt term for efficiency
        sqrt_parent_visits = math.sqrt(self.visit_count)

        for action, child in self.children.items():
            # Q-value (average value)
            q_value = child.value

            # U-value (exploration bonus)
            u_value = c_puct * child.prior * sqrt_parent_visits / (1 + child.visit_count + child.virtual_loss)

            # PUCT score
            puct_score = q_value + u_value

            if puct_score > best_score:
                best_score = puct_score
                best_action = action
                best_child = child

        return best_action, best_child

    def add_virtual_loss(self, virtual_loss: float):
        """Add virtual loss for parallel search."""
        self.virtual_loss += virtual_loss

    def revert_virtual_loss(self, virtual_loss: float):
        """Remove virtual loss after search completes."""
        self.virtual_loss -= virtual_loss

    def update(self, value: float):
        """Update node statistics with search result."""
        self.visit_count += 1
        self.value_sum += value

    def get_action_probs(self, temperature: float = 1.0) -> dict[Any, float]:
        """
        Get action selection probabilities based on visit counts.

        Args:
            temperature: Temperature parameter
                - temperature -> 0: argmax (deterministic)
                - temperature = 1: proportional to visits
                - temperature -> inf: uniform

        Returns:
            Dictionary mapping actions to probabilities
        """
        if not self.children:
            return {}

        if temperature == 0:
            # Deterministic: select most visited
            visits = {action: child.visit_count for action, child in self.children.items()}
            max_visits = max(visits.values())
            best_actions = [a for a, v in visits.items() if v == max_visits]

            # Uniform over best actions
            prob = 1.0 / len(best_actions)
            return {a: (prob if a in best_actions else 0.0) for a in self.children}

        # Temperature-scaled visits
        visits = np.array([child.visit_count for child in self.children.values()])
        actions = list(self.children.keys())

        if temperature != 1.0:
            visits = visits ** (1.0 / temperature)

        # Normalize to probabilities
        probs = visits / visits.sum()

        return dict(zip(actions, probs, strict=True))


class NeuralMCTS:
    """
    Neural-guided MCTS for decision making.

    Combines tree search with neural network evaluation
    using the AlphaZero algorithm.
    """

    def __init__(
        self,
        policy_value_network: nn.Module,
        config: MCTSConfig,
        device: str = "cpu",
    ):
        """
        Initialize neural MCTS.

        Args:
            policy_value_network: Network that outputs (policy, value)
            config: MCTS configuration
            device: Device for neural network
        """
        self.network = policy_value_network
        self.config = config
        self.device = device

        # Caching for network evaluations
        self.cache: dict[str, tuple[np.ndarray, float]] = {}
        self.cache_hits = 0
        self.cache_misses = 0

    def add_dirichlet_noise(
        self,
        policy_probs: np.ndarray,
        epsilon: float | None = None,
        alpha: float | None = None,
    ) -> np.ndarray:
        """
        Add Dirichlet noise to policy for exploration (at root only).

        Policy' = (1 - epsilon) * Policy + epsilon * Noise

        Args:
            policy_probs: Original policy probabilities
            epsilon: Mixing parameter (defaults to config)
            alpha: Dirichlet concentration parameter (defaults to config)

        Returns:
            Noised policy probabilities
        """
        epsilon = epsilon or self.config.dirichlet_epsilon
        alpha = alpha or self.config.dirichlet_alpha

        noise = np.random.dirichlet([alpha] * len(policy_probs))
        return (1 - epsilon) * policy_probs + epsilon * noise

    @torch.no_grad()
    async def evaluate_state(self, state: GameState, add_noise: bool = False) -> tuple[np.ndarray, float]:
        """
        Evaluate state using neural network.

        Args:
            state: Game state to evaluate
            add_noise: Whether to add Dirichlet noise (for root exploration)

        Returns:
            (policy_probs, value) tuple
        """
        # Check cache
        state_hash = state.get_hash()
        if not add_noise and state_hash in self.cache:
            self.cache_hits += 1
            return self.cache[state_hash]

        self.cache_misses += 1

        # Get legal actions
        legal_actions = state.get_legal_actions()
        if not legal_actions:
            return np.array([]), 0.0

        # Convert state to tensor
        state_tensor = state.to_tensor().unsqueeze(0).to(self.device)

        # Network forward pass
        policy_logits, value = self.network(state_tensor)

        # Convert to numpy (detach to remove gradients)
        policy_logits = policy_logits.squeeze(0).detach().cpu().numpy()
        value = value.item()

        # Proper action masking: Map legal actions to their indices in the action space
        # Create a mask for legal actions
        action_mask = np.full_like(policy_logits, -np.inf)  # Mask all actions initially
        action_indices = []

        # Map legal actions to their network output indices
        for action in legal_actions:
            try:
                action_idx = state.action_to_index(action)
                if 0 <= action_idx < len(policy_logits):
                    action_mask[action_idx] = 0  # Unmask legal actions
                    action_indices.append(action_idx)
            except (ValueError, IndexError, AttributeError) as e:
                # Fallback: if action_to_index fails, use sequential mapping
                print(f"Warning: action_to_index failed for action {action}: {e}")
                action_indices = list(range(len(legal_actions)))
                action_mask = np.full_like(policy_logits, -np.inf)
                action_mask[action_indices] = 0
                break

        # Apply mask before softmax for numerical stability
        masked_logits = policy_logits + action_mask

        # Compute softmax over legal actions only
        exp_logits = np.exp(masked_logits - np.max(masked_logits))  # Subtract max for stability
        policy_probs_full = exp_logits / exp_logits.sum()

        # Extract probabilities for legal actions in order
        policy_probs = policy_probs_full[action_indices]

        # Normalize to ensure probabilities sum to 1 (handle numerical errors)
        if policy_probs.sum() > 0:
            policy_probs = policy_probs / policy_probs.sum()
        else:
            # Fallback: uniform distribution over legal actions
            policy_probs = np.ones(len(legal_actions)) / len(legal_actions)

        # Add Dirichlet noise if requested (root exploration)
        if add_noise:
            policy_probs = self.add_dirichlet_noise(policy_probs)

        # Cache result (without noise)
        if not add_noise:
            self.cache[state_hash] = (policy_probs, value)

        return policy_probs, value

    async def search(
        self,
        root_state: GameState,
        num_simulations: int | None = None,
        temperature: float = 1.0,
        add_root_noise: bool = True,
    ) -> tuple[dict[Any, float], NeuralMCTSNode]:
        """
        Run MCTS search from root state.

        Args:
            root_state: Initial state
            num_simulations: Number of MCTS simulations
            temperature: Temperature for action selection
            add_root_noise: Whether to add Dirichlet noise to root

        Returns:
            (action_probs, root_node) tuple
        """
        num_simulations = num_simulations or self.config.num_simulations

        # Create root node
        root = NeuralMCTSNode(state=root_state)

        # Expand root
        policy_probs, _ = await self.evaluate_state(root_state, add_noise=add_root_noise)
        legal_actions = root_state.get_legal_actions()
        root.expand(policy_probs, legal_actions)

        # Run simulations
        for _ in range(num_simulations):
            await self._simulate(root)

        # Get action probabilities
        action_probs = root.get_action_probs(temperature)

        return action_probs, root

    async def _simulate(self, node: NeuralMCTSNode) -> float:
        """
        Run single MCTS simulation (select, expand, evaluate, backpropagate).

        Args:
            node: Root node for this simulation

        Returns:
            Value from this simulation
        """
        path: list[NeuralMCTSNode] = []

        # Selection: traverse tree using PUCT
        current = node
        while current.is_expanded and not current.is_terminal:
            # Add virtual loss for parallel search
            current.add_virtual_loss(self.config.virtual_loss)
            path.append(current)

            # Select best child
            _, current = current.select_child(self.config.c_puct)

        # Add leaf to path
        path.append(current)
        current.add_virtual_loss(self.config.virtual_loss)

        # Evaluate leaf node
        if current.is_terminal:
            # Terminal node: use game result
            value = current.state.get_reward()
        else:
            # Non-terminal: expand and evaluate with network
            policy_probs, value = await self.evaluate_state(current.state, add_noise=False)

            if not current.is_expanded:
                legal_actions = current.state.get_legal_actions()
                current.expand(policy_probs, legal_actions)

        # Backpropagate
        for node_in_path in reversed(path):
            node_in_path.update(value)
            node_in_path.revert_virtual_loss(self.config.virtual_loss)

            # Flip value for opponent
            value = -value

        return value

    def select_action(
        self,
        action_probs: dict[Any, float],
        temperature: float = 1.0,
        deterministic: bool = False,
    ) -> Any:
        """
        Select action from probability distribution.

        Args:
            action_probs: Action probability dictionary
            temperature: Temperature (unused if deterministic=True)
            deterministic: If True, select action with highest probability

        Returns:
            Selected action
        """
        if not action_probs:
            return None

        actions = list(action_probs.keys())
        probs = list(action_probs.values())

        if deterministic or temperature == 0:
            return actions[np.argmax(probs)]

        # Sample from distribution
        return np.random.choice(actions, p=probs)

    def clear_cache(self):
        """Clear the evaluation cache."""
        self.cache.clear()
        self.cache_hits = 0
        self.cache_misses = 0

    def get_cache_stats(self) -> dict:
        """Get cache performance statistics."""
        total = self.cache_hits + self.cache_misses
        hit_rate = self.cache_hits / total if total > 0 else 0.0

        return {
            "cache_size": len(self.cache),
            "cache_hits": self.cache_hits,
            "cache_misses": self.cache_misses,
            "hit_rate": hit_rate,
        }


# Training data collection
@dataclass
class MCTSExample:
    """Training example from MCTS self-play."""

    state: torch.Tensor  # State representation
    policy_target: np.ndarray  # Target policy (visit counts)
    value_target: float  # Target value (game outcome)
    player: int  # Player to move (1 or -1)


class SelfPlayCollector:
    """
    Collect training data from self-play games.

    Uses MCTS to generate high-quality training examples.
    """

    def __init__(
        self,
        mcts: NeuralMCTS,
        config: MCTSConfig,
    ):
        self.mcts = mcts
        self.config = config

    async def play_game(
        self,
        initial_state: GameState,
        temperature_threshold: int | None = None,
    ) -> list[MCTSExample]:
        """
        Play a single self-play game.

        Args:
            initial_state: Starting game state
            temperature_threshold: Move number to switch to greedy play

        Returns:
            List of training examples from the game
        """
        temperature_threshold = temperature_threshold or self.config.temperature_threshold

        examples: list[MCTSExample] = []
        state = initial_state
        player = 1  # Current player (1 or -1)
        move_count = 0

        while not state.is_terminal():
            # Determine temperature
            temperature = (
                self.config.temperature_init if move_count < temperature_threshold else self.config.temperature_final
            )

            # Run MCTS
            action_probs, root = await self.mcts.search(state, temperature=temperature, add_root_noise=True)

            # Store training example
            # Convert action probs to array for all actions
            probs = np.array(list(action_probs.values()))

            examples.append(
                MCTSExample(
                    state=state.to_tensor(),
                    policy_target=probs,
                    value_target=0.0,  # Will be filled with game outcome
                    player=player,
                )
            )

            # Select and apply action
            action = self.mcts.select_action(action_probs, temperature=temperature)
            state = state.apply_action(action)

            # Switch player
            player = -player
            move_count += 1

        # Get game outcome
        outcome = state.get_reward()

        # Assign values to examples
        for example in examples:
            # Value is from perspective of the player who made the move
            example.value_target = outcome if example.player == 1 else -outcome

        return examples

    async def generate_batch(self, num_games: int, initial_state_fn: Callable[[], GameState]) -> list[MCTSExample]:
        """
        Generate a batch of training examples from multiple games.

        Args:
            num_games: Number of games to play
            initial_state_fn: Function that returns initial game state

        Returns:
            Combined list of training examples
        """
        all_examples = []

        for _ in range(num_games):
            initial_state = initial_state_fn()
            examples = await self.play_game(initial_state)
            all_examples.extend(examples)

            # Clear cache periodically
            if len(self.mcts.cache) > 10000:
                self.mcts.clear_cache()

        return all_examples