Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import nn | |
| import lightning as L | |
| import torch.nn.functional as F | |
| from torch import optim | |
| from torchmetrics import Accuracy | |
| from torch.optim.lr_scheduler import ReduceLROnPlateau | |
| class PetClassificationModel(L.LightningModule): | |
| def __init__(self, base_model, config): | |
| super().__init__() | |
| self.config = config | |
| self.num_classes = len(self.config.idx_to_class) | |
| metric = Accuracy(task="multiclass", num_classes=self.num_classes) | |
| self.train_acc = metric.clone() | |
| self.val_acc = metric.clone() | |
| self.test_acc = metric.clone() | |
| self.training_step_outputs = [] | |
| self.validation_step_outputs = [] | |
| self.test_step_outputs = [] | |
| self.device_ = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.pretrained_model = base_model | |
| out_features = self.pretrained_model.get_classifier().out_features | |
| self.custom_layers = nn.Sequential( | |
| nn.Linear(out_features, 512, device = self.device_), | |
| nn.ReLU(), | |
| nn.Dropout(), | |
| nn.Linear(512, self.num_classes, device = self.device_), | |
| ) | |
| def forward(self, x): | |
| x = self.pretrained_model(x) | |
| #x = self.custom_layers(x) | |
| return x | |
| def training_step(self, batch, batch_idx): | |
| x,y = batch | |
| logits = self.forward(x) # -> logits | |
| loss = F.cross_entropy(logits, y) | |
| self.log_dict({'train_loss': loss}) | |
| self.training_step_outputs.append({'loss': loss, 'logits': logits, 'y':y}) | |
| return loss | |
| def on_train_epoch_end(self): | |
| # Concat batches | |
| outputs = self.training_step_outputs | |
| logits = torch.cat([x['logits'] for x in outputs]) | |
| y = torch.cat([x['y'] for x in outputs]) | |
| self.train_acc(logits, y) | |
| self.log_dict({ | |
| 'train_acc': self.train_acc, | |
| }, | |
| on_step = False, | |
| on_epoch = True, | |
| prog_bar = True) | |
| self.training_step_outputs.clear() | |
| def validation_step(self, batch, batch_idx): | |
| x,y = batch | |
| logits = self.forward(x) | |
| loss = F.cross_entropy(logits, y) | |
| self.log_dict({'val_loss': loss}) | |
| self.validation_step_outputs.append({'loss': loss, 'logits': logits, 'y':y}) | |
| return loss | |
| def on_validation_epoch_end(self): | |
| # Concat batches | |
| outputs = self.validation_step_outputs | |
| logits = torch.cat([x['logits'] for x in outputs]) | |
| y = torch.cat([x['y'] for x in outputs]) | |
| self.val_acc(logits, y) | |
| self.log_dict({ | |
| 'val_acc': self.val_acc, | |
| }, | |
| on_step = False, | |
| on_epoch = True, | |
| prog_bar = True) | |
| self.validation_step_outputs.clear() | |
| def test_step(self, batch, batch_idx): | |
| x,y = batch | |
| logits = self.forward(x) | |
| loss = F.cross_entropy(logits, y) | |
| self.log_dict({'test_loss': loss}) | |
| self.test_step_outputs.append({'loss': loss, 'logits': logits, 'y':y}) | |
| return loss | |
| def on_test_epoch_end(self): | |
| # Concat batches | |
| outputs = self.test_step_outputs | |
| logits = torch.cat([x['logits'] for x in outputs]) | |
| y = torch.cat([x['y'] for x in outputs]) | |
| self.test_acc(logits, y) | |
| self.log_dict({ | |
| 'test_acc': self.test_acc, | |
| }, | |
| on_step = False, | |
| on_epoch = True, | |
| prog_bar = True) | |
| self.test_step_outputs.clear() | |
| def predict_step(self, batch): | |
| x, y = batch | |
| return self.model(x, y) | |
| def configure_optimizers(self): | |
| optimizer = optim.Adam(self.parameters(), lr=self.config.LEARNING_RATE) | |
| lr_scheduler = ReduceLROnPlateau(optimizer, mode = 'min', patience = 3) | |
| lr_scheduler_dict = { | |
| "scheduler": lr_scheduler, | |
| "interval": "epoch", | |
| "monitor": "val_loss", | |
| } | |
| return {'optimizer': optimizer, 'lr_scheduler': lr_scheduler_dict} | |