langgraph-mcts-demo / src /training /experiment_tracker.py
ianshank
feat: add personality output and bug fixes
40ee6b4
"""
Experiment Tracking Integration Module.
Provides unified interface for:
- Braintrust experiment tracking
- Weights & Biases (W&B) logging
- Metric collection and visualization
- Model artifact versioning
"""
import logging
import os
import time
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
@dataclass
class ExperimentConfig:
"""Configuration for experiment tracking."""
project_name: str
experiment_name: str
tags: list[str] = field(default_factory=list)
description: str = ""
save_artifacts: bool = True
log_frequency: int = 1 # Log every N steps
@dataclass
class TrainingMetrics:
"""Standard training metrics."""
epoch: int
step: int
train_loss: float
val_loss: float | None = None
accuracy: float | None = None
learning_rate: float | None = None
timestamp: float = field(default_factory=time.time)
custom_metrics: dict[str, float] = field(default_factory=dict)
class BraintrustTracker:
"""
Braintrust experiment tracking integration.
Provides:
- Experiment initialization and management
- Metric logging with automatic visualization
- Hyperparameter tracking
- Model evaluation scoring
- Artifact versioning
"""
def __init__(self, api_key: str | None = None, project_name: str = "mcts-neural-meta-controller"):
"""
Initialize Braintrust tracker.
Args:
api_key: Braintrust API key (or from BRAINTRUST_API_KEY env var)
project_name: Project name in Braintrust
"""
self.api_key = api_key or os.getenv("BRAINTRUST_API_KEY")
self.project_name = project_name
self._experiment = None
self._experiment_id = None
self._metrics_buffer: list[dict[str, Any]] = []
self._initialized = False
if not self.api_key:
logger.warning("BRAINTRUST_API_KEY not set. Using offline mode.")
self._offline_mode = True
else:
self._offline_mode = False
self._initialize_client()
def _initialize_client(self):
"""Initialize Braintrust client."""
try:
import braintrust
braintrust.login(api_key=self.api_key)
self._bt = braintrust
self._initialized = True
logger.info(f"Braintrust client initialized for project: {self.project_name}")
except ImportError:
logger.error("braintrust library not installed. Run: pip install braintrust")
self._offline_mode = True
except Exception as e:
logger.error(f"Failed to initialize Braintrust: {e}")
self._offline_mode = True
def init_experiment(
self,
name: str,
description: str = "",
tags: list[str] | None = None,
metadata: dict[str, Any] | None = None,
) -> str:
"""
Initialize a new experiment.
Args:
name: Experiment name (e.g., "rnn_meta_controller_v2")
description: Experiment description
tags: List of tags for filtering
metadata: Additional metadata
Returns:
Experiment ID
"""
if self._offline_mode:
exp_id = f"offline_{int(time.time())}"
logger.info(f"Created offline experiment: {exp_id}")
self._experiment_id = exp_id
self._experiment_config = {
"name": name,
"description": description,
"tags": tags or [],
"metadata": metadata or {},
"start_time": datetime.now().isoformat(),
}
return exp_id
try:
self._experiment = self._bt.init(
project=self.project_name,
experiment=name,
)
self._experiment_id = self._experiment.id
logger.info(f"Created Braintrust experiment: {name} (ID: {self._experiment_id})")
return self._experiment_id
except Exception as e:
logger.error(f"Failed to create experiment: {e}")
return self.init_experiment(name, description, tags, metadata) # Fallback to offline
def log_hyperparameters(self, params: dict[str, Any]):
"""
Log hyperparameters for the experiment.
Args:
params: Dictionary of hyperparameters
"""
logger.info(f"Logging hyperparameters: {params}")
if self._offline_mode:
self._experiment_config["hyperparameters"] = params
return
try:
if self._experiment:
# Braintrust uses metadata for hyperparameters
self._experiment.log(
input="hyperparameters",
output=params,
metadata={"type": "hyperparameters"},
)
except Exception as e:
logger.error(f"Failed to log hyperparameters: {e}")
def log_metric(
self,
name: str,
value: float,
step: int | None = None,
timestamp: float | None = None,
):
"""
Log a single metric.
Args:
name: Metric name
value: Metric value
step: Optional step number
timestamp: Optional timestamp
"""
metric_data = {
"name": name,
"value": value,
"step": step or len(self._metrics_buffer),
"timestamp": timestamp or time.time(),
}
self._metrics_buffer.append(metric_data)
if self._offline_mode:
logger.debug(f"Metric logged (offline): {name}={value}")
return
try:
if self._experiment:
self._experiment.log(
input=f"metric_{name}",
output={"value": value},
scores={name: value},
metadata={"step": step},
)
except Exception as e:
logger.error(f"Failed to log metric {name}: {e}")
def log_training_step(self, metrics: TrainingMetrics):
"""
Log a complete training step.
Args:
metrics: TrainingMetrics object
"""
self.log_metric("train_loss", metrics.train_loss, step=metrics.step)
if metrics.val_loss is not None:
self.log_metric("val_loss", metrics.val_loss, step=metrics.step)
if metrics.accuracy is not None:
self.log_metric("accuracy", metrics.accuracy, step=metrics.step)
if metrics.learning_rate is not None:
self.log_metric("learning_rate", metrics.learning_rate, step=metrics.step)
for key, value in metrics.custom_metrics.items():
self.log_metric(key, value, step=metrics.step)
def log_evaluation(
self,
input_data: Any,
output: Any,
expected: Any,
scores: dict[str, float],
metadata: dict[str, Any] | None = None,
):
"""
Log an evaluation result.
Args:
input_data: Input to the model
output: Model output
expected: Expected output
scores: Dictionary of scores (e.g., accuracy, f1)
metadata: Additional metadata
"""
if self._offline_mode:
logger.info(f"Evaluation logged (offline): scores={scores}")
return
try:
if self._experiment:
self._experiment.log(
input=input_data,
output=output,
expected=expected,
scores=scores,
metadata=metadata or {},
)
except Exception as e:
logger.error(f"Failed to log evaluation: {e}")
def log_artifact(self, path: str | Path, name: str | None = None):
"""
Log a model artifact.
Args:
path: Path to artifact file
name: Optional artifact name
"""
path = Path(path)
if not path.exists():
logger.warning(f"Artifact not found: {path}")
return
logger.info(f"Logging artifact: {path}")
if self._offline_mode:
if "artifacts" not in self._experiment_config:
self._experiment_config["artifacts"] = []
self._experiment_config["artifacts"].append(str(path))
return
# Braintrust artifact logging would go here
# For now, just log the path
try:
if self._experiment:
self._experiment.log(
input=f"artifact_{name or path.name}",
output={"path": str(path), "name": name or path.name},
metadata={"artifact_path": str(path), "artifact_name": name or path.name},
)
except Exception as e:
logger.error(f"Failed to log artifact: {e}")
def get_summary(self) -> dict[str, Any]:
"""
Get experiment summary.
Returns:
Dictionary with experiment summary
"""
if self._offline_mode:
return {
"id": self._experiment_id,
"config": self._experiment_config,
"metrics_count": len(self._metrics_buffer),
"offline": True,
}
return {
"id": self._experiment_id,
"project": self.project_name,
"metrics_count": len(self._metrics_buffer),
"offline": False,
}
def end_experiment(self):
"""End the current experiment."""
summary = self.get_summary()
logger.info(f"Experiment ended: {summary}")
if not self._offline_mode and self._experiment:
try:
# Braintrust experiments auto-close, but we'll try explicit close if available
if hasattr(self._experiment, "close"):
self._experiment.close()
elif hasattr(self._experiment, "flush"):
self._experiment.flush()
except Exception as e:
logger.error(f"Failed to end experiment: {e}")
self._experiment = None
self._experiment_id = None
self._metrics_buffer = []
return summary
class WandBTracker:
"""
Weights & Biases experiment tracking integration.
Provides:
- Real-time metric visualization
- Hyperparameter sweep management
- Model artifact logging
- Collaborative experiment comparison
"""
def __init__(
self,
api_key: str | None = None,
project_name: str = "mcts-neural-meta-controller",
entity: str | None = None,
):
"""
Initialize W&B tracker.
Args:
api_key: W&B API key (or from WANDB_API_KEY env var)
project_name: Project name in W&B
entity: W&B entity (team or username)
"""
self.api_key = api_key or os.getenv("WANDB_API_KEY")
self.project_name = project_name
self.entity = entity
self._run = None
self._initialized = False
self._offline_mode = os.getenv("WANDB_MODE") == "offline"
if not self.api_key and not self._offline_mode:
logger.warning("WANDB_API_KEY not set. Using offline mode.")
self._offline_mode = True
os.environ["WANDB_MODE"] = "offline"
else:
self._initialize_client()
def _initialize_client(self):
"""Initialize W&B client."""
try:
import wandb
if self.api_key:
wandb.login(key=self.api_key)
self._wandb = wandb
self._initialized = True
logger.info(f"W&B client initialized for project: {self.project_name}")
except ImportError:
logger.error("wandb library not installed. Run: pip install wandb")
self._offline_mode = True
except Exception as e:
logger.error(f"Failed to initialize W&B: {e}")
self._offline_mode = True
def init_run(
self,
name: str,
config: dict[str, Any] | None = None,
tags: list[str] | None = None,
notes: str = "",
):
"""
Initialize a new W&B run.
Args:
name: Run name
config: Configuration dictionary
tags: List of tags
notes: Run notes/description
Returns:
Run object
"""
if self._offline_mode:
logger.info(f"W&B run initialized (offline mode): {name}")
self._run_config = config or {}
return None
try:
self._run = self._wandb.init(
project=self.project_name,
entity=self.entity,
name=name,
config=config,
tags=tags,
notes=notes,
)
logger.info(f"W&B run initialized: {name}")
return self._run
except Exception as e:
logger.error(f"Failed to initialize W&B run: {e}")
self._offline_mode = True
return None
def log(self, metrics: dict[str, Any], step: int | None = None):
"""
Log metrics to W&B.
Args:
metrics: Dictionary of metrics
step: Optional step number
"""
if self._offline_mode:
logger.debug(f"W&B metrics (offline): {metrics}")
return
try:
if self._run:
self._wandb.log(metrics, step=step)
except Exception as e:
logger.error(f"Failed to log to W&B: {e}")
def log_training_step(self, metrics: TrainingMetrics):
"""
Log a complete training step to W&B.
Args:
metrics: TrainingMetrics object
"""
log_data = {
"epoch": metrics.epoch,
"train_loss": metrics.train_loss,
}
if metrics.val_loss is not None:
log_data["val_loss"] = metrics.val_loss
if metrics.accuracy is not None:
log_data["accuracy"] = metrics.accuracy
if metrics.learning_rate is not None:
log_data["learning_rate"] = metrics.learning_rate
log_data.update(metrics.custom_metrics)
self.log(log_data, step=metrics.step)
def update_config(self, config: dict[str, Any]):
"""
Update run configuration.
Args:
config: Configuration updates
"""
if self._offline_mode:
self._run_config.update(config)
return
try:
if self._run:
self._wandb.config.update(config)
except Exception as e:
logger.error(f"Failed to update W&B config: {e}")
def watch_model(self, model, log_freq: int = 100):
"""
Watch model gradients and parameters.
Args:
model: PyTorch model
log_freq: Logging frequency
"""
if self._offline_mode:
return
try:
if self._run:
self._wandb.watch(model, log="all", log_freq=log_freq)
except Exception as e:
logger.error(f"Failed to watch model: {e}")
def log_artifact(self, path: str | Path, name: str, artifact_type: str = "model"):
"""
Log artifact to W&B.
Args:
path: Path to artifact
name: Artifact name
artifact_type: Type of artifact (model, dataset, etc.)
"""
if self._offline_mode:
logger.info(f"Artifact logged (offline): {path}")
return
try:
artifact = self._wandb.Artifact(name, type=artifact_type)
artifact.add_file(str(path))
if self._run:
self._run.log_artifact(artifact)
logger.info(f"Artifact logged: {name}")
except Exception as e:
logger.error(f"Failed to log artifact: {e}")
def finish(self):
"""Finish the W&B run."""
if self._offline_mode:
logger.info("W&B run finished (offline)")
return
try:
if self._run:
self._run.finish()
logger.info("W&B run finished")
except Exception as e:
logger.error(f"Failed to finish W&B run: {e}")
class UnifiedExperimentTracker:
"""
Unified experiment tracker that coordinates both Braintrust and W&B.
Provides single interface for:
- Dual logging to both platforms
- Fallback handling
- Consistent metric tracking
"""
def __init__(
self,
braintrust_api_key: str | None = None,
wandb_api_key: str | None = None,
project_name: str = "mcts-neural-meta-controller",
):
"""
Initialize unified tracker.
Args:
braintrust_api_key: Braintrust API key
wandb_api_key: W&B API key
project_name: Project name for both platforms
"""
self.bt = BraintrustTracker(api_key=braintrust_api_key, project_name=project_name)
self.wandb = WandBTracker(api_key=wandb_api_key, project_name=project_name)
self.project_name = project_name
def init_experiment(
self,
name: str,
config: dict[str, Any] | None = None,
description: str = "",
tags: list[str] | None = None,
):
"""
Initialize experiment on both platforms.
Args:
name: Experiment/run name
config: Configuration dictionary
description: Description
tags: List of tags
"""
self.bt.init_experiment(name, description, tags)
self.wandb.init_run(name, config, tags, description)
if config:
self.bt.log_hyperparameters(config)
logger.info(f"Unified experiment initialized: {name}")
def log_metrics(self, metrics: TrainingMetrics):
"""
Log training metrics to both platforms.
Args:
metrics: TrainingMetrics object
"""
self.bt.log_training_step(metrics)
self.wandb.log_training_step(metrics)
def log_evaluation(
self,
input_data: Any,
output: Any,
expected: Any,
scores: dict[str, float],
):
"""
Log evaluation to Braintrust.
Args:
input_data: Input data
output: Model output
expected: Expected output
scores: Evaluation scores
"""
self.bt.log_evaluation(input_data, output, expected, scores)
self.wandb.log(scores)
def log_artifact(self, path: str | Path, name: str):
"""
Log artifact to both platforms.
Args:
path: Path to artifact
name: Artifact name
"""
self.bt.log_artifact(path, name)
self.wandb.log_artifact(path, name)
def finish(self):
"""End tracking on both platforms."""
bt_summary = self.bt.end_experiment()
self.wandb.finish()
logger.info("Unified experiment ended")
return bt_summary