langgraph-mcts-demo / src /training /unified_orchestrator.py
ianshank
feat: add personality output and bug fixes
40ee6b4
"""
Unified Training Orchestrator for LangGraph Multi-Agent MCTS with DeepMind-Style Learning.
Coordinates:
- HRM Agent
- TRM Agent
- Neural MCTS
- Policy-Value Network
- Self-play data generation
- Training loops
- Evaluation
- Checkpointing
"""
import time
from collections.abc import Callable
from pathlib import Path
from typing import Any
import torch
import torch.nn as nn
from torch.cuda.amp import GradScaler, autocast
from ..agents.hrm_agent import HRMLoss, create_hrm_agent
from ..agents.trm_agent import TRMLoss, create_trm_agent
from ..framework.mcts.neural_mcts import GameState, NeuralMCTS, SelfPlayCollector
from ..models.policy_value_net import (
AlphaZeroLoss,
create_policy_value_network,
)
from .performance_monitor import PerformanceMonitor, TimingContext
from .replay_buffer import Experience, PrioritizedReplayBuffer, collate_experiences
from .system_config import SystemConfig
class UnifiedTrainingOrchestrator:
"""
Complete training pipeline integrating all framework components.
This orchestrator manages:
1. Self-play data generation using MCTS
2. Neural network training (policy-value)
3. HRM agent training
4. TRM agent training
5. Evaluation and checkpointing
6. Performance monitoring
"""
def __init__(
self,
config: SystemConfig,
initial_state_fn: Callable[[], GameState],
board_size: int = 19,
):
"""
Initialize training orchestrator.
Args:
config: System configuration
initial_state_fn: Function that returns initial game state
board_size: Board/grid size for spatial games
"""
self.config = config
self.initial_state_fn = initial_state_fn
self.board_size = board_size
# Setup device
self.device = config.device
torch.manual_seed(config.seed)
# Initialize performance monitor
self.monitor = PerformanceMonitor(
window_size=100,
enable_gpu_monitoring=(self.device != "cpu"),
)
# Initialize components
self._initialize_components()
# Training state
self.current_iteration = 0
self.best_win_rate = 0.0
self.best_model_path = None
# Setup paths
self._setup_paths()
# Setup experiment tracking
if config.use_wandb:
self._setup_wandb()
def _initialize_components(self):
"""Initialize all framework components."""
print("Initializing components...")
# Policy-Value Network
self.policy_value_net = create_policy_value_network(
config=self.config.neural_net,
board_size=self.board_size,
device=self.device,
)
print(f" ✓ Policy-Value Network: {self.policy_value_net.get_parameter_count():,} parameters")
# HRM Agent
self.hrm_agent = create_hrm_agent(self.config.hrm, self.device)
print(f" ✓ HRM Agent: {self.hrm_agent.get_parameter_count():,} parameters")
# TRM Agent
self.trm_agent = create_trm_agent(
self.config.trm, output_dim=self.config.neural_net.action_size, device=self.device
)
print(f" ✓ TRM Agent: {self.trm_agent.get_parameter_count():,} parameters")
# Neural MCTS
self.mcts = NeuralMCTS(
policy_value_network=self.policy_value_net,
config=self.config.mcts,
device=self.device,
)
print(" ✓ Neural MCTS initialized")
# Self-play collector
self.self_play_collector = SelfPlayCollector(mcts=self.mcts, config=self.config.mcts)
# Optimizers
self._setup_optimizers()
# Loss functions
self.pv_loss_fn = AlphaZeroLoss(value_loss_weight=1.0)
self.hrm_loss_fn = HRMLoss(ponder_weight=0.01)
self.trm_loss_fn = TRMLoss(
task_loss_fn=nn.MSELoss(),
supervision_weight_decay=self.config.trm.supervision_weight_decay,
)
# Replay buffer
self.replay_buffer = PrioritizedReplayBuffer(
capacity=self.config.training.buffer_size,
alpha=0.6,
beta_start=0.4,
beta_frames=self.config.training.games_per_iteration * 10,
)
# Mixed precision scaler
self.scaler = GradScaler() if self.config.use_mixed_precision else None
def _setup_optimizers(self):
"""Setup optimizers and learning rate schedulers."""
# Policy-Value optimizer
self.pv_optimizer = torch.optim.SGD(
self.policy_value_net.parameters(),
lr=self.config.training.learning_rate,
momentum=self.config.training.momentum,
weight_decay=self.config.training.weight_decay,
)
# HRM optimizer
self.hrm_optimizer = torch.optim.Adam(self.hrm_agent.parameters(), lr=1e-3)
# TRM optimizer
self.trm_optimizer = torch.optim.Adam(self.trm_agent.parameters(), lr=1e-3)
# Learning rate scheduler for policy-value network
if self.config.training.lr_schedule == "cosine":
self.pv_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.pv_optimizer, T_max=100)
elif self.config.training.lr_schedule == "step":
self.pv_scheduler = torch.optim.lr_scheduler.StepLR(
self.pv_optimizer,
step_size=self.config.training.lr_decay_steps,
gamma=self.config.training.lr_decay_gamma,
)
else:
self.pv_scheduler = None
def _setup_paths(self):
"""Setup directory paths."""
self.checkpoint_dir = Path(self.config.checkpoint_dir)
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
self.data_dir = Path(self.config.data_dir)
self.data_dir.mkdir(parents=True, exist_ok=True)
self.log_dir = Path(self.config.log_dir)
self.log_dir.mkdir(parents=True, exist_ok=True)
def _setup_wandb(self):
"""Setup Weights & Biases experiment tracking."""
try:
import wandb
wandb.init(
project=self.config.wandb_project,
entity=self.config.wandb_entity,
config=self.config.to_dict(),
name=f"run_{time.strftime('%Y%m%d_%H%M%S')}",
)
print(" ✓ Weights & Biases initialized")
except ImportError:
print(" ⚠️ wandb not installed, skipping")
self.config.use_wandb = False
async def train_iteration(self, iteration: int) -> dict[str, Any]:
"""
Execute single training iteration.
Args:
iteration: Current iteration number
Returns:
Dictionary of metrics
"""
print(f"\n{'=' * 80}")
print(f"Training Iteration {iteration}")
print(f"{'=' * 80}")
metrics = {}
# Phase 1: Self-play data generation
print("\n[1/5] Generating self-play data...")
with TimingContext(self.monitor, "self_play_generation"):
game_data = await self._generate_self_play_data()
metrics["games_generated"] = len(game_data)
print(f" Generated {len(game_data)} training examples")
# Phase 2: Policy-Value network training
print("\n[2/5] Training Policy-Value Network...")
with TimingContext(self.monitor, "pv_training"):
pv_metrics = await self._train_policy_value_network()
metrics.update(pv_metrics)
# Phase 3: HRM agent training (optional, if using HRM)
if hasattr(self, "hrm_agent"):
print("\n[3/5] Training HRM Agent...")
with TimingContext(self.monitor, "hrm_training"):
hrm_metrics = await self._train_hrm_agent()
metrics.update(hrm_metrics)
# Phase 4: TRM agent training (optional, if using TRM)
if hasattr(self, "trm_agent"):
print("\n[4/5] Training TRM Agent...")
with TimingContext(self.monitor, "trm_training"):
trm_metrics = await self._train_trm_agent()
metrics.update(trm_metrics)
# Phase 5: Evaluation
print("\n[5/5] Evaluation...")
if iteration % self.config.training.checkpoint_interval == 0:
eval_metrics = await self._evaluate()
metrics.update(eval_metrics)
# Save checkpoint if improved
if eval_metrics.get("win_rate", 0) > self.best_win_rate:
self.best_win_rate = eval_metrics["win_rate"]
self._save_checkpoint(iteration, metrics, is_best=True)
print(f" ✓ New best model! Win rate: {self.best_win_rate:.2%}")
# Log metrics
self._log_metrics(iteration, metrics)
# Performance check
self.monitor.alert_if_slow()
return metrics
async def _generate_self_play_data(self) -> list[Experience]:
"""Generate training data from self-play games."""
num_games = self.config.training.games_per_iteration
# In production, this would use parallel actors
# For simplicity, we'll do sequential self-play
all_examples = []
for game_idx in range(num_games):
examples = await self.self_play_collector.play_game(
initial_state=self.initial_state_fn(),
temperature_threshold=self.config.mcts.temperature_threshold,
)
# Convert to Experience objects
for ex in examples:
all_examples.append(Experience(state=ex.state, policy=ex.policy_target, value=ex.value_target))
if (game_idx + 1) % 5 == 0:
print(f" Generated {game_idx + 1}/{num_games} games...")
# Add to replay buffer
self.replay_buffer.add_batch(all_examples)
return all_examples
async def _train_policy_value_network(self) -> dict[str, float]:
"""Train policy-value network on replay buffer data."""
if not self.replay_buffer.is_ready(self.config.training.batch_size):
print(" Replay buffer not ready, skipping...")
return {"policy_loss": 0.0, "value_loss": 0.0}
self.policy_value_net.train()
total_policy_loss = 0.0
total_value_loss = 0.0
num_batches = 10 # Train for 10 batches per iteration
for _ in range(num_batches):
# Sample batch
experiences, indices, weights = self.replay_buffer.sample(self.config.training.batch_size)
states, policies, values = collate_experiences(experiences)
states = states.to(self.device)
policies = policies.to(self.device)
values = values.to(self.device)
weights = torch.from_numpy(weights).to(self.device)
# Forward pass
if self.config.use_mixed_precision and self.scaler:
with autocast():
policy_logits, value_pred = self.policy_value_net(states)
loss, loss_dict = self.pv_loss_fn(policy_logits, value_pred, policies, values)
# Apply importance sampling weights
loss = (loss * weights).mean()
# Backward pass with mixed precision
self.pv_optimizer.zero_grad()
self.scaler.scale(loss).backward()
self.scaler.step(self.pv_optimizer)
self.scaler.update()
else:
policy_logits, value_pred = self.policy_value_net(states)
loss, loss_dict = self.pv_loss_fn(policy_logits, value_pred, policies, values)
loss = (loss * weights).mean()
self.pv_optimizer.zero_grad()
loss.backward()
self.pv_optimizer.step()
# Update priorities in replay buffer
with torch.no_grad():
td_errors = torch.abs(value_pred.squeeze() - values)
self.replay_buffer.update_priorities(indices, td_errors.cpu().numpy())
total_policy_loss += loss_dict["policy"]
total_value_loss += loss_dict["value"]
# Log losses
self.monitor.log_loss(loss_dict["policy"], loss_dict["value"], loss_dict["total"])
# Step learning rate scheduler
if self.pv_scheduler:
self.pv_scheduler.step()
avg_policy_loss = total_policy_loss / num_batches
avg_value_loss = total_value_loss / num_batches
print(f" Policy Loss: {avg_policy_loss:.4f}, Value Loss: {avg_value_loss:.4f}")
return {"policy_loss": avg_policy_loss, "value_loss": avg_value_loss}
async def _train_hrm_agent(self) -> dict[str, float]:
"""Train HRM agent (placeholder for domain-specific implementation)."""
# This would require domain-specific data and tasks
# For now, return dummy metrics
return {"hrm_halt_step": 5.0, "hrm_ponder_cost": 0.1}
async def _train_trm_agent(self) -> dict[str, float]:
"""Train TRM agent (placeholder for domain-specific implementation)."""
# This would require domain-specific data and tasks
# For now, return dummy metrics
return {"trm_convergence_step": 8.0, "trm_final_residual": 0.01}
async def _evaluate(self) -> dict[str, float]:
"""Evaluate current model against baseline."""
# Simplified evaluation: play games against previous best
# In production, this would be more sophisticated
win_rate = 0.55 # Placeholder
return {
"win_rate": win_rate,
"eval_games": self.config.training.evaluation_games,
}
def _save_checkpoint(self, iteration: int, metrics: dict, is_best: bool = False):
"""Save model checkpoint."""
checkpoint = {
"iteration": iteration,
"policy_value_net": self.policy_value_net.state_dict(),
"hrm_agent": self.hrm_agent.state_dict(),
"trm_agent": self.trm_agent.state_dict(),
"pv_optimizer": self.pv_optimizer.state_dict(),
"hrm_optimizer": self.hrm_optimizer.state_dict(),
"trm_optimizer": self.trm_optimizer.state_dict(),
"config": self.config.to_dict(),
"metrics": metrics,
"best_win_rate": self.best_win_rate,
}
# Save regular checkpoint
path = self.checkpoint_dir / f"checkpoint_iter_{iteration}.pt"
torch.save(checkpoint, path)
print(f" ✓ Checkpoint saved: {path}")
# Save best model
if is_best:
best_path = self.checkpoint_dir / "best_model.pt"
torch.save(checkpoint, best_path)
self.best_model_path = best_path
print(f" ✓ Best model saved: {best_path}")
def _log_metrics(self, iteration: int, metrics: dict):
"""Log metrics to console and tracking systems."""
print(f"\n[Metrics Summary - Iteration {iteration}]")
for key, value in metrics.items():
if isinstance(value, float):
print(f" {key}: {value:.4f}")
else:
print(f" {key}: {value}")
# Log to wandb
if self.config.use_wandb:
try:
import wandb
wandb_metrics = self.monitor.export_to_wandb(iteration)
wandb_metrics.update(metrics)
wandb.log(wandb_metrics, step=iteration)
except Exception as e:
print(f" ⚠️ Failed to log to wandb: {e}")
async def train(self, num_iterations: int):
"""
Run complete training loop.
Args:
num_iterations: Number of training iterations
"""
print("\n" + "=" * 80)
print("Starting Training")
print("=" * 80)
print(f"Total iterations: {num_iterations}")
print(f"Device: {self.device}")
print(f"Mixed precision: {self.config.use_mixed_precision}")
start_time = time.time()
for iteration in range(1, num_iterations + 1):
self.current_iteration = iteration
try:
_ = await self.train_iteration(iteration)
# Check early stopping
if self._should_early_stop(iteration):
print("\n⚠️ Early stopping triggered")
break
except KeyboardInterrupt:
print("\n⚠️ Training interrupted by user")
break
except Exception as e:
print(f"\n❌ Error in iteration {iteration}: {e}")
import traceback
traceback.print_exc()
break
elapsed = time.time() - start_time
print(f"\n{'=' * 80}")
print(f"Training completed in {elapsed / 3600:.2f} hours")
print(f"Best win rate: {self.best_win_rate:.2%}")
print(f"{'=' * 80}\n")
# Print final performance summary
self.monitor.print_summary()
def _should_early_stop(self, iteration: int) -> bool:
"""Check early stopping criteria."""
# Placeholder: implement actual early stopping logic
_ = iteration # noqa: F841
return False
def load_checkpoint(self, path: str):
"""Load checkpoint from file."""
checkpoint = torch.load(path, map_location=self.device, weights_only=True)
self.policy_value_net.load_state_dict(checkpoint["policy_value_net"])
self.hrm_agent.load_state_dict(checkpoint["hrm_agent"])
self.trm_agent.load_state_dict(checkpoint["trm_agent"])
self.pv_optimizer.load_state_dict(checkpoint["pv_optimizer"])
self.hrm_optimizer.load_state_dict(checkpoint["hrm_optimizer"])
self.trm_optimizer.load_state_dict(checkpoint["trm_optimizer"])
self.current_iteration = checkpoint["iteration"]
self.best_win_rate = checkpoint.get("best_win_rate", 0.0)
print(f"✓ Loaded checkpoint from iteration {self.current_iteration}")