eye_disease_classifier / model_file.py
alexakup05's picture
Create model_file.py
dedea34 verified
import torch
import torch.nn as nn
from torchvision import models
from transformers import PreTrainedModel, AutoConfig
class EyeDiseaseEfficientNetConfig(AutoConfig):
model_type = "EyeDiseaseEfficientNet"
def __init__(self, num_labels=8, **kwargs):
super().__init__(**kwargs)
self.num_labels = num_labels
class EyeDiseaseEfficientNet(PreTrainedModel):
config_class = EyeDiseaseEfficientNetConfig
def __init__(self, config):
super().__init__(config)
self.efficientnet = models.efficientnet_b4(pretrained=True)
self.efficientnet.classifier = nn.Identity()
for param in self.efficientnet.features[-2:].parameters():
param.requires_grad = True
self.fc_age_sex = nn.Sequential(
nn.Linear(2, 64),
nn.ReLU(),
nn.Dropout(0.5)
)
self.fc_combined = nn.Sequential(
nn.Linear(1792 + 64, 512),
nn.ReLU(),
nn.Dropout(0.6),
nn.Linear(512, config.num_labels)
)
def forward(self, x_img, x_age_sex):
x_img = self.efficientnet(x_img)
x_age_sex = self.fc_age_sex(x_age_sex)
x = torch.cat((x_img, x_age_sex), dim=1)
x = self.fc_combined(x)
return x