Spaces:
Running
Running
| """ | |
| Braintrust integration for experiment tracking in Neural Meta-Controller training. | |
| Provides experiment logging, metric tracking, and model versioning capabilities. | |
| """ | |
| import os | |
| from datetime import datetime | |
| from typing import Any | |
| # Check if braintrust is available | |
| try: | |
| import braintrust | |
| BRAINTRUST_AVAILABLE = True | |
| except ImportError: | |
| BRAINTRUST_AVAILABLE = False | |
| braintrust = None | |
| class BraintrustTracker: | |
| """ | |
| Experiment tracker using Braintrust API for neural meta-controller training. | |
| Provides: | |
| - Experiment creation and management | |
| - Metric logging (loss, accuracy, etc.) | |
| - Hyperparameter tracking | |
| - Model evaluation logging | |
| - Training run comparison | |
| """ | |
| def __init__( | |
| self, | |
| project_name: str = "neural-meta-controller", | |
| api_key: str | None = None, | |
| auto_init: bool = True, | |
| ): | |
| """ | |
| Initialize Braintrust tracker. | |
| Args: | |
| project_name: Name of the Braintrust project | |
| api_key: Braintrust API key (if None, reads from BRAINTRUST_API_KEY env var) | |
| auto_init: Whether to initialize Braintrust client immediately | |
| """ | |
| self.project_name = project_name | |
| self._api_key = api_key or os.environ.get("BRAINTRUST_API_KEY") | |
| self._experiment: Any = None | |
| self._current_span: Any = None | |
| self._is_initialized = False | |
| self._metrics_buffer: list[dict[str, Any]] = [] | |
| if not BRAINTRUST_AVAILABLE: | |
| print("Warning: braintrust package not installed. Install with: pip install braintrust") | |
| return | |
| if auto_init and self._api_key: | |
| self._initialize() | |
| def _initialize(self) -> None: | |
| """Initialize Braintrust client with API key.""" | |
| if not BRAINTRUST_AVAILABLE: | |
| return | |
| if self._api_key: | |
| braintrust.login(api_key=self._api_key) | |
| self._is_initialized = True | |
| def is_available(self) -> bool: | |
| """Check if Braintrust is available and configured.""" | |
| return BRAINTRUST_AVAILABLE and self._is_initialized and self._api_key is not None | |
| def start_experiment( | |
| self, | |
| experiment_name: str | None = None, | |
| metadata: dict[str, Any] | None = None, | |
| ) -> Any | None: | |
| """ | |
| Start a new experiment run. | |
| Args: | |
| experiment_name: Optional name for the experiment (auto-generated if None) | |
| metadata: Optional metadata to attach to the experiment | |
| Returns: | |
| Braintrust Experiment object or None if not available | |
| """ | |
| if not self.is_available: | |
| return None | |
| if experiment_name is None: | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| experiment_name = f"meta_controller_training_{timestamp}" | |
| try: | |
| self._experiment = braintrust.init( | |
| project=self.project_name, | |
| experiment=experiment_name, | |
| metadata=metadata or {}, | |
| ) | |
| return self._experiment | |
| except Exception as e: | |
| print(f"Warning: Failed to start Braintrust experiment: {e}") | |
| return None | |
| def log_hyperparameters(self, params: dict[str, Any]) -> None: | |
| """ | |
| Log hyperparameters for the current experiment. | |
| Args: | |
| params: Dictionary of hyperparameters | |
| """ | |
| if not self.is_available or self._experiment is None: | |
| self._metrics_buffer.append({"type": "hyperparameters", "data": params}) | |
| return | |
| try: | |
| self._experiment.log(metadata=params) | |
| except Exception as e: | |
| print(f"Warning: Failed to log hyperparameters: {e}") | |
| def log_training_step( | |
| self, | |
| epoch: int, | |
| step: int, | |
| loss: float, | |
| metrics: dict[str, float] | None = None, | |
| ) -> None: | |
| """ | |
| Log a single training step. | |
| Args: | |
| epoch: Current epoch number | |
| step: Current step/batch number | |
| loss: Training loss value | |
| metrics: Optional additional metrics (accuracy, etc.) | |
| """ | |
| if not self.is_available or self._experiment is None: | |
| self._metrics_buffer.append( | |
| { | |
| "type": "training_step", | |
| "epoch": epoch, | |
| "step": step, | |
| "loss": loss, | |
| "metrics": metrics or {}, | |
| } | |
| ) | |
| return | |
| try: | |
| log_data = { | |
| "input": {"epoch": epoch, "step": step}, | |
| "output": {"loss": loss}, | |
| "scores": metrics or {}, | |
| } | |
| self._experiment.log(**log_data) | |
| except Exception as e: | |
| print(f"Warning: Failed to log training step: {e}") | |
| def log_epoch_summary( | |
| self, | |
| epoch: int, | |
| train_loss: float, | |
| val_loss: float | None = None, | |
| train_accuracy: float | None = None, | |
| val_accuracy: float | None = None, | |
| additional_metrics: dict[str, float] | None = None, | |
| ) -> None: | |
| """ | |
| Log summary metrics for a completed epoch. | |
| Args: | |
| epoch: Epoch number | |
| train_loss: Training loss for the epoch | |
| val_loss: Optional validation loss | |
| train_accuracy: Optional training accuracy | |
| val_accuracy: Optional validation accuracy | |
| additional_metrics: Optional additional metrics | |
| """ | |
| if not self.is_available or self._experiment is None: | |
| self._metrics_buffer.append( | |
| { | |
| "type": "epoch_summary", | |
| "epoch": epoch, | |
| "train_loss": train_loss, | |
| "val_loss": val_loss, | |
| "train_accuracy": train_accuracy, | |
| "val_accuracy": val_accuracy, | |
| "additional_metrics": additional_metrics or {}, | |
| } | |
| ) | |
| return | |
| try: | |
| scores = { | |
| "train_loss": train_loss, | |
| } | |
| if val_loss is not None: | |
| scores["val_loss"] = val_loss | |
| if train_accuracy is not None: | |
| scores["train_accuracy"] = train_accuracy | |
| if val_accuracy is not None: | |
| scores["val_accuracy"] = val_accuracy | |
| if additional_metrics: | |
| scores.update(additional_metrics) | |
| self._experiment.log( | |
| input={"epoch": epoch}, | |
| output={"completed": True}, | |
| scores=scores, | |
| ) | |
| except Exception as e: | |
| print(f"Warning: Failed to log epoch summary: {e}") | |
| def log_evaluation( | |
| self, | |
| eval_type: str, | |
| predictions: list[str], | |
| ground_truth: list[str], | |
| metrics: dict[str, float], | |
| ) -> None: | |
| """ | |
| Log model evaluation results. | |
| Args: | |
| eval_type: Type of evaluation (e.g., "validation", "test") | |
| predictions: Model predictions | |
| ground_truth: Ground truth labels | |
| metrics: Computed metrics (accuracy, precision, recall, f1, etc.) | |
| """ | |
| if not self.is_available or self._experiment is None: | |
| self._metrics_buffer.append( | |
| { | |
| "type": "evaluation", | |
| "eval_type": eval_type, | |
| "num_samples": len(predictions), | |
| "metrics": metrics, | |
| } | |
| ) | |
| return | |
| try: | |
| self._experiment.log( | |
| input={ | |
| "eval_type": eval_type, | |
| "num_samples": len(predictions), | |
| }, | |
| output={ | |
| "predictions_sample": predictions[:10], | |
| "ground_truth_sample": ground_truth[:10], | |
| }, | |
| scores=metrics, | |
| ) | |
| except Exception as e: | |
| print(f"Warning: Failed to log evaluation: {e}") | |
| def log_model_prediction( | |
| self, | |
| input_features: dict[str, Any], | |
| prediction: str, | |
| confidence: float, | |
| ground_truth: str | None = None, | |
| ) -> None: | |
| """ | |
| Log a single model prediction for analysis. | |
| Args: | |
| input_features: Input features used for prediction | |
| prediction: Model's predicted agent | |
| confidence: Prediction confidence score | |
| ground_truth: Optional ground truth label | |
| """ | |
| if not self.is_available or self._experiment is None: | |
| self._metrics_buffer.append( | |
| { | |
| "type": "prediction", | |
| "input": input_features, | |
| "prediction": prediction, | |
| "confidence": confidence, | |
| "ground_truth": ground_truth, | |
| } | |
| ) | |
| return | |
| try: | |
| scores = {"confidence": confidence} | |
| if ground_truth: | |
| scores["correct"] = float(prediction == ground_truth) | |
| self._experiment.log( | |
| input=input_features, | |
| output={"prediction": prediction}, | |
| expected=ground_truth, | |
| scores=scores, | |
| ) | |
| except Exception as e: | |
| print(f"Warning: Failed to log prediction: {e}") | |
| def log_model_artifact( | |
| self, | |
| model_path: str, | |
| model_type: str, | |
| metrics: dict[str, float], | |
| metadata: dict[str, Any] | None = None, | |
| ) -> None: | |
| """ | |
| Log a trained model artifact. | |
| Args: | |
| model_path: Path to the saved model | |
| model_type: Type of model (e.g., "rnn", "bert") | |
| metrics: Final model metrics | |
| metadata: Optional additional metadata | |
| """ | |
| if not self.is_available or self._experiment is None: | |
| self._metrics_buffer.append( | |
| { | |
| "type": "model_artifact", | |
| "model_path": model_path, | |
| "model_type": model_type, | |
| "metrics": metrics, | |
| "metadata": metadata or {}, | |
| } | |
| ) | |
| return | |
| try: | |
| self._experiment.log( | |
| input={ | |
| "model_path": model_path, | |
| "model_type": model_type, | |
| }, | |
| output={"saved": True}, | |
| scores=metrics, | |
| metadata=metadata or {}, | |
| ) | |
| except Exception as e: | |
| print(f"Warning: Failed to log model artifact: {e}") | |
| def end_experiment(self) -> str | None: | |
| """ | |
| End the current experiment and return summary URL. | |
| Returns: | |
| URL to view the experiment in Braintrust dashboard, or None | |
| """ | |
| if not self.is_available or self._experiment is None: | |
| return None | |
| try: | |
| summary = self._experiment.summarize() | |
| self._experiment = None | |
| return summary.experiment_url if hasattr(summary, "experiment_url") else None | |
| except Exception as e: | |
| print(f"Warning: Failed to end experiment: {e}") | |
| return None | |
| def get_buffered_metrics(self) -> list[dict[str, Any]]: | |
| """ | |
| Get all buffered metrics (useful when Braintrust is not available). | |
| Returns: | |
| List of buffered metric dictionaries | |
| """ | |
| return self._metrics_buffer.copy() | |
| def clear_buffer(self) -> None: | |
| """Clear the metrics buffer.""" | |
| self._metrics_buffer.clear() | |
| class BraintrustContextManager: | |
| """ | |
| Context manager for Braintrust experiment tracking. | |
| Usage: | |
| with BraintrustContextManager( | |
| project_name="neural-meta-controller", | |
| experiment_name="training_run_1" | |
| ) as tracker: | |
| tracker.log_hyperparameters({"learning_rate": 0.001}) | |
| tracker.log_epoch_summary(1, train_loss=0.5, val_loss=0.4) | |
| """ | |
| def __init__( | |
| self, | |
| project_name: str = "neural-meta-controller", | |
| experiment_name: str | None = None, | |
| api_key: str | None = None, | |
| metadata: dict[str, Any] | None = None, | |
| ): | |
| """ | |
| Initialize context manager. | |
| Args: | |
| project_name: Name of the Braintrust project | |
| experiment_name: Optional experiment name | |
| api_key: Optional API key | |
| metadata: Optional experiment metadata | |
| """ | |
| self.project_name = project_name | |
| self.experiment_name = experiment_name | |
| self.api_key = api_key | |
| self.metadata = metadata | |
| self.tracker: BraintrustTracker | None = None | |
| self.experiment_url: str | None = None | |
| def __enter__(self) -> BraintrustTracker: | |
| """Start experiment tracking.""" | |
| self.tracker = BraintrustTracker( | |
| project_name=self.project_name, | |
| api_key=self.api_key, | |
| ) | |
| self.tracker.start_experiment( | |
| experiment_name=self.experiment_name, | |
| metadata=self.metadata, | |
| ) | |
| return self.tracker | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| """End experiment tracking.""" | |
| if self.tracker: | |
| self.experiment_url = self.tracker.end_experiment() | |
| return False | |
| def create_training_tracker( | |
| model_type: str = "rnn", | |
| config: dict[str, Any] | None = None, | |
| ) -> BraintrustTracker: | |
| """ | |
| Create a pre-configured tracker for meta-controller training. | |
| Args: | |
| model_type: Type of model being trained ("rnn" or "bert") | |
| config: Optional training configuration | |
| Returns: | |
| Configured BraintrustTracker instance | |
| """ | |
| tracker = BraintrustTracker(project_name="neural-meta-controller") | |
| if tracker.is_available: | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| experiment_name = f"{model_type}_training_{timestamp}" | |
| metadata = { | |
| "model_type": model_type, | |
| "timestamp": timestamp, | |
| } | |
| if config: | |
| metadata.update(config) | |
| tracker.start_experiment( | |
| experiment_name=experiment_name, | |
| metadata=metadata, | |
| ) | |
| return tracker | |
| __all__ = [ | |
| "BraintrustTracker", | |
| "BraintrustContextManager", | |
| "create_training_tracker", | |
| "BRAINTRUST_AVAILABLE", | |
| ] | |