Spaces:
Sleeping
Sleeping
| # This file contains the implementation of the IndependentMLPs class | |
| import torch | |
| class IndependentMLPs(torch.nn.Module): | |
| """ | |
| This class implements the MLP used for classification with the option to use an additional independent MLP layer | |
| """ | |
| def __init__(self, part_dim, latent_dim, bias=False, num_lin_layers=1, act_layer=True, out_dim=None, stack_dim=-1): | |
| """ | |
| :param part_dim: Number of parts | |
| :param latent_dim: Latent dimension | |
| :param bias: Whether to use bias | |
| :param num_lin_layers: Number of linear layers | |
| :param act_layer: Whether to use activation layer | |
| :param out_dim: Output dimension (default: None) | |
| :param stack_dim: Dimension to stack the outputs (default: -1) | |
| """ | |
| super().__init__() | |
| self.bias = bias | |
| self.latent_dim = latent_dim | |
| if out_dim is None: | |
| out_dim = latent_dim | |
| self.out_dim = out_dim | |
| self.part_dim = part_dim | |
| self.stack_dim = stack_dim | |
| layer_stack = torch.nn.ModuleList() | |
| for i in range(part_dim): | |
| layer_stack.append(torch.nn.Sequential()) | |
| for j in range(num_lin_layers): | |
| layer_stack[i].add_module(f"fc_{j}", torch.nn.Linear(latent_dim, self.out_dim, bias=bias)) | |
| if act_layer: | |
| layer_stack[i].add_module(f"act_{j}", torch.nn.GELU()) | |
| self.feature_layers = layer_stack | |
| self.reset_weights() | |
| def __repr__(self): | |
| return f"IndependentMLPs(part_dim={self.part_dim}, latent_dim={self.latent_dim}), bias={self.bias}" | |
| def reset_weights(self): | |
| """ Initialize weights with a identity matrix""" | |
| for layer in self.feature_layers: | |
| for m in layer.modules(): | |
| if isinstance(m, torch.nn.Linear): | |
| # Initialize weights with a truncated normal distribution | |
| torch.nn.init.trunc_normal_(m.weight, std=0.02) | |
| if m.bias is not None: | |
| torch.nn.init.zeros_(m.bias) | |
| def forward(self, x): | |
| """ Input X has the dimensions batch x latent_dim x part_dim """ | |
| outputs = [] | |
| for i, layer in enumerate(self.feature_layers): | |
| if self.stack_dim == -1: | |
| in_ = x[..., i] | |
| else: | |
| in_ = x[:, i, ...] # Select feature i | |
| out = layer(in_) # Apply MLP to feature i | |
| outputs.append(out) | |
| x = torch.stack(outputs, dim=self.stack_dim) # Stack the outputs | |
| return x | |