Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| class MultiClassFocalLossWithAlpha(nn.Module): | |
| def __init__(self, num_classes, alpha=None, gamma=1, reduction='mean', device="cuda"): | |
| super(MultiClassFocalLossWithAlpha, self).__init__() | |
| if alpha is None: | |
| self.alpha = torch.ones(num_classes, dtype=torch.float32) | |
| self.alpha = torch.tensor(alpha).to(device) | |
| self.gamma = gamma | |
| self.reduction = reduction | |
| def forward(self, pred, target): | |
| alpha = self.alpha[target] | |
| log_softmax = torch.log_softmax(pred, dim=1) | |
| logpt = torch.gather(log_softmax, dim=1, index=target.view(-1, 1)) | |
| logpt = logpt.view(-1) | |
| ce_loss = -logpt | |
| pt = torch.exp(logpt) | |
| focal_loss = alpha * (1 - pt) ** self.gamma * ce_loss | |
| if self.reduction == "mean": | |
| return torch.mean(focal_loss) | |
| if self.reduction == "sum": | |
| return torch.sum(focal_loss) | |
| return focal_loss |