Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import Tensor | |
| from torch.nn import Parameter | |
| from typing import Any | |
| from layers.independent_mlp import IndependentMLPs | |
| # Baseline model, a modified convnext with reduced downsampling for a spatially larger feature tensor in the last layer | |
| class IndividualLandmarkConvNext(torch.nn.Module): | |
| def __init__(self, init_model: torch.nn.Module, num_landmarks: int = 8, | |
| num_classes: int = 200, sl_channels: int = 1024, fl_channels: int = 2048, part_dropout: float = 0.3, | |
| modulation_type: str = "original", modulation_orth: bool = False, gumbel_softmax: bool = False, | |
| gumbel_softmax_temperature: float = 1.0, gumbel_softmax_hard: bool = False, | |
| classifier_type: str = "linear", noise_variance: float = 0.0) -> None: | |
| super().__init__() | |
| self.num_landmarks = num_landmarks | |
| self.num_classes = num_classes | |
| self.noise_variance = noise_variance | |
| self.stem = init_model.stem | |
| self.stages = init_model.stages | |
| self.feature_dim = sl_channels + fl_channels | |
| self.fc_landmarks = torch.nn.Conv2d(self.feature_dim, num_landmarks + 1, 1, bias=False) | |
| self.gumbel_softmax = gumbel_softmax | |
| self.gumbel_softmax_temperature = gumbel_softmax_temperature | |
| self.gumbel_softmax_hard = gumbel_softmax_hard | |
| self.modulation_type = modulation_type | |
| if modulation_type == "layer_norm": | |
| self.modulation = torch.nn.LayerNorm([self.feature_dim, self.num_landmarks + 1]) | |
| elif modulation_type == "original": | |
| self.modulation = torch.nn.Parameter(torch.ones(1, self.feature_dim, self.num_landmarks + 1)) | |
| elif modulation_type == "parallel_mlp": | |
| self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim, | |
| num_lin_layers=1, act_layer=True, bias=True) | |
| elif modulation_type == "parallel_mlp_no_bias": | |
| self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim, | |
| num_lin_layers=1, act_layer=True, bias=False) | |
| elif modulation_type == "parallel_mlp_no_act": | |
| self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim, | |
| num_lin_layers=1, act_layer=False, bias=True) | |
| elif modulation_type == "parallel_mlp_no_act_no_bias": | |
| self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim, | |
| num_lin_layers=1, act_layer=False, bias=False) | |
| elif modulation_type == "none": | |
| self.modulation = torch.nn.Identity() | |
| else: | |
| raise ValueError("modulation_type not implemented") | |
| self.modulation_orth = modulation_orth | |
| self.dropout_full_landmarks = torch.nn.Dropout1d(part_dropout) | |
| self.classifier_type = classifier_type | |
| if classifier_type == "independent_mlp": | |
| self.fc_class_landmarks = IndependentMLPs(part_dim=self.num_landmarks, latent_dim=self.feature_dim, | |
| num_lin_layers=1, act_layer=False, out_dim=num_classes, | |
| bias=False, stack_dim=1) | |
| elif classifier_type == "linear": | |
| self.fc_class_landmarks = torch.nn.Linear(in_features=self.feature_dim, out_features=num_classes, | |
| bias=False) | |
| else: | |
| raise ValueError("classifier_type not implemented") | |
| def forward(self, x: Tensor) -> tuple[Any, Any, Any, Any, Parameter, int | Any]: | |
| # Pretrained ConvNeXt part of the model | |
| x = self.stem(x) | |
| x = self.stages[0](x) | |
| x = self.stages[1](x) | |
| l3 = self.stages[2](x) | |
| x = self.stages[3](l3) | |
| x = torch.nn.functional.interpolate(x, size=(l3.shape[-2], l3.shape[-1]), mode='bilinear', align_corners=False) | |
| x = torch.cat((x, l3), dim=1) | |
| # Compute per landmark attention maps | |
| # (b - a)^2 = b^2 - 2ab + a^2, b = feature maps resnet, a = convolution kernel | |
| batch_size = x.shape[0] | |
| ab = self.fc_landmarks(x) | |
| b_sq = x.pow(2).sum(1, keepdim=True) | |
| b_sq = b_sq.expand(-1, self.num_landmarks + 1, -1, -1).contiguous() | |
| a_sq = self.fc_landmarks.weight.pow(2).sum(1).unsqueeze(1).expand(-1, batch_size, x.shape[-2], | |
| x.shape[-1]).contiguous() | |
| a_sq = a_sq.permute(1, 0, 2, 3).contiguous() | |
| dist = b_sq - 2 * ab + a_sq | |
| maps = -dist | |
| # Softmax so that the attention maps for each pixel add up to 1 | |
| if self.gumbel_softmax: | |
| maps = torch.nn.functional.gumbel_softmax(maps, dim=1, tau=self.gumbel_softmax_temperature, | |
| hard=self.gumbel_softmax_hard) # [B, num_landmarks + 1, H, W] | |
| else: | |
| maps = torch.nn.functional.softmax(maps, dim=1) # [B, num_landmarks + 1, H, W] | |
| # Use maps to get weighted average features per landmark | |
| all_features = (maps.unsqueeze(1) * x.unsqueeze(2)).mean(-1).mean(-1).contiguous() | |
| if self.noise_variance > 0.0: | |
| all_features += torch.randn_like(all_features, | |
| device=all_features.device) * x.std().detach() * self.noise_variance | |
| # Modulate the features | |
| if self.modulation_type == "original": | |
| all_features_mod = all_features * self.modulation | |
| else: | |
| all_features_mod = self.modulation(all_features) | |
| # Classification based on the landmark features | |
| scores = self.fc_class_landmarks( | |
| self.dropout_full_landmarks(all_features_mod[..., :-1].permute(0, 2, 1).contiguous())).permute(0, 2, | |
| 1).contiguous() | |
| if self.modulation_orth: | |
| return all_features_mod, maps, scores, dist | |
| else: | |
| return all_features, maps, scores, dist | |