Spaces:
Running
Running
| """ | |
| Experiment Tracking Module - Track, analyze, and compare MCTS experiments. | |
| Provides: | |
| - Experiment run tracking (seed, params, results) | |
| - Statistical analysis of MCTS performance | |
| - Comparison utilities for different configurations | |
| - Export to JSON/CSV for analysis | |
| """ | |
| from __future__ import annotations | |
| import csv | |
| import json | |
| import statistics | |
| from dataclasses import asdict, dataclass, field | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Any | |
| from .config import MCTSConfig | |
| class ExperimentResult: | |
| """Result of a single MCTS experiment run.""" | |
| # Identification | |
| experiment_id: str | |
| timestamp: str = field(default_factory=lambda: datetime.now().isoformat()) | |
| # Configuration | |
| config: dict[str, Any] | None = None | |
| seed: int = 42 | |
| # Core results | |
| best_action: str | None = None | |
| best_action_value: float = 0.0 | |
| best_action_visits: int = 0 | |
| root_visits: int = 0 | |
| # Performance metrics | |
| total_iterations: int = 0 | |
| total_simulations: int = 0 | |
| execution_time_ms: float = 0.0 | |
| # Cache statistics | |
| cache_hits: int = 0 | |
| cache_misses: int = 0 | |
| cache_hit_rate: float = 0.0 | |
| # Tree statistics | |
| tree_depth: int = 0 | |
| tree_node_count: int = 0 | |
| branching_factor: float = 0.0 | |
| # Action distribution | |
| action_stats: dict[str, dict[str, Any]] = field(default_factory=dict) | |
| # Optional metadata | |
| metadata: dict[str, Any] = field(default_factory=dict) | |
| def to_dict(self) -> dict[str, Any]: | |
| """Convert to dictionary.""" | |
| return asdict(self) | |
| def to_json(self, indent: int = 2) -> str: | |
| """Serialize to JSON.""" | |
| return json.dumps(self.to_dict(), indent=indent) | |
| def from_dict(cls, data: dict[str, Any]) -> ExperimentResult: | |
| """Create from dictionary.""" | |
| return cls(**data) | |
| def from_json(cls, json_str: str) -> ExperimentResult: | |
| """Deserialize from JSON.""" | |
| return cls.from_dict(json.loads(json_str)) | |
| class ExperimentTracker: | |
| """ | |
| Track and analyze MCTS experiments. | |
| Features: | |
| - Store multiple experiment results | |
| - Statistical analysis across runs | |
| - Configuration comparison | |
| - Export to JSON/CSV | |
| """ | |
| def __init__(self, name: str = "mcts_experiments"): | |
| """ | |
| Initialize experiment tracker. | |
| Args: | |
| name: Name of this experiment series | |
| """ | |
| self.name = name | |
| self.results: list[ExperimentResult] = [] | |
| self.created_at = datetime.now().isoformat() | |
| def add_result(self, result: ExperimentResult) -> None: | |
| """ | |
| Add an experiment result. | |
| Args: | |
| result: ExperimentResult to add | |
| """ | |
| self.results.append(result) | |
| def create_result( | |
| self, | |
| experiment_id: str, | |
| config: MCTSConfig, | |
| mcts_stats: dict[str, Any], | |
| execution_time_ms: float = 0.0, | |
| tree_depth: int = 0, | |
| tree_node_count: int = 0, | |
| metadata: dict[str, Any] | None = None, | |
| ) -> ExperimentResult: | |
| """ | |
| Create and add an experiment result from MCTS statistics. | |
| Args: | |
| experiment_id: Unique ID for this experiment | |
| config: MCTS configuration used | |
| mcts_stats: Statistics dict from MCTSEngine.search() | |
| execution_time_ms: Execution time in milliseconds | |
| tree_depth: Depth of MCTS tree | |
| tree_node_count: Total nodes in tree | |
| metadata: Optional additional metadata | |
| Returns: | |
| Created ExperimentResult | |
| """ | |
| # Calculate branching factor | |
| branching_factor = 0.0 | |
| if tree_node_count > 1 and tree_depth > 0: | |
| branching_factor = (tree_node_count - 1) / tree_depth | |
| result = ExperimentResult( | |
| experiment_id=experiment_id, | |
| config=config.to_dict(), | |
| seed=config.seed, | |
| best_action=mcts_stats.get("best_action"), | |
| best_action_value=mcts_stats.get("best_action_value", 0.0), | |
| best_action_visits=mcts_stats.get("best_action_visits", 0), | |
| root_visits=mcts_stats.get("root_visits", 0), | |
| total_iterations=mcts_stats.get("iterations", 0), | |
| total_simulations=mcts_stats.get("total_simulations", 0), | |
| execution_time_ms=execution_time_ms, | |
| cache_hits=mcts_stats.get("cache_hits", 0), | |
| cache_misses=mcts_stats.get("cache_misses", 0), | |
| cache_hit_rate=mcts_stats.get("cache_hit_rate", 0.0), | |
| tree_depth=tree_depth, | |
| tree_node_count=tree_node_count, | |
| branching_factor=branching_factor, | |
| action_stats=mcts_stats.get("action_stats", {}), | |
| metadata=metadata or {}, | |
| ) | |
| self.add_result(result) | |
| return result | |
| def get_summary_statistics(self) -> dict[str, Any]: | |
| """ | |
| Compute summary statistics across all experiments. | |
| Returns: | |
| Dictionary of summary statistics | |
| """ | |
| if not self.results: | |
| return {"error": "No results to analyze"} | |
| # Extract metrics | |
| best_values = [r.best_action_value for r in self.results] | |
| best_visits = [r.best_action_visits for r in self.results] | |
| exec_times = [r.execution_time_ms for r in self.results] | |
| cache_rates = [r.cache_hit_rate for r in self.results] | |
| tree_depths = [r.tree_depth for r in self.results] | |
| node_counts = [r.tree_node_count for r in self.results] | |
| def compute_stats(values: list[float]) -> dict[str, float]: | |
| """Compute basic statistics.""" | |
| if not values: | |
| return {} | |
| return { | |
| "mean": statistics.mean(values), | |
| "std": statistics.stdev(values) if len(values) > 1 else 0.0, | |
| "min": min(values), | |
| "max": max(values), | |
| "median": statistics.median(values), | |
| } | |
| # Best action consistency | |
| best_actions = [r.best_action for r in self.results] | |
| action_counts = {} | |
| for action in best_actions: | |
| action_counts[action] = action_counts.get(action, 0) + 1 | |
| most_common_action = max(action_counts.items(), key=lambda x: x[1]) | |
| consistency_rate = most_common_action[1] / len(best_actions) | |
| return { | |
| "num_experiments": len(self.results), | |
| "best_action_value_stats": compute_stats(best_values), | |
| "best_action_visits_stats": compute_stats(best_visits), | |
| "execution_time_ms_stats": compute_stats(exec_times), | |
| "cache_hit_rate_stats": compute_stats(cache_rates), | |
| "tree_depth_stats": compute_stats(tree_depths), | |
| "tree_node_count_stats": compute_stats(node_counts), | |
| "action_consistency": { | |
| "most_common_action": most_common_action[0], | |
| "consistency_rate": consistency_rate, | |
| "action_distribution": action_counts, | |
| }, | |
| } | |
| def compare_configs( | |
| self, | |
| config_names: list[str] | None = None, | |
| ) -> dict[str, dict[str, Any]]: | |
| """ | |
| Compare performance across different configurations. | |
| Args: | |
| config_names: Specific config names to compare (all if None) | |
| Returns: | |
| Dictionary mapping config names to their statistics | |
| """ | |
| # Group results by configuration name | |
| grouped: dict[str, list[ExperimentResult]] = {} | |
| for result in self.results: | |
| if result.config is None: | |
| continue | |
| config_name = result.config.get("name", "unnamed") | |
| if config_names and config_name not in config_names: | |
| continue | |
| if config_name not in grouped: | |
| grouped[config_name] = [] | |
| grouped[config_name].append(result) | |
| # Compute statistics for each group | |
| comparison = {} | |
| for name, results in grouped.items(): | |
| values = [r.best_action_value for r in results] | |
| times = [r.execution_time_ms for r in results] | |
| visits = [r.best_action_visits for r in results] | |
| comparison[name] = { | |
| "num_runs": len(results), | |
| "avg_value": statistics.mean(values) if values else 0.0, | |
| "std_value": statistics.stdev(values) if len(values) > 1 else 0.0, | |
| "avg_time_ms": statistics.mean(times) if times else 0.0, | |
| "avg_visits": statistics.mean(visits) if visits else 0.0, | |
| "value_per_ms": ( | |
| statistics.mean(values) / statistics.mean(times) if times and statistics.mean(times) > 0 else 0.0 | |
| ), | |
| } | |
| return comparison | |
| def analyze_seed_consistency(self, seed: int) -> dict[str, Any]: | |
| """ | |
| Analyze consistency of results for a specific seed. | |
| Args: | |
| seed: Seed value to analyze | |
| Returns: | |
| Analysis of determinism for this seed | |
| """ | |
| seed_results = [r for r in self.results if r.seed == seed] | |
| if not seed_results: | |
| return {"error": f"No results found for seed {seed}"} | |
| # Check if all results are identical | |
| best_actions = [r.best_action for r in seed_results] | |
| best_values = [r.best_action_value for r in seed_results] | |
| best_visits = [r.best_action_visits for r in seed_results] | |
| is_deterministic = len(set(best_actions)) == 1 and len(set(best_values)) == 1 and len(set(best_visits)) == 1 | |
| return { | |
| "seed": seed, | |
| "num_runs": len(seed_results), | |
| "is_deterministic": is_deterministic, | |
| "unique_actions": list(set(best_actions)), | |
| "value_variance": statistics.variance(best_values) if len(best_values) > 1 else 0.0, | |
| "visits_variance": statistics.variance(best_visits) if len(best_visits) > 1 else 0.0, | |
| } | |
| def export_to_json(self, file_path: str) -> None: | |
| """ | |
| Export all results to JSON file. | |
| Args: | |
| file_path: Path to output file | |
| """ | |
| data = { | |
| "name": self.name, | |
| "created_at": self.created_at, | |
| "num_experiments": len(self.results), | |
| "results": [r.to_dict() for r in self.results], | |
| "summary": self.get_summary_statistics(), | |
| } | |
| path = Path(file_path) | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| with open(path, "w") as f: | |
| json.dump(data, f, indent=2) | |
| def export_to_csv(self, file_path: str) -> None: | |
| """ | |
| Export results to CSV file for spreadsheet analysis. | |
| Args: | |
| file_path: Path to output file | |
| """ | |
| if not self.results: | |
| return | |
| path = Path(file_path) | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| # Define CSV columns | |
| fieldnames = [ | |
| "experiment_id", | |
| "timestamp", | |
| "seed", | |
| "config_name", | |
| "num_iterations", | |
| "exploration_weight", | |
| "best_action", | |
| "best_action_value", | |
| "best_action_visits", | |
| "root_visits", | |
| "total_simulations", | |
| "execution_time_ms", | |
| "cache_hit_rate", | |
| "tree_depth", | |
| "tree_node_count", | |
| "branching_factor", | |
| ] | |
| with open(path, "w", newline="") as f: | |
| writer = csv.DictWriter(f, fieldnames=fieldnames) | |
| writer.writeheader() | |
| for result in self.results: | |
| row = { | |
| "experiment_id": result.experiment_id, | |
| "timestamp": result.timestamp, | |
| "seed": result.seed, | |
| "config_name": (result.config.get("name", "unnamed") if result.config else "unknown"), | |
| "num_iterations": (result.config.get("num_iterations", 0) if result.config else 0), | |
| "exploration_weight": (result.config.get("exploration_weight", 0) if result.config else 0), | |
| "best_action": result.best_action, | |
| "best_action_value": result.best_action_value, | |
| "best_action_visits": result.best_action_visits, | |
| "root_visits": result.root_visits, | |
| "total_simulations": result.total_simulations, | |
| "execution_time_ms": result.execution_time_ms, | |
| "cache_hit_rate": result.cache_hit_rate, | |
| "tree_depth": result.tree_depth, | |
| "tree_node_count": result.tree_node_count, | |
| "branching_factor": result.branching_factor, | |
| } | |
| writer.writerow(row) | |
| def load_from_json(cls, file_path: str) -> ExperimentTracker: | |
| """ | |
| Load experiment tracker from JSON file. | |
| Args: | |
| file_path: Path to JSON file | |
| Returns: | |
| Loaded ExperimentTracker | |
| """ | |
| with open(file_path) as f: | |
| data = json.load(f) | |
| tracker = cls(name=data.get("name", "loaded_experiments")) | |
| tracker.created_at = data.get("created_at", tracker.created_at) | |
| for result_data in data.get("results", []): | |
| tracker.results.append(ExperimentResult.from_dict(result_data)) | |
| return tracker | |
| def clear(self) -> None: | |
| """Clear all results.""" | |
| self.results.clear() | |
| def __len__(self) -> int: | |
| return len(self.results) | |
| def __repr__(self) -> str: | |
| return f"ExperimentTracker(name={self.name!r}, num_results={len(self.results)})" | |
| def run_determinism_test( | |
| engine_factory, | |
| config: MCTSConfig, | |
| num_runs: int = 3, | |
| ) -> tuple[bool, dict[str, Any]]: | |
| """ | |
| Test that MCTS produces deterministic results with same seed. | |
| Args: | |
| engine_factory: Factory function to create MCTSEngine | |
| config: Configuration to test | |
| num_runs: Number of runs to compare | |
| Returns: | |
| Tuple of (is_deterministic, analysis_dict) | |
| """ | |
| ExperimentTracker(name="determinism_test") | |
| # This is a stub - actual implementation would run the engine | |
| # Results would be compared to verify determinism | |
| analysis = { | |
| "config": config.to_dict(), | |
| "num_runs": num_runs, | |
| "is_deterministic": True, # Would be computed from actual runs | |
| "message": "Determinism test requires actual engine execution", | |
| } | |
| return True, analysis | |