ianshank
feat: add personality output and bug fixes
40ee6b4
"""
Tiny Recursive Model (TRM) Agent.
Implements recursive refinement with:
- Deep supervision at all recursion levels
- Convergence detection
- Memory-efficient recursion
- Iterative improvement mechanism
Based on principles from:
- "Recursive Refinement Networks"
- "Deep Supervision for Neural Networks"
"""
from __future__ import annotations
from dataclasses import dataclass
import torch
import torch.nn as nn
from ..training.system_config import TRMConfig
@dataclass
class TRMOutput:
"""Output from TRM recursive processing."""
final_prediction: torch.Tensor # Final refined output
intermediate_predictions: list[torch.Tensor] # Predictions at each recursion
recursion_depth: int # Actual depth used
converged: bool # Whether convergence was achieved
convergence_step: int # Step at which convergence occurred
residual_norms: list[float] # L2 norms of residuals at each step
class RecursiveBlock(nn.Module):
"""
Core recursive processing block.
Applies the same transformation repeatedly, with residual connections.
"""
def __init__(self, config: TRMConfig):
super().__init__()
self.config = config
# Main processing pathway
self.transform = nn.Sequential(
nn.Linear(config.latent_dim, config.hidden_dim),
nn.LayerNorm(config.hidden_dim) if config.use_layer_norm else nn.Identity(),
nn.GELU(),
nn.Dropout(config.dropout),
nn.Linear(config.hidden_dim, config.latent_dim),
nn.LayerNorm(config.latent_dim) if config.use_layer_norm else nn.Identity(),
)
# Residual scaling (learned)
self.residual_scale = nn.Parameter(torch.ones(1))
def forward(self, x: torch.Tensor, iteration: int = 0) -> torch.Tensor: # noqa: ARG002
"""
Apply recursive transformation.
Args:
x: Input tensor [batch, ..., latent_dim]
iteration: Current recursion iteration (reserved for future iteration-dependent behavior)
Returns:
Refined tensor [batch, ..., latent_dim]
"""
# Residual connection with learned scaling
residual = self.transform(x)
return x + self.residual_scale * residual
class DeepSupervisionHead(nn.Module):
"""
Supervision head for intermediate predictions.
Enables training signal at each recursion level.
"""
def __init__(self, latent_dim: int, output_dim: int):
super().__init__()
self.head = nn.Sequential(
nn.Linear(latent_dim, latent_dim // 2),
nn.ReLU(),
nn.Linear(latent_dim // 2, output_dim),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Generate prediction from latent state."""
return self.head(x)
class TRMAgent(nn.Module):
"""
Tiny Recursive Model for iterative refinement.
Features:
- Shared weights across recursions (parameter efficiency)
- Deep supervision at all levels
- Automatic convergence detection
- Residual connections for stable gradients
"""
def __init__(self, config: TRMConfig, output_dim: int | None = None, device: str = "cpu"):
super().__init__()
self.config = config
self.device = device
self.output_dim = output_dim or config.latent_dim
# Initial encoding
self.encoder = nn.Sequential(
nn.Linear(config.latent_dim, config.hidden_dim),
nn.LayerNorm(config.hidden_dim) if config.use_layer_norm else nn.Identity(),
nn.GELU(),
nn.Linear(config.hidden_dim, config.latent_dim),
nn.LayerNorm(config.latent_dim) if config.use_layer_norm else nn.Identity(),
)
# Shared recursive block
self.recursive_block = RecursiveBlock(config)
# Deep supervision heads (one per recursion level)
if config.deep_supervision:
self.supervision_heads = nn.ModuleList(
[DeepSupervisionHead(config.latent_dim, self.output_dim) for _ in range(config.num_recursions)]
)
else:
# Single output head
self.output_head = DeepSupervisionHead(config.latent_dim, self.output_dim)
self.to(device)
def forward(
self,
x: torch.Tensor,
num_recursions: int | None = None,
check_convergence: bool = True,
) -> TRMOutput:
"""
Process input through recursive refinement.
Args:
x: Input tensor [batch, ..., latent_dim]
num_recursions: Number of recursions (defaults to config)
check_convergence: Whether to check for early convergence
Returns:
TRMOutput with final and intermediate predictions
"""
num_recursions = num_recursions or self.config.num_recursions
# Initial encoding
latent = self.encoder(x)
previous_latent = latent.clone()
# Tracking
intermediate_predictions = []
residual_norms = []
converged = False
convergence_step = num_recursions
# Recursive refinement
for i in range(num_recursions):
# Apply recursive transformation
latent = self.recursive_block(latent, iteration=i)
# Generate intermediate prediction
if self.config.deep_supervision and i < len(self.supervision_heads):
pred = self.supervision_heads[i](latent)
else:
pred = self.output_head(latent)
intermediate_predictions.append(pred)
# Check convergence
if check_convergence and i >= self.config.min_recursions:
residual = latent - previous_latent
residual_norm = torch.norm(residual, p=2, dim=-1).mean().item()
residual_norms.append(residual_norm)
if residual_norm < self.config.convergence_threshold:
converged = True
convergence_step = i + 1
break
previous_latent = latent.clone()
# Final prediction
final_pred = intermediate_predictions[-1]
return TRMOutput(
final_prediction=final_pred,
intermediate_predictions=intermediate_predictions,
recursion_depth=len(intermediate_predictions),
converged=converged,
convergence_step=convergence_step,
residual_norms=residual_norms,
)
async def refine_solution(
self,
initial_prediction: torch.Tensor,
num_recursions: int | None = None,
convergence_threshold: float | None = None,
) -> tuple[torch.Tensor, dict]:
"""
Refine an initial prediction through recursive processing.
Args:
initial_prediction: Initial solution [batch, ..., latent_dim]
num_recursions: Maximum recursions (optional)
convergence_threshold: Convergence threshold (optional)
Returns:
refined_solution: Final refined prediction
info: Dictionary with refinement metadata
"""
# Temporarily override convergence threshold if provided
original_threshold = self.config.convergence_threshold
if convergence_threshold is not None:
self.config.convergence_threshold = convergence_threshold
# Process
output = self.forward(
initial_prediction,
num_recursions=num_recursions,
check_convergence=True,
)
# Restore original threshold
self.config.convergence_threshold = original_threshold
info = {
"converged": output.converged,
"convergence_step": output.convergence_step,
"total_recursions": output.recursion_depth,
"final_residual": output.residual_norms[-1] if output.residual_norms else None,
"refinement_path": output.residual_norms,
}
return output.final_prediction, info
def get_parameter_count(self) -> int:
"""Return total number of trainable parameters."""
return sum(p.numel() for p in self.parameters() if p.requires_grad)
class TRMLoss(nn.Module):
"""
Deep supervision loss for TRM.
Applies weighted supervision at all recursion levels,
with exponential decay for deeper levels.
"""
def __init__(
self,
task_loss_fn: nn.Module,
supervision_weight_decay: float = 0.5,
final_weight: float = 1.0,
):
"""
Initialize TRM loss.
Args:
task_loss_fn: Base loss function (e.g., MSE, CrossEntropy)
supervision_weight_decay: Decay factor for intermediate losses
final_weight: Weight for final prediction loss
"""
super().__init__()
self.task_loss_fn = task_loss_fn
self.supervision_weight_decay = supervision_weight_decay
self.final_weight = final_weight
def forward(self, trm_output: TRMOutput, targets: torch.Tensor) -> tuple[torch.Tensor, dict]:
"""
Compute deep supervision loss.
Args:
trm_output: Output from TRM forward pass
targets: Ground truth targets
Returns:
total_loss: Combined loss
loss_dict: Dictionary of loss components
"""
# Final prediction loss (highest weight)
final_loss = self.task_loss_fn(trm_output.final_prediction, targets)
total_loss = self.final_weight * final_loss
# Intermediate supervision losses
intermediate_losses = []
num_intermediate = len(trm_output.intermediate_predictions) - 1
for i, pred in enumerate(trm_output.intermediate_predictions[:-1]):
# Exponential decay: earlier predictions get lower weight
weight = self.supervision_weight_decay ** (num_intermediate - i)
loss = self.task_loss_fn(pred, targets)
intermediate_losses.append(loss.item())
total_loss = total_loss + weight * loss
loss_dict = {
"total": total_loss.item(),
"final": final_loss.item(),
"intermediate_mean": (sum(intermediate_losses) / len(intermediate_losses) if intermediate_losses else 0.0),
"recursion_depth": trm_output.recursion_depth,
"converged": trm_output.converged,
"convergence_step": trm_output.convergence_step,
}
return total_loss, loss_dict
def create_trm_agent(config: TRMConfig, output_dim: int | None = None, device: str = "cpu") -> TRMAgent:
"""
Factory function to create and initialize TRM agent.
Args:
config: TRM configuration
output_dim: Output dimension (defaults to latent_dim)
device: Device to place model on
Returns:
Initialized TRMAgent
"""
agent = TRMAgent(config, output_dim, device)
# Initialize weights with Xavier/He initialization
def init_weights(m):
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.zeros_(m.bias)
agent.apply(init_weights)
return agent
# Utility functions for integration
class TRMRefinementWrapper:
"""
Wrapper for using TRM as a refinement step in pipelines.
Provides a clean interface for integrating TRM into larger systems.
"""
def __init__(self, trm_agent: TRMAgent, device: str = "cpu"):
self.trm_agent = trm_agent
self.device = device
self.trm_agent.eval()
@torch.no_grad()
async def refine(
self,
predictions: torch.Tensor,
num_iterations: int = 10,
return_path: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
"""
Refine predictions using TRM.
Args:
predictions: Initial predictions to refine
num_iterations: Number of refinement iterations
return_path: Whether to return intermediate predictions
Returns:
refined_predictions or (refined_predictions, refinement_path)
"""
# Ensure predictions are on correct device
predictions = predictions.to(self.device)
# Run TRM
output = self.trm_agent(predictions, num_recursions=num_iterations, check_convergence=True)
if return_path:
return output.final_prediction, output.intermediate_predictions
return output.final_prediction
def get_refinement_stats(self, predictions: torch.Tensor) -> dict:
"""Get statistics about the refinement process."""
with torch.no_grad():
output = self.trm_agent(predictions, check_convergence=True)
return {
"converged": output.converged,
"steps_to_convergence": output.convergence_step,
"final_residual": (output.residual_norms[-1] if output.residual_norms else None),
"total_refinement_iterations": output.recursion_depth,
}