| from torchvision.models import resnet50, ResNet50_Weights | |
| from transformers import PreTrainedModel | |
| from .config import ResnetConfig | |
| import torch.nn as nn | |
| class ResNet50(nn.Module): | |
| def __init__(self, ): | |
| super().__init__() | |
| self.cnn = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) | |
| self.backbone = nn.Sequential(*list(self.cnn.children())[:-2]) | |
| self.flaten = nn.Sequential(nn.AvgPool2d(kernel_size=7), nn.Flatten()) | |
| self.fc_1 = nn.Linear(2048, 768) | |
| def forward(self, x): | |
| if len(x.shape) == 3: | |
| x = x.unsqueeze(0) | |
| x = self.backbone(x) | |
| x = self.flaten(x) | |
| x = self.fc_1(x) | |
| x = x.squeeze(0) | |
| return x | |
| class ResNet50AffectiveFeatureExtractor(PreTrainedModel): | |
| config_class = ResnetConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.model = ResNet50() | |
| del self.model.cnn | |
| def forward(self, tensor): | |
| return self.model(tensor) |