ianshank
feat: add personality output and bug fixes
40ee6b4
"""
Experiment Tracking Module - Track, analyze, and compare MCTS experiments.
Provides:
- Experiment run tracking (seed, params, results)
- Statistical analysis of MCTS performance
- Comparison utilities for different configurations
- Export to JSON/CSV for analysis
"""
from __future__ import annotations
import csv
import json
import statistics
from dataclasses import asdict, dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Any
from .config import MCTSConfig
@dataclass
class ExperimentResult:
"""Result of a single MCTS experiment run."""
# Identification
experiment_id: str
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
# Configuration
config: dict[str, Any] | None = None
seed: int = 42
# Core results
best_action: str | None = None
best_action_value: float = 0.0
best_action_visits: int = 0
root_visits: int = 0
# Performance metrics
total_iterations: int = 0
total_simulations: int = 0
execution_time_ms: float = 0.0
# Cache statistics
cache_hits: int = 0
cache_misses: int = 0
cache_hit_rate: float = 0.0
# Tree statistics
tree_depth: int = 0
tree_node_count: int = 0
branching_factor: float = 0.0
# Action distribution
action_stats: dict[str, dict[str, Any]] = field(default_factory=dict)
# Optional metadata
metadata: dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary."""
return asdict(self)
def to_json(self, indent: int = 2) -> str:
"""Serialize to JSON."""
return json.dumps(self.to_dict(), indent=indent)
@classmethod
def from_dict(cls, data: dict[str, Any]) -> ExperimentResult:
"""Create from dictionary."""
return cls(**data)
@classmethod
def from_json(cls, json_str: str) -> ExperimentResult:
"""Deserialize from JSON."""
return cls.from_dict(json.loads(json_str))
class ExperimentTracker:
"""
Track and analyze MCTS experiments.
Features:
- Store multiple experiment results
- Statistical analysis across runs
- Configuration comparison
- Export to JSON/CSV
"""
def __init__(self, name: str = "mcts_experiments"):
"""
Initialize experiment tracker.
Args:
name: Name of this experiment series
"""
self.name = name
self.results: list[ExperimentResult] = []
self.created_at = datetime.now().isoformat()
def add_result(self, result: ExperimentResult) -> None:
"""
Add an experiment result.
Args:
result: ExperimentResult to add
"""
self.results.append(result)
def create_result(
self,
experiment_id: str,
config: MCTSConfig,
mcts_stats: dict[str, Any],
execution_time_ms: float = 0.0,
tree_depth: int = 0,
tree_node_count: int = 0,
metadata: dict[str, Any] | None = None,
) -> ExperimentResult:
"""
Create and add an experiment result from MCTS statistics.
Args:
experiment_id: Unique ID for this experiment
config: MCTS configuration used
mcts_stats: Statistics dict from MCTSEngine.search()
execution_time_ms: Execution time in milliseconds
tree_depth: Depth of MCTS tree
tree_node_count: Total nodes in tree
metadata: Optional additional metadata
Returns:
Created ExperimentResult
"""
# Calculate branching factor
branching_factor = 0.0
if tree_node_count > 1 and tree_depth > 0:
branching_factor = (tree_node_count - 1) / tree_depth
result = ExperimentResult(
experiment_id=experiment_id,
config=config.to_dict(),
seed=config.seed,
best_action=mcts_stats.get("best_action"),
best_action_value=mcts_stats.get("best_action_value", 0.0),
best_action_visits=mcts_stats.get("best_action_visits", 0),
root_visits=mcts_stats.get("root_visits", 0),
total_iterations=mcts_stats.get("iterations", 0),
total_simulations=mcts_stats.get("total_simulations", 0),
execution_time_ms=execution_time_ms,
cache_hits=mcts_stats.get("cache_hits", 0),
cache_misses=mcts_stats.get("cache_misses", 0),
cache_hit_rate=mcts_stats.get("cache_hit_rate", 0.0),
tree_depth=tree_depth,
tree_node_count=tree_node_count,
branching_factor=branching_factor,
action_stats=mcts_stats.get("action_stats", {}),
metadata=metadata or {},
)
self.add_result(result)
return result
def get_summary_statistics(self) -> dict[str, Any]:
"""
Compute summary statistics across all experiments.
Returns:
Dictionary of summary statistics
"""
if not self.results:
return {"error": "No results to analyze"}
# Extract metrics
best_values = [r.best_action_value for r in self.results]
best_visits = [r.best_action_visits for r in self.results]
exec_times = [r.execution_time_ms for r in self.results]
cache_rates = [r.cache_hit_rate for r in self.results]
tree_depths = [r.tree_depth for r in self.results]
node_counts = [r.tree_node_count for r in self.results]
def compute_stats(values: list[float]) -> dict[str, float]:
"""Compute basic statistics."""
if not values:
return {}
return {
"mean": statistics.mean(values),
"std": statistics.stdev(values) if len(values) > 1 else 0.0,
"min": min(values),
"max": max(values),
"median": statistics.median(values),
}
# Best action consistency
best_actions = [r.best_action for r in self.results]
action_counts = {}
for action in best_actions:
action_counts[action] = action_counts.get(action, 0) + 1
most_common_action = max(action_counts.items(), key=lambda x: x[1])
consistency_rate = most_common_action[1] / len(best_actions)
return {
"num_experiments": len(self.results),
"best_action_value_stats": compute_stats(best_values),
"best_action_visits_stats": compute_stats(best_visits),
"execution_time_ms_stats": compute_stats(exec_times),
"cache_hit_rate_stats": compute_stats(cache_rates),
"tree_depth_stats": compute_stats(tree_depths),
"tree_node_count_stats": compute_stats(node_counts),
"action_consistency": {
"most_common_action": most_common_action[0],
"consistency_rate": consistency_rate,
"action_distribution": action_counts,
},
}
def compare_configs(
self,
config_names: list[str] | None = None,
) -> dict[str, dict[str, Any]]:
"""
Compare performance across different configurations.
Args:
config_names: Specific config names to compare (all if None)
Returns:
Dictionary mapping config names to their statistics
"""
# Group results by configuration name
grouped: dict[str, list[ExperimentResult]] = {}
for result in self.results:
if result.config is None:
continue
config_name = result.config.get("name", "unnamed")
if config_names and config_name not in config_names:
continue
if config_name not in grouped:
grouped[config_name] = []
grouped[config_name].append(result)
# Compute statistics for each group
comparison = {}
for name, results in grouped.items():
values = [r.best_action_value for r in results]
times = [r.execution_time_ms for r in results]
visits = [r.best_action_visits for r in results]
comparison[name] = {
"num_runs": len(results),
"avg_value": statistics.mean(values) if values else 0.0,
"std_value": statistics.stdev(values) if len(values) > 1 else 0.0,
"avg_time_ms": statistics.mean(times) if times else 0.0,
"avg_visits": statistics.mean(visits) if visits else 0.0,
"value_per_ms": (
statistics.mean(values) / statistics.mean(times) if times and statistics.mean(times) > 0 else 0.0
),
}
return comparison
def analyze_seed_consistency(self, seed: int) -> dict[str, Any]:
"""
Analyze consistency of results for a specific seed.
Args:
seed: Seed value to analyze
Returns:
Analysis of determinism for this seed
"""
seed_results = [r for r in self.results if r.seed == seed]
if not seed_results:
return {"error": f"No results found for seed {seed}"}
# Check if all results are identical
best_actions = [r.best_action for r in seed_results]
best_values = [r.best_action_value for r in seed_results]
best_visits = [r.best_action_visits for r in seed_results]
is_deterministic = len(set(best_actions)) == 1 and len(set(best_values)) == 1 and len(set(best_visits)) == 1
return {
"seed": seed,
"num_runs": len(seed_results),
"is_deterministic": is_deterministic,
"unique_actions": list(set(best_actions)),
"value_variance": statistics.variance(best_values) if len(best_values) > 1 else 0.0,
"visits_variance": statistics.variance(best_visits) if len(best_visits) > 1 else 0.0,
}
def export_to_json(self, file_path: str) -> None:
"""
Export all results to JSON file.
Args:
file_path: Path to output file
"""
data = {
"name": self.name,
"created_at": self.created_at,
"num_experiments": len(self.results),
"results": [r.to_dict() for r in self.results],
"summary": self.get_summary_statistics(),
}
path = Path(file_path)
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w") as f:
json.dump(data, f, indent=2)
def export_to_csv(self, file_path: str) -> None:
"""
Export results to CSV file for spreadsheet analysis.
Args:
file_path: Path to output file
"""
if not self.results:
return
path = Path(file_path)
path.parent.mkdir(parents=True, exist_ok=True)
# Define CSV columns
fieldnames = [
"experiment_id",
"timestamp",
"seed",
"config_name",
"num_iterations",
"exploration_weight",
"best_action",
"best_action_value",
"best_action_visits",
"root_visits",
"total_simulations",
"execution_time_ms",
"cache_hit_rate",
"tree_depth",
"tree_node_count",
"branching_factor",
]
with open(path, "w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
for result in self.results:
row = {
"experiment_id": result.experiment_id,
"timestamp": result.timestamp,
"seed": result.seed,
"config_name": (result.config.get("name", "unnamed") if result.config else "unknown"),
"num_iterations": (result.config.get("num_iterations", 0) if result.config else 0),
"exploration_weight": (result.config.get("exploration_weight", 0) if result.config else 0),
"best_action": result.best_action,
"best_action_value": result.best_action_value,
"best_action_visits": result.best_action_visits,
"root_visits": result.root_visits,
"total_simulations": result.total_simulations,
"execution_time_ms": result.execution_time_ms,
"cache_hit_rate": result.cache_hit_rate,
"tree_depth": result.tree_depth,
"tree_node_count": result.tree_node_count,
"branching_factor": result.branching_factor,
}
writer.writerow(row)
@classmethod
def load_from_json(cls, file_path: str) -> ExperimentTracker:
"""
Load experiment tracker from JSON file.
Args:
file_path: Path to JSON file
Returns:
Loaded ExperimentTracker
"""
with open(file_path) as f:
data = json.load(f)
tracker = cls(name=data.get("name", "loaded_experiments"))
tracker.created_at = data.get("created_at", tracker.created_at)
for result_data in data.get("results", []):
tracker.results.append(ExperimentResult.from_dict(result_data))
return tracker
def clear(self) -> None:
"""Clear all results."""
self.results.clear()
def __len__(self) -> int:
return len(self.results)
def __repr__(self) -> str:
return f"ExperimentTracker(name={self.name!r}, num_results={len(self.results)})"
def run_determinism_test(
engine_factory,
config: MCTSConfig,
num_runs: int = 3,
) -> tuple[bool, dict[str, Any]]:
"""
Test that MCTS produces deterministic results with same seed.
Args:
engine_factory: Factory function to create MCTSEngine
config: Configuration to test
num_runs: Number of runs to compare
Returns:
Tuple of (is_deterministic, analysis_dict)
"""
ExperimentTracker(name="determinism_test")
# This is a stub - actual implementation would run the engine
# Results would be compared to verify determinism
analysis = {
"config": config.to_dict(),
"num_runs": num_runs,
"is_deterministic": True, # Would be computed from actual runs
"message": "Determinism test requires actual engine execution",
}
return True, analysis