langgraph-mcts-demo / src /observability /braintrust_tracker.py
ianshank
feat: add personality output and bug fixes
40ee6b4
"""
Braintrust integration for experiment tracking in Neural Meta-Controller training.
Provides experiment logging, metric tracking, and model versioning capabilities.
"""
import os
from datetime import datetime
from typing import Any
# Check if braintrust is available
try:
import braintrust
BRAINTRUST_AVAILABLE = True
except ImportError:
BRAINTRUST_AVAILABLE = False
braintrust = None
class BraintrustTracker:
"""
Experiment tracker using Braintrust API for neural meta-controller training.
Provides:
- Experiment creation and management
- Metric logging (loss, accuracy, etc.)
- Hyperparameter tracking
- Model evaluation logging
- Training run comparison
"""
def __init__(
self,
project_name: str = "neural-meta-controller",
api_key: str | None = None,
auto_init: bool = True,
):
"""
Initialize Braintrust tracker.
Args:
project_name: Name of the Braintrust project
api_key: Braintrust API key (if None, reads from BRAINTRUST_API_KEY env var)
auto_init: Whether to initialize Braintrust client immediately
"""
self.project_name = project_name
self._api_key = api_key or os.environ.get("BRAINTRUST_API_KEY")
self._experiment: Any = None
self._current_span: Any = None
self._is_initialized = False
self._metrics_buffer: list[dict[str, Any]] = []
if not BRAINTRUST_AVAILABLE:
print("Warning: braintrust package not installed. Install with: pip install braintrust")
return
if auto_init and self._api_key:
self._initialize()
def _initialize(self) -> None:
"""Initialize Braintrust client with API key."""
if not BRAINTRUST_AVAILABLE:
return
if self._api_key:
braintrust.login(api_key=self._api_key)
self._is_initialized = True
@property
def is_available(self) -> bool:
"""Check if Braintrust is available and configured."""
return BRAINTRUST_AVAILABLE and self._is_initialized and self._api_key is not None
def start_experiment(
self,
experiment_name: str | None = None,
metadata: dict[str, Any] | None = None,
) -> Any | None:
"""
Start a new experiment run.
Args:
experiment_name: Optional name for the experiment (auto-generated if None)
metadata: Optional metadata to attach to the experiment
Returns:
Braintrust Experiment object or None if not available
"""
if not self.is_available:
return None
if experiment_name is None:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
experiment_name = f"meta_controller_training_{timestamp}"
try:
self._experiment = braintrust.init(
project=self.project_name,
experiment=experiment_name,
metadata=metadata or {},
)
return self._experiment
except Exception as e:
print(f"Warning: Failed to start Braintrust experiment: {e}")
return None
def log_hyperparameters(self, params: dict[str, Any]) -> None:
"""
Log hyperparameters for the current experiment.
Args:
params: Dictionary of hyperparameters
"""
if not self.is_available or self._experiment is None:
self._metrics_buffer.append({"type": "hyperparameters", "data": params})
return
try:
self._experiment.log(metadata=params)
except Exception as e:
print(f"Warning: Failed to log hyperparameters: {e}")
def log_training_step(
self,
epoch: int,
step: int,
loss: float,
metrics: dict[str, float] | None = None,
) -> None:
"""
Log a single training step.
Args:
epoch: Current epoch number
step: Current step/batch number
loss: Training loss value
metrics: Optional additional metrics (accuracy, etc.)
"""
if not self.is_available or self._experiment is None:
self._metrics_buffer.append(
{
"type": "training_step",
"epoch": epoch,
"step": step,
"loss": loss,
"metrics": metrics or {},
}
)
return
try:
log_data = {
"input": {"epoch": epoch, "step": step},
"output": {"loss": loss},
"scores": metrics or {},
}
self._experiment.log(**log_data)
except Exception as e:
print(f"Warning: Failed to log training step: {e}")
def log_epoch_summary(
self,
epoch: int,
train_loss: float,
val_loss: float | None = None,
train_accuracy: float | None = None,
val_accuracy: float | None = None,
additional_metrics: dict[str, float] | None = None,
) -> None:
"""
Log summary metrics for a completed epoch.
Args:
epoch: Epoch number
train_loss: Training loss for the epoch
val_loss: Optional validation loss
train_accuracy: Optional training accuracy
val_accuracy: Optional validation accuracy
additional_metrics: Optional additional metrics
"""
if not self.is_available or self._experiment is None:
self._metrics_buffer.append(
{
"type": "epoch_summary",
"epoch": epoch,
"train_loss": train_loss,
"val_loss": val_loss,
"train_accuracy": train_accuracy,
"val_accuracy": val_accuracy,
"additional_metrics": additional_metrics or {},
}
)
return
try:
scores = {
"train_loss": train_loss,
}
if val_loss is not None:
scores["val_loss"] = val_loss
if train_accuracy is not None:
scores["train_accuracy"] = train_accuracy
if val_accuracy is not None:
scores["val_accuracy"] = val_accuracy
if additional_metrics:
scores.update(additional_metrics)
self._experiment.log(
input={"epoch": epoch},
output={"completed": True},
scores=scores,
)
except Exception as e:
print(f"Warning: Failed to log epoch summary: {e}")
def log_evaluation(
self,
eval_type: str,
predictions: list[str],
ground_truth: list[str],
metrics: dict[str, float],
) -> None:
"""
Log model evaluation results.
Args:
eval_type: Type of evaluation (e.g., "validation", "test")
predictions: Model predictions
ground_truth: Ground truth labels
metrics: Computed metrics (accuracy, precision, recall, f1, etc.)
"""
if not self.is_available or self._experiment is None:
self._metrics_buffer.append(
{
"type": "evaluation",
"eval_type": eval_type,
"num_samples": len(predictions),
"metrics": metrics,
}
)
return
try:
self._experiment.log(
input={
"eval_type": eval_type,
"num_samples": len(predictions),
},
output={
"predictions_sample": predictions[:10],
"ground_truth_sample": ground_truth[:10],
},
scores=metrics,
)
except Exception as e:
print(f"Warning: Failed to log evaluation: {e}")
def log_model_prediction(
self,
input_features: dict[str, Any],
prediction: str,
confidence: float,
ground_truth: str | None = None,
) -> None:
"""
Log a single model prediction for analysis.
Args:
input_features: Input features used for prediction
prediction: Model's predicted agent
confidence: Prediction confidence score
ground_truth: Optional ground truth label
"""
if not self.is_available or self._experiment is None:
self._metrics_buffer.append(
{
"type": "prediction",
"input": input_features,
"prediction": prediction,
"confidence": confidence,
"ground_truth": ground_truth,
}
)
return
try:
scores = {"confidence": confidence}
if ground_truth:
scores["correct"] = float(prediction == ground_truth)
self._experiment.log(
input=input_features,
output={"prediction": prediction},
expected=ground_truth,
scores=scores,
)
except Exception as e:
print(f"Warning: Failed to log prediction: {e}")
def log_model_artifact(
self,
model_path: str,
model_type: str,
metrics: dict[str, float],
metadata: dict[str, Any] | None = None,
) -> None:
"""
Log a trained model artifact.
Args:
model_path: Path to the saved model
model_type: Type of model (e.g., "rnn", "bert")
metrics: Final model metrics
metadata: Optional additional metadata
"""
if not self.is_available or self._experiment is None:
self._metrics_buffer.append(
{
"type": "model_artifact",
"model_path": model_path,
"model_type": model_type,
"metrics": metrics,
"metadata": metadata or {},
}
)
return
try:
self._experiment.log(
input={
"model_path": model_path,
"model_type": model_type,
},
output={"saved": True},
scores=metrics,
metadata=metadata or {},
)
except Exception as e:
print(f"Warning: Failed to log model artifact: {e}")
def end_experiment(self) -> str | None:
"""
End the current experiment and return summary URL.
Returns:
URL to view the experiment in Braintrust dashboard, or None
"""
if not self.is_available or self._experiment is None:
return None
try:
summary = self._experiment.summarize()
self._experiment = None
return summary.experiment_url if hasattr(summary, "experiment_url") else None
except Exception as e:
print(f"Warning: Failed to end experiment: {e}")
return None
def get_buffered_metrics(self) -> list[dict[str, Any]]:
"""
Get all buffered metrics (useful when Braintrust is not available).
Returns:
List of buffered metric dictionaries
"""
return self._metrics_buffer.copy()
def clear_buffer(self) -> None:
"""Clear the metrics buffer."""
self._metrics_buffer.clear()
class BraintrustContextManager:
"""
Context manager for Braintrust experiment tracking.
Usage:
with BraintrustContextManager(
project_name="neural-meta-controller",
experiment_name="training_run_1"
) as tracker:
tracker.log_hyperparameters({"learning_rate": 0.001})
tracker.log_epoch_summary(1, train_loss=0.5, val_loss=0.4)
"""
def __init__(
self,
project_name: str = "neural-meta-controller",
experiment_name: str | None = None,
api_key: str | None = None,
metadata: dict[str, Any] | None = None,
):
"""
Initialize context manager.
Args:
project_name: Name of the Braintrust project
experiment_name: Optional experiment name
api_key: Optional API key
metadata: Optional experiment metadata
"""
self.project_name = project_name
self.experiment_name = experiment_name
self.api_key = api_key
self.metadata = metadata
self.tracker: BraintrustTracker | None = None
self.experiment_url: str | None = None
def __enter__(self) -> BraintrustTracker:
"""Start experiment tracking."""
self.tracker = BraintrustTracker(
project_name=self.project_name,
api_key=self.api_key,
)
self.tracker.start_experiment(
experiment_name=self.experiment_name,
metadata=self.metadata,
)
return self.tracker
def __exit__(self, exc_type, exc_val, exc_tb):
"""End experiment tracking."""
if self.tracker:
self.experiment_url = self.tracker.end_experiment()
return False
def create_training_tracker(
model_type: str = "rnn",
config: dict[str, Any] | None = None,
) -> BraintrustTracker:
"""
Create a pre-configured tracker for meta-controller training.
Args:
model_type: Type of model being trained ("rnn" or "bert")
config: Optional training configuration
Returns:
Configured BraintrustTracker instance
"""
tracker = BraintrustTracker(project_name="neural-meta-controller")
if tracker.is_available:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
experiment_name = f"{model_type}_training_{timestamp}"
metadata = {
"model_type": model_type,
"timestamp": timestamp,
}
if config:
metadata.update(config)
tracker.start_experiment(
experiment_name=experiment_name,
metadata=metadata,
)
return tracker
__all__ = [
"BraintrustTracker",
"BraintrustContextManager",
"create_training_tracker",
"BRAINTRUST_AVAILABLE",
]