File size: 11,188 Bytes
40ee6b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
"""
MCTS Policies Module - Selection, rollout, and evaluation policies.

Provides:
- UCB1 with configurable exploration weight
- Rollout heuristics (random, greedy, hybrid)
- Action selection policies (max visits, max value, robust child)
- Progressive widening parameters
"""

from __future__ import annotations

import math
from abc import ABC, abstractmethod
from collections.abc import Awaitable, Callable
from enum import Enum
from typing import TYPE_CHECKING

import numpy as np

if TYPE_CHECKING:
    from .core import MCTSState


def ucb1(
    value_sum: float,
    visits: int,
    parent_visits: int,
    c: float = 1.414,
) -> float:
    """
    Upper Confidence Bound 1 (UCB1) formula for tree selection.

    Formula: Q(s,a) + c * sqrt(N(s)) / sqrt(N(s,a))

    Args:
        value_sum: Total accumulated value for the node
        visits: Number of visits to the node
        parent_visits: Number of visits to the parent node
        c: Exploration weight constant (default sqrt(2))

    Returns:
        UCB1 score for node selection
    """
    if visits == 0:
        return float("inf")

    exploitation = value_sum / visits
    exploration = c * ((parent_visits) ** 0.5 / (visits) ** 0.5)

    return exploitation + exploration


def ucb1_tuned(
    value_sum: float,
    value_squared_sum: float,
    visits: int,
    parent_visits: int,
    c: float = 1.0,
) -> float:
    """
    UCB1-Tuned variant with variance estimate.

    Provides tighter bounds by considering value variance.

    Args:
        value_sum: Total accumulated value
        value_squared_sum: Sum of squared values (for variance)
        visits: Number of visits
        parent_visits: Parent visit count
        c: Exploration constant

    Returns:
        UCB1-Tuned score
    """
    if visits == 0:
        return float("inf")

    mean_value = value_sum / visits
    variance = value_squared_sum / visits - mean_value**2
    variance = max(0, variance)  # Ensure non-negative

    # Variance bound term
    ln_parent = math.log(parent_visits)
    variance_bound = variance + math.sqrt(2 * ln_parent / visits)
    min_bound = min(0.25, variance_bound)

    exploitation = mean_value
    exploration = c * math.sqrt(ln_parent / visits * min_bound)

    return exploitation + exploration


class SelectionPolicy(Enum):
    """Policy for selecting the final action after MCTS search."""

    MAX_VISITS = "max_visits"
    """Select action with most visits (most robust)."""

    MAX_VALUE = "max_value"
    """Select action with highest average value (greedy)."""

    ROBUST_CHILD = "robust_child"
    """Select action balancing visits and value."""

    SECURE_CHILD = "secure_child"
    """Select action with lowest lower confidence bound."""


class RolloutPolicy(ABC):
    """Abstract base class for rollout/simulation policies."""

    @abstractmethod
    async def evaluate(
        self,
        state: MCTSState,
        rng: np.random.Generator,
        max_depth: int = 10,
    ) -> float:
        """
        Evaluate a state through rollout simulation.

        Args:
            state: State to evaluate
            rng: Seeded random number generator
            max_depth: Maximum rollout depth

        Returns:
            Estimated value in [0, 1] range
        """
        pass


class RandomRolloutPolicy(RolloutPolicy):
    """Random rollout policy - uniform random evaluation."""

    def __init__(self, base_value: float = 0.5, noise_scale: float = 0.3):
        """
        Initialize random rollout policy.

        Args:
            base_value: Base value for evaluations
            noise_scale: Scale of random noise
        """
        self.base_value = base_value
        self.noise_scale = noise_scale

    async def evaluate(
        self,
        _state: MCTSState,
        rng: np.random.Generator,
        _max_depth: int = 10,
    ) -> float:
        """Generate random evaluation with noise."""
        noise = rng.uniform(-self.noise_scale, self.noise_scale)
        value = self.base_value + noise
        return max(0.0, min(1.0, value))


class GreedyRolloutPolicy(RolloutPolicy):
    """Greedy rollout policy using domain heuristics."""

    def __init__(
        self,
        heuristic_fn: Callable[[MCTSState], float],
        noise_scale: float = 0.05,
    ):
        """
        Initialize greedy rollout policy.

        Args:
            heuristic_fn: Function to evaluate state heuristically
            noise_scale: Small noise for tie-breaking
        """
        self.heuristic_fn = heuristic_fn
        self.noise_scale = noise_scale

    async def evaluate(
        self,
        state: MCTSState,
        rng: np.random.Generator,
        _max_depth: int = 10,
    ) -> float:
        """Evaluate using heuristic with small noise."""
        base_value = self.heuristic_fn(state)
        noise = rng.uniform(-self.noise_scale, self.noise_scale)
        value = base_value + noise
        return max(0.0, min(1.0, value))


class HybridRolloutPolicy(RolloutPolicy):
    """Hybrid policy combining random and heuristic evaluation."""

    def __init__(
        self,
        heuristic_fn: Callable[[MCTSState], float] | None = None,
        heuristic_weight: float = 0.7,
        random_weight: float = 0.3,
        base_random_value: float = 0.5,
        noise_scale: float = 0.2,
    ):
        """
        Initialize hybrid rollout policy.

        Args:
            heuristic_fn: Optional heuristic evaluation function
            heuristic_weight: Weight for heuristic component
            random_weight: Weight for random component
            base_random_value: Base value for random component
            noise_scale: Noise scale for random component
        """
        self.heuristic_fn = heuristic_fn
        self.heuristic_weight = heuristic_weight
        self.random_weight = random_weight
        self.base_random_value = base_random_value
        self.noise_scale = noise_scale

        # Normalize weights
        total_weight = heuristic_weight + random_weight
        if total_weight > 0:
            self.heuristic_weight /= total_weight
            self.random_weight /= total_weight

    async def evaluate(
        self,
        state: MCTSState,
        rng: np.random.Generator,
        _max_depth: int = 10,
    ) -> float:
        """Combine heuristic and random evaluation."""
        # Random component
        random_noise = rng.uniform(-self.noise_scale, self.noise_scale)
        random_value = self.base_random_value + random_noise

        # Heuristic component
        heuristic_value = self.heuristic_fn(state) if self.heuristic_fn is not None else self.base_random_value

        # Combine
        value = self.heuristic_weight * heuristic_value + self.random_weight * random_value

        return max(0.0, min(1.0, value))


class LLMRolloutPolicy(RolloutPolicy):
    """Rollout policy that uses an LLM for state evaluation."""

    def __init__(
        self,
        evaluate_fn: Callable[[MCTSState], Awaitable[float]],
        cache_results: bool = True,
    ):
        """
        Initialize LLM rollout policy.

        Args:
            evaluate_fn: Async function to evaluate state with LLM
            cache_results: Whether to cache evaluation results
        """
        self.evaluate_fn = evaluate_fn
        self.cache_results = cache_results
        self._cache: dict = {}

    async def evaluate(
        self,
        state: MCTSState,
        _rng: np.random.Generator,
        _max_depth: int = 10,
    ) -> float:
        """Evaluate state using LLM."""
        state_key = state.to_hash_key()

        if self.cache_results and state_key in self._cache:
            return self._cache[state_key]

        value = await self.evaluate_fn(state)
        value = max(0.0, min(1.0, value))

        if self.cache_results:
            self._cache[state_key] = value

        return value


class ProgressiveWideningConfig:
    """Configuration for progressive widening in MCTS."""

    def __init__(
        self,
        k: float = 1.0,
        alpha: float = 0.5,
    ):
        """
        Configure progressive widening parameters.

        Progressive widening expands when: visits > k * num_children^alpha

        Args:
            k: Coefficient controlling expansion threshold
            alpha: Exponent controlling growth rate

        Common configurations:
        - k=1.0, alpha=0.5: Moderate widening (default)
        - k=2.0, alpha=0.5: Conservative (fewer expansions)
        - k=0.5, alpha=0.5: Aggressive (more expansions)
        - k=1.0, alpha=0.3: Very aggressive
        - k=1.0, alpha=0.7: Very conservative
        """
        if k <= 0:
            raise ValueError("k must be positive")
        if not 0 < alpha < 1:
            raise ValueError("alpha must be in (0, 1)")

        self.k = k
        self.alpha = alpha

    def should_expand(self, visits: int, num_children: int) -> bool:
        """
        Check if expansion should occur.

        Args:
            visits: Number of visits to node
            num_children: Current number of children

        Returns:
            True if should expand, False otherwise
        """
        threshold = self.k * (num_children**self.alpha)
        return visits > threshold

    def min_visits_for_expansion(self, num_children: int) -> int:
        """
        Calculate minimum visits needed to expand to next child.

        Args:
            num_children: Current number of children

        Returns:
            Minimum visit count for expansion
        """
        threshold = self.k * (num_children**self.alpha)
        return int(math.ceil(threshold))

    def __repr__(self) -> str:
        return f"ProgressiveWideningConfig(k={self.k}, alpha={self.alpha})"


def compute_action_probabilities(
    children_stats: list[dict],
    temperature: float = 1.0,
) -> list[float]:
    """
    Compute action probabilities from visit counts using softmax.

    Args:
        children_stats: List of dicts with 'visits' key
        temperature: Temperature parameter (lower = more deterministic)

    Returns:
        List of probabilities for each action
    """
    if not children_stats:
        return []

    visits = np.array([c["visits"] for c in children_stats], dtype=float)

    if temperature == 0:
        # Deterministic: assign 1.0 to max, 0 to others
        probs = np.zeros_like(visits)
        probs[np.argmax(visits)] = 1.0
        return probs.tolist()

    # Apply temperature
    scaled_visits = visits ** (1.0 / temperature)
    probs = scaled_visits / scaled_visits.sum()
    return probs.tolist()


def select_action_stochastic(
    children_stats: list[dict],
    rng: np.random.Generator,
    temperature: float = 1.0,
) -> int:
    """
    Stochastically select action based on visit counts.

    Args:
        children_stats: List of child statistics
        rng: Random number generator
        temperature: Temperature for softmax

    Returns:
        Index of selected action
    """
    probs = compute_action_probabilities(children_stats, temperature)
    if not probs:
        raise ValueError("No actions to select from")
    return rng.choice(len(probs), p=probs)