| import torch | |
| import torch.nn as nn | |
| from torchvision import models | |
| from transformers import PreTrainedModel, AutoConfig | |
| class EyeDiseaseEfficientNetConfig(AutoConfig): | |
| model_type = "EyeDiseaseEfficientNet" | |
| def __init__(self, num_labels=8, **kwargs): | |
| super().__init__(**kwargs) | |
| self.num_labels = num_labels | |
| class EyeDiseaseEfficientNet(PreTrainedModel): | |
| config_class = EyeDiseaseEfficientNetConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.efficientnet = models.efficientnet_b4(pretrained=True) | |
| self.efficientnet.classifier = nn.Identity() | |
| for param in self.efficientnet.features[-2:].parameters(): | |
| param.requires_grad = True | |
| self.fc_age_sex = nn.Sequential( | |
| nn.Linear(2, 64), | |
| nn.ReLU(), | |
| nn.Dropout(0.5) | |
| ) | |
| self.fc_combined = nn.Sequential( | |
| nn.Linear(1792 + 64, 512), | |
| nn.ReLU(), | |
| nn.Dropout(0.6), | |
| nn.Linear(512, config.num_labels) | |
| ) | |
| def forward(self, x_img, x_age_sex): | |
| x_img = self.efficientnet(x_img) | |
| x_age_sex = self.fc_age_sex(x_age_sex) | |
| x = torch.cat((x_img, x_age_sex), dim=1) | |
| x = self.fc_combined(x) | |
| return x |