Spaces:
Running
Running
| """ | |
| Policy-Value Network using ResNet Architecture. | |
| Implements the dual-head neural network used in AlphaZero: | |
| - Policy Head: Outputs action probabilities | |
| - Value Head: Outputs state value estimation | |
| Based on: | |
| - "Mastering Chess and Shogi by Self-Play with a General RL Algorithm" (AlphaZero) | |
| - Deep Residual Learning for Image Recognition (ResNet) | |
| """ | |
| from __future__ import annotations | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from ..training.system_config import NeuralNetworkConfig | |
| class ResidualBlock(nn.Module): | |
| """ | |
| Residual block with batch normalization and skip connections. | |
| Architecture: | |
| Conv -> BN -> ReLU -> Conv -> BN -> Add -> ReLU | |
| """ | |
| def __init__(self, channels: int, use_batch_norm: bool = True): | |
| super().__init__() | |
| self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False) | |
| self.bn1 = nn.BatchNorm2d(channels) if use_batch_norm else nn.Identity() | |
| self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False) | |
| self.bn2 = nn.BatchNorm2d(channels) if use_batch_norm else nn.Identity() | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """Apply residual block transformation.""" | |
| residual = x | |
| out = self.conv1(x) | |
| out = self.bn1(out) | |
| out = F.relu(out) | |
| out = self.conv2(out) | |
| out = self.bn2(out) | |
| # Skip connection | |
| out = out + residual | |
| out = F.relu(out) | |
| return out | |
| class PolicyHead(nn.Module): | |
| """ | |
| Policy head for outputting action probabilities. | |
| Architecture: | |
| Conv -> BN -> ReLU -> FC -> LogSoftmax | |
| """ | |
| def __init__( | |
| self, | |
| input_channels: int, | |
| policy_conv_channels: int, | |
| action_size: int, | |
| board_size: int = 19, | |
| ): | |
| super().__init__() | |
| self.conv = nn.Conv2d(input_channels, policy_conv_channels, kernel_size=1, bias=False) | |
| self.bn = nn.BatchNorm2d(policy_conv_channels) | |
| # Assuming square board | |
| fc_input_size = policy_conv_channels * board_size * board_size | |
| self.fc = nn.Linear(fc_input_size, action_size) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Compute policy (action probabilities). | |
| Args: | |
| x: [batch, channels, height, width] | |
| Returns: | |
| Log probabilities: [batch, action_size] | |
| """ | |
| batch_size = x.size(0) | |
| out = self.conv(x) | |
| out = self.bn(out) | |
| out = F.relu(out) | |
| # Flatten spatial dimensions | |
| out = out.view(batch_size, -1) | |
| # Fully connected layer | |
| out = self.fc(out) | |
| # Log probabilities for numerical stability | |
| return F.log_softmax(out, dim=1) | |
| class ValueHead(nn.Module): | |
| """ | |
| Value head for estimating state value. | |
| Architecture: | |
| Conv -> BN -> ReLU -> FC -> ReLU -> FC -> Tanh | |
| """ | |
| def __init__( | |
| self, | |
| input_channels: int, | |
| value_conv_channels: int, | |
| value_fc_hidden: int, | |
| board_size: int = 19, | |
| ): | |
| super().__init__() | |
| self.conv = nn.Conv2d(input_channels, value_conv_channels, kernel_size=1, bias=False) | |
| self.bn = nn.BatchNorm2d(value_conv_channels) | |
| # Assuming square board | |
| fc_input_size = value_conv_channels * board_size * board_size | |
| self.fc1 = nn.Linear(fc_input_size, value_fc_hidden) | |
| self.fc2 = nn.Linear(value_fc_hidden, 1) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Compute value estimation. | |
| Args: | |
| x: [batch, channels, height, width] | |
| Returns: | |
| Value: [batch, 1] in range [-1, 1] | |
| """ | |
| batch_size = x.size(0) | |
| out = self.conv(x) | |
| out = self.bn(out) | |
| out = F.relu(out) | |
| # Flatten spatial dimensions | |
| out = out.view(batch_size, -1) | |
| # Fully connected layers | |
| out = self.fc1(out) | |
| out = F.relu(out) | |
| out = self.fc2(out) | |
| # Tanh to bound value in [-1, 1] | |
| return torch.tanh(out) | |
| class PolicyValueNetwork(nn.Module): | |
| """ | |
| Combined policy-value network with ResNet backbone. | |
| This is the core neural network used in AlphaZero-style learning. | |
| """ | |
| def __init__(self, config: NeuralNetworkConfig, board_size: int = 19): | |
| super().__init__() | |
| self.config = config | |
| self.board_size = board_size | |
| # Initial convolution | |
| self.conv_input = nn.Conv2d( | |
| config.input_channels, | |
| config.num_channels, | |
| kernel_size=3, | |
| padding=1, | |
| bias=False, | |
| ) | |
| self.bn_input = nn.BatchNorm2d(config.num_channels) if config.use_batch_norm else nn.Identity() | |
| # Residual blocks (shared feature extractor) | |
| self.res_blocks = nn.ModuleList( | |
| [ResidualBlock(config.num_channels, config.use_batch_norm) for _ in range(config.num_res_blocks)] | |
| ) | |
| # Policy head | |
| self.policy_head = PolicyHead( | |
| input_channels=config.num_channels, | |
| policy_conv_channels=config.policy_conv_channels, | |
| action_size=config.action_size, | |
| board_size=board_size, | |
| ) | |
| # Value head | |
| self.value_head = ValueHead( | |
| input_channels=config.num_channels, | |
| value_conv_channels=config.value_conv_channels, | |
| value_fc_hidden=config.value_fc_hidden, | |
| board_size=board_size, | |
| ) | |
| def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Forward pass through the network. | |
| Args: | |
| x: Input state [batch, channels, height, width] | |
| Returns: | |
| (policy_logits, value) tuple | |
| - policy_logits: [batch, action_size] log probabilities | |
| - value: [batch, 1] state value in [-1, 1] | |
| """ | |
| # Initial convolution | |
| out = self.conv_input(x) | |
| out = self.bn_input(out) | |
| out = F.relu(out) | |
| # Residual blocks | |
| for res_block in self.res_blocks: | |
| out = res_block(out) | |
| # Split into policy and value heads | |
| policy = self.policy_head(out) | |
| value = self.value_head(out) | |
| return policy, value | |
| def predict(self, state: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Inference mode prediction. | |
| Args: | |
| state: Input state tensor | |
| Returns: | |
| (policy_probs, value) tuple with probabilities (not log) | |
| """ | |
| with torch.no_grad(): | |
| policy_log_probs, value = self.forward(state) | |
| policy_probs = torch.exp(policy_log_probs) | |
| return policy_probs, value | |
| def get_parameter_count(self) -> int: | |
| """Return total number of trainable parameters.""" | |
| return sum(p.numel() for p in self.parameters() if p.requires_grad) | |
| class AlphaZeroLoss(nn.Module): | |
| """ | |
| Combined loss function for AlphaZero training. | |
| Loss = (z - v)^2 - π^T log(p) + c||θ||^2 | |
| Where: | |
| - z: actual game outcome | |
| - v: value prediction | |
| - π: MCTS visit count distribution | |
| - p: policy prediction | |
| - c: L2 regularization coefficient | |
| """ | |
| def __init__(self, value_loss_weight: float = 1.0): | |
| super().__init__() | |
| self.value_loss_weight = value_loss_weight | |
| def forward( | |
| self, | |
| policy_logits: torch.Tensor, | |
| value: torch.Tensor, | |
| target_policy: torch.Tensor, | |
| target_value: torch.Tensor, | |
| ) -> tuple[torch.Tensor, dict]: | |
| """ | |
| Compute AlphaZero loss. | |
| Args: | |
| policy_logits: Predicted policy log probabilities [batch, action_size] | |
| value: Predicted values [batch, 1] | |
| target_policy: Target policy from MCTS [batch, action_size] | |
| target_value: Target value from game outcome [batch, 1] | |
| Returns: | |
| (total_loss, loss_dict) tuple | |
| """ | |
| # Value loss: MSE between predicted and actual outcome | |
| value_loss = F.mse_loss(value.squeeze(-1), target_value) | |
| # Policy loss: Cross-entropy between MCTS policy and network policy | |
| # Target policy is already normalized, policy_logits are log probabilities | |
| policy_loss = -torch.sum(target_policy * policy_logits, dim=1).mean() | |
| # Combined loss | |
| total_loss = self.value_loss_weight * value_loss + policy_loss | |
| loss_dict = { | |
| "total": total_loss.item(), | |
| "value": value_loss.item(), | |
| "policy": policy_loss.item(), | |
| } | |
| return total_loss, loss_dict | |
| def create_policy_value_network( | |
| config: NeuralNetworkConfig, | |
| board_size: int = 19, | |
| device: str = "cpu", | |
| ) -> PolicyValueNetwork: | |
| """ | |
| Factory function to create and initialize policy-value network. | |
| Args: | |
| config: Network configuration | |
| board_size: Board/grid size (for games) | |
| device: Device to place model on | |
| Returns: | |
| Initialized PolicyValueNetwork | |
| """ | |
| network = PolicyValueNetwork(config, board_size) | |
| # He initialization for convolutional layers | |
| def init_weights(m): | |
| if isinstance(m, nn.Conv2d): | |
| nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") | |
| elif isinstance(m, nn.BatchNorm2d): | |
| nn.init.constant_(m.weight, 1) | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.Linear): | |
| nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| network.apply(init_weights) | |
| network = network.to(device) | |
| return network | |
| # Example: Simpler MLP-based policy-value network for non-spatial tasks | |
| class MLPPolicyValueNetwork(nn.Module): | |
| """ | |
| MLP-based policy-value network for non-spatial state representations. | |
| Useful for tasks where state is not naturally represented as an image. | |
| """ | |
| def __init__( | |
| self, | |
| state_dim: int, | |
| action_size: int, | |
| hidden_dims: list[int] | None = None, | |
| use_batch_norm: bool = True, | |
| dropout: float = 0.1, | |
| ): | |
| super().__init__() | |
| self.state_dim = state_dim | |
| self.action_size = action_size | |
| if hidden_dims is None: | |
| hidden_dims = [512, 256] | |
| # Shared feature extractor | |
| layers = [] | |
| prev_dim = state_dim | |
| for hidden_dim in hidden_dims: | |
| layers.append(nn.Linear(prev_dim, hidden_dim)) | |
| if use_batch_norm: | |
| layers.append(nn.BatchNorm1d(hidden_dim)) | |
| layers.append(nn.ReLU()) | |
| if dropout > 0: | |
| layers.append(nn.Dropout(dropout)) | |
| prev_dim = hidden_dim | |
| self.shared_network = nn.Sequential(*layers) | |
| # Policy head | |
| self.policy_head = nn.Sequential( | |
| nn.Linear(prev_dim, prev_dim // 2), | |
| nn.ReLU(), | |
| nn.Linear(prev_dim // 2, action_size), | |
| ) | |
| # Value head | |
| self.value_head = nn.Sequential( | |
| nn.Linear(prev_dim, prev_dim // 2), | |
| nn.ReLU(), | |
| nn.Linear(prev_dim // 2, 1), | |
| nn.Tanh(), | |
| ) | |
| def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Forward pass. | |
| Args: | |
| x: Input state [batch, state_dim] | |
| Returns: | |
| (policy_log_probs, value) tuple | |
| """ | |
| # Shared features | |
| features = self.shared_network(x) | |
| # Policy | |
| policy_logits = self.policy_head(features) | |
| policy_log_probs = F.log_softmax(policy_logits, dim=1) | |
| # Value | |
| value = self.value_head(features) | |
| return policy_log_probs, value | |
| def get_parameter_count(self) -> int: | |
| """Return total number of trainable parameters.""" | |
| return sum(p.numel() for p in self.parameters() if p.requires_grad) | |