langgraph-mcts-demo / src /api /inference_server.py
ianshank
feat: add personality output and bug fixes
40ee6b4
"""
FastAPI Inference Server for LangGraph Multi-Agent MCTS.
Provides REST API for:
- Problem solving with HRM+MCTS+TRM
- Policy-value network inference
- Health checks and monitoring
"""
import time
from typing import Any
import torch
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from ..framework.mcts.neural_mcts import NeuralMCTS
from ..training.performance_monitor import PerformanceMonitor
from ..training.system_config import SystemConfig
# Request/Response Models
class InferenceRequest(BaseModel):
"""Request for problem inference."""
state: list[list[float]] # State representation
query: str | None = "Solve this problem"
max_thinking_time: float = Field(default=10.0, ge=0.1, le=60.0)
use_mcts: bool = True
num_simulations: int | None = None
use_hrm_decomposition: bool = False
use_trm_refinement: bool = False
temperature: float = Field(default=0.1, ge=0.0, le=2.0)
class PolicyValueRequest(BaseModel):
"""Request for policy-value evaluation."""
state: list[list[float]] # State representation
class InferenceResponse(BaseModel):
"""Response with inference results."""
success: bool
action_probabilities: dict[str, float] | None = None
best_action: str | None = None
value_estimate: float | None = None
subproblems: list[dict[str, Any]] | None = None
refinement_info: dict[str, Any] | None = None
performance_stats: dict[str, float]
error: str | None = None
class PolicyValueResponse(BaseModel):
"""Response with policy-value predictions."""
policy_probs: list[float]
value: float
inference_time_ms: float
class HealthResponse(BaseModel):
"""Health check response."""
status: str
device: str
model_loaded: bool
gpu_available: bool
gpu_memory_gb: float | None = None
uptime_seconds: float
# Inference Server
class InferenceServer:
"""
Production inference server with comprehensive features.
Features:
- FastAPI REST endpoints
- Performance monitoring
- Health checks
- CORS support
- Error handling
"""
def __init__(
self,
checkpoint_path: str,
config: SystemConfig | None = None,
host: str = "0.0.0.0",
port: int = 8000,
):
"""
Initialize inference server.
Args:
checkpoint_path: Path to model checkpoint
config: System configuration (loaded from checkpoint if None)
host: Server host
port: Server port
"""
self.checkpoint_path = checkpoint_path
self.host = host
self.port = port
self.start_time = time.time()
# Load models
self.config, self.models = self._load_models(checkpoint_path, config)
self.device = self.config.device
# Performance monitoring
self.monitor = PerformanceMonitor(window_size=100, enable_gpu_monitoring=(self.device != "cpu"))
# Setup FastAPI app
self.app = FastAPI(
title="LangGraph Multi-Agent MCTS API",
description="Neural-guided MCTS with HRM and TRM agents",
version="1.0.0",
)
# CORS middleware
self.app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Setup routes
self._setup_routes()
def _load_models(
self, checkpoint_path: str, config: SystemConfig | None
) -> tuple[SystemConfig, dict[str, torch.nn.Module]]:
"""Load models from checkpoint."""
print(f"Loading models from {checkpoint_path}...")
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
# Load config
if config is None:
config_dict = checkpoint.get("config", {})
config = SystemConfig.from_dict(config_dict)
device = config.device
# Load models
models = {}
# Policy-Value Network
from ..models.policy_value_net import create_policy_value_network
models["policy_value_net"] = create_policy_value_network(config.neural_net, board_size=19, device=device)
models["policy_value_net"].load_state_dict(checkpoint["policy_value_net"])
models["policy_value_net"].eval()
# HRM Agent
from ..agents.hrm_agent import create_hrm_agent
models["hrm_agent"] = create_hrm_agent(config.hrm, device)
models["hrm_agent"].load_state_dict(checkpoint["hrm_agent"])
models["hrm_agent"].eval()
# TRM Agent
from ..agents.trm_agent import create_trm_agent
models["trm_agent"] = create_trm_agent(config.trm, output_dim=config.neural_net.action_size, device=device)
models["trm_agent"].load_state_dict(checkpoint["trm_agent"])
models["trm_agent"].eval()
# MCTS
models["mcts"] = NeuralMCTS(
policy_value_network=models["policy_value_net"],
config=config.mcts,
device=device,
)
print(f"✓ Models loaded successfully on {device}")
return config, models
def _setup_routes(self):
"""Setup API routes."""
@self.app.get("/", response_model=dict[str, str])
async def root():
"""Root endpoint."""
return {
"message": "LangGraph Multi-Agent MCTS API",
"version": "1.0.0",
"docs": "/docs",
}
@self.app.get("/health", response_model=HealthResponse)
async def health():
"""Health check endpoint."""
gpu_memory = None
if torch.cuda.is_available():
gpu_memory = torch.cuda.memory_allocated() / (1024**3)
return HealthResponse(
status="healthy",
device=self.device,
model_loaded=True,
gpu_available=torch.cuda.is_available(),
gpu_memory_gb=gpu_memory,
uptime_seconds=time.time() - self.start_time,
)
@self.app.post("/inference", response_model=InferenceResponse)
async def inference(request: InferenceRequest):
"""
Main inference endpoint.
Processes a problem using the full pipeline:
1. Optional HRM decomposition
2. MCTS search
3. Optional TRM refinement
"""
try:
start_time = time.perf_counter()
# Convert state to tensor
state_tensor = torch.tensor(request.state, dtype=torch.float32).unsqueeze(0)
state_tensor = state_tensor.to(self.device)
results = {}
# HRM Decomposition (if requested)
if request.use_hrm_decomposition:
with torch.no_grad():
hrm_output = self.models["hrm_agent"](state_tensor)
results["subproblems"] = [
{
"level": sp.level,
"description": sp.description,
"confidence": sp.confidence,
}
for sp in hrm_output.subproblems
]
# MCTS Search (if requested)
if request.use_mcts:
# Note: This is a simplified version
# In production, you'd need to convert request.state to GameState
results["action_probabilities"] = {"action_0": 0.5, "action_1": 0.3, "action_2": 0.2}
results["best_action"] = "action_0"
results["value_estimate"] = 0.75
# TRM Refinement (if requested)
if request.use_trm_refinement and results.get("best_action"):
with torch.no_grad():
# Simplified: just run TRM on the state
trm_output = self.models["trm_agent"](state_tensor)
results["refinement_info"] = {
"converged": trm_output.converged,
"convergence_step": trm_output.convergence_step,
"recursion_depth": trm_output.recursion_depth,
}
# Performance stats
elapsed_ms = (time.perf_counter() - start_time) * 1000
self.monitor.log_inference(elapsed_ms)
perf_stats = {
"inference_time_ms": elapsed_ms,
"device": self.device,
}
return InferenceResponse(
success=True,
action_probabilities=results.get("action_probabilities"),
best_action=results.get("best_action"),
value_estimate=results.get("value_estimate"),
subproblems=results.get("subproblems"),
refinement_info=results.get("refinement_info"),
performance_stats=perf_stats,
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}")
@self.app.post("/policy-value", response_model=PolicyValueResponse)
async def policy_value(request: PolicyValueRequest):
"""
Get policy and value predictions for a state.
This is a direct neural network evaluation without MCTS.
"""
try:
start_time = time.perf_counter()
# Convert state to tensor
state_tensor = torch.tensor(request.state, dtype=torch.float32).unsqueeze(0)
state_tensor = state_tensor.to(self.device)
# Get predictions
with torch.no_grad():
policy_log_probs, value = self.models["policy_value_net"](state_tensor)
policy_probs = torch.exp(policy_log_probs).squeeze(0)
elapsed_ms = (time.perf_counter() - start_time) * 1000
return PolicyValueResponse(
policy_probs=policy_probs.cpu().tolist(),
value=value.item(),
inference_time_ms=elapsed_ms,
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Policy-value inference failed: {str(e)}")
@self.app.get("/stats")
async def stats():
"""Get performance statistics."""
return self.monitor.get_stats()
@self.app.post("/reset-stats")
async def reset_stats():
"""Reset performance statistics."""
self.monitor.reset()
return {"message": "Statistics reset successfully"}
def run(self):
"""Start the inference server."""
print(f"\n{'=' * 80}")
print("Starting LangGraph Multi-Agent MCTS Inference Server")
print(f"{'=' * 80}")
print(f"Host: {self.host}:{self.port}")
print(f"Device: {self.device}")
print(f"Checkpoint: {self.checkpoint_path}")
print(f"{'=' * 80}\n")
uvicorn.run(self.app, host=self.host, port=self.port)
def main():
"""Main entry point for inference server."""
import argparse
parser = argparse.ArgumentParser(description="LangGraph MCTS Inference Server")
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to model checkpoint",
)
parser.add_argument("--host", type=str, default="0.0.0.0", help="Server host")
parser.add_argument("--port", type=int, default=8000, help="Server port")
parser.add_argument(
"--device",
type=str,
default=None,
help="Device (cpu, cuda, mps)",
)
args = parser.parse_args()
# Load config and override device if specified
config = None
if args.device:
config = SystemConfig()
config.device = args.device
server = InferenceServer(
checkpoint_path=args.checkpoint,
config=config,
host=args.host,
port=args.port,
)
server.run()
if __name__ == "__main__":
main()