Spaces:
Runtime error
Runtime error
| # -*- coding: UTF-8 -*- | |
| '''================================================= | |
| @Project -> File pram -> layers | |
| @IDE PyCharm | |
| @Author fx221@cam.ac.uk | |
| @Date 29/01/2024 14:46 | |
| ==================================================''' | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from copy import deepcopy | |
| from einops import rearrange | |
| def MLP(channels: list, do_bn=True, ac_fn='relu', norm_fn='bn'): | |
| """ Multi-layer perceptron """ | |
| n = len(channels) | |
| layers = [] | |
| for i in range(1, n): | |
| layers.append( | |
| nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True)) | |
| if i < (n - 1): | |
| if norm_fn == 'in': | |
| layers.append(nn.InstanceNorm1d(channels[i], eps=1e-3)) | |
| elif norm_fn == 'bn': | |
| layers.append(nn.BatchNorm1d(channels[i], eps=1e-3)) | |
| if ac_fn == 'relu': | |
| layers.append(nn.ReLU()) | |
| elif ac_fn == 'gelu': | |
| layers.append(nn.GELU()) | |
| elif ac_fn == 'lrelu': | |
| layers.append(nn.LeakyReLU(negative_slope=0.1)) | |
| # if norm_fn == 'ln': | |
| # layers.append(nn.LayerNorm(channels[i])) | |
| return nn.Sequential(*layers) | |
| class MultiHeadedAttention(nn.Module): | |
| def __init__(self, num_heads: int, d_model: int): | |
| super().__init__() | |
| assert d_model % num_heads == 0 | |
| self.dim = d_model // num_heads | |
| self.num_heads = num_heads | |
| self.merge = nn.Conv1d(d_model, d_model, kernel_size=1) | |
| self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)]) | |
| def forward(self, query, key, value, M=None): | |
| ''' | |
| :param query: [B, D, N] | |
| :param key: [B, D, M] | |
| :param value: [B, D, M] | |
| :param M: [B, N, M] | |
| :return: | |
| ''' | |
| batch_dim = query.size(0) | |
| query, key, value = [l(x).view(batch_dim, self.dim, self.num_heads, -1) | |
| for l, x in zip(self.proj, (query, key, value))] # [B, D, NH, N] | |
| dim = query.shape[1] | |
| scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim ** .5 | |
| if M is not None: | |
| # print('M: ', scores.shape, M.shape, torch.sum(M, dim=2)) | |
| # scores = scores * M[:, None, :, :].expand_as(scores) | |
| # with torch.no_grad(): | |
| mask = (1 - M[:, None, :, :]).repeat(1, scores.shape[1], 1, 1).bool() # [B, H, N, M] | |
| scores = scores.masked_fill(mask, -torch.finfo(scores.dtype).max) | |
| prob = F.softmax(scores, dim=-1) # * (~mask).float() # * mask.float() | |
| else: | |
| prob = F.softmax(scores, dim=-1) | |
| x = torch.einsum('bhnm,bdhm->bdhn', prob, value) | |
| self.prob = prob | |
| out = self.merge(x.contiguous().view(batch_dim, self.dim * self.num_heads, -1)) | |
| return out | |
| class AttentionalPropagation(nn.Module): | |
| def __init__(self, feature_dim: int, num_heads: int, ac_fn='relu', norm_fn='bn'): | |
| super().__init__() | |
| self.attn = MultiHeadedAttention(num_heads, feature_dim) | |
| self.mlp = MLP([feature_dim * 2, feature_dim * 2, feature_dim], ac_fn=ac_fn, norm_fn=norm_fn) | |
| nn.init.constant_(self.mlp[-1].bias, 0.0) | |
| def forward(self, x, source, M=None): | |
| message = self.attn(x, source, source, M=M) | |
| self.prob = self.attn.prob | |
| out = self.mlp(torch.cat([x, message], dim=1)) | |
| return out | |
| class KeypointEncoder(nn.Module): | |
| """ Joint encoding of visual appearance and location using MLPs""" | |
| def __init__(self, input_dim, feature_dim, layers, ac_fn='relu', norm_fn='bn'): | |
| super().__init__() | |
| self.input_dim = input_dim | |
| self.encoder = MLP([input_dim] + layers + [feature_dim], ac_fn=ac_fn, norm_fn=norm_fn) | |
| nn.init.constant_(self.encoder[-1].bias, 0.0) | |
| def forward(self, kpts, scores=None): | |
| if self.input_dim == 2: | |
| return self.encoder(kpts.transpose(1, 2)) | |
| else: | |
| inputs = [kpts.transpose(1, 2), scores.unsqueeze(1)] # [B, 2, N] + [B, 1, N] | |
| return self.encoder(torch.cat(inputs, dim=1)) | |