Spaces:
Runtime error
Runtime error
| # -*- coding: UTF-8 -*- | |
| '''================================================= | |
| @Project -> File pram -> gm | |
| @IDE PyCharm | |
| @Author fx221@cam.ac.uk | |
| @Date 07/02/2024 10:47 | |
| ==================================================''' | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from nets.layers import KeypointEncoder, AttentionalPropagation | |
| from nets.utils import normalize_keypoints, arange_like | |
| eps = 1e-8 | |
| def dual_softmax(M, dustbin): | |
| M = torch.cat([M, dustbin.expand([M.shape[0], M.shape[1], 1])], dim=-1) | |
| M = torch.cat([M, dustbin.expand([M.shape[0], 1, M.shape[2]])], dim=-2) | |
| score = torch.log_softmax(M, dim=-1) + torch.log_softmax(M, dim=1) | |
| return torch.exp(score) | |
| def sinkhorn(M, r, c, iteration): | |
| p = torch.softmax(M, dim=-1) | |
| u = torch.ones_like(r) | |
| v = torch.ones_like(c) | |
| for _ in range(iteration): | |
| u = r / ((p * v.unsqueeze(-2)).sum(-1) + eps) | |
| v = c / ((p * u.unsqueeze(-1)).sum(-2) + eps) | |
| p = p * u.unsqueeze(-1) * v.unsqueeze(-2) | |
| return p | |
| def sink_algorithm(M, dustbin, iteration): | |
| M = torch.cat([M, dustbin.expand([M.shape[0], M.shape[1], 1])], dim=-1) | |
| M = torch.cat([M, dustbin.expand([M.shape[0], 1, M.shape[2]])], dim=-2) | |
| r = torch.ones([M.shape[0], M.shape[1] - 1], device='cuda') | |
| r = torch.cat([r, torch.ones([M.shape[0], 1], device='cuda') * M.shape[1]], dim=-1) | |
| c = torch.ones([M.shape[0], M.shape[2] - 1], device='cuda') | |
| c = torch.cat([c, torch.ones([M.shape[0], 1], device='cuda') * M.shape[2]], dim=-1) | |
| p = sinkhorn(M, r, c, iteration) | |
| return p | |
| class AttentionalGNN(nn.Module): | |
| def __init__(self, feature_dim: int, layer_names: list, hidden_dim: int = 256, ac_fn: str = 'relu', | |
| norm_fn: str = 'bn'): | |
| super().__init__() | |
| self.layers = nn.ModuleList([ | |
| AttentionalPropagation(feature_dim=feature_dim, num_heads=4, hidden_dim=hidden_dim, ac_fn=ac_fn, | |
| norm_fn=norm_fn) | |
| for _ in range(len(layer_names))]) | |
| self.names = layer_names | |
| def forward(self, desc0, desc1): | |
| # desc0s = [] | |
| # desc1s = [] | |
| for i, (layer, name) in enumerate(zip(self.layers, self.names)): | |
| if name == 'cross': | |
| src0, src1 = desc1, desc0 | |
| else: | |
| src0, src1 = desc0, desc1 | |
| delta0 = layer(desc0, src0) | |
| # prob0 = layer.attn.prob | |
| delta1 = layer(desc1, src1) | |
| # prob1 = layer.attn.prob | |
| desc0, desc1 = (desc0 + delta0), (desc1 + delta1) | |
| # if name == 'cross': | |
| # desc0s.append(desc0) | |
| # desc1s.append(desc1) | |
| return [desc0], [desc1] | |
| def predict(self, desc0, desc1, n_it=-1): | |
| for i, (layer, name) in enumerate(zip(self.layers, self.names)): | |
| if name == 'cross': | |
| src0, src1 = desc1, desc0 | |
| else: | |
| src0, src1 = desc0, desc1 | |
| delta0 = layer(desc0, src0) | |
| # prob0 = layer.attn.prob | |
| delta1 = layer(desc1, src1) | |
| # prob1 = layer.attn.prob | |
| desc0, desc1 = (desc0 + delta0), (desc1 + delta1) | |
| if name == 'cross' and i == n_it: | |
| break | |
| return [desc0], [desc1] | |
| class GM(nn.Module): | |
| default_config = { | |
| 'descriptor_dim': 128, | |
| 'hidden_dim': 256, | |
| 'keypoint_encoder': [32, 64, 128, 256], | |
| 'GNN_layers': ['self', 'cross'] * 9, # [self, cross, self, cross, ...] 9 in total | |
| 'sinkhorn_iterations': 20, | |
| 'match_threshold': 0.2, | |
| 'with_pose': False, | |
| 'n_layers': 9, | |
| 'n_min_tokens': 256, | |
| 'with_sinkhorn': True, | |
| 'ac_fn': 'relu', | |
| 'norm_fn': 'bn', | |
| 'weight_path': None, | |
| } | |
| required_inputs = [ | |
| 'image0', 'keypoints0', 'scores0', 'descriptors0', | |
| 'image1', 'keypoints1', 'scores1', 'descriptors1', | |
| ] | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = {**self.default_config, **config} | |
| print('gm: ', self.config) | |
| self.n_layers = self.config['n_layers'] | |
| self.with_sinkhorn = self.config['with_sinkhorn'] | |
| self.match_threshold = self.config['match_threshold'] | |
| self.sinkhorn_iterations = self.config['sinkhorn_iterations'] | |
| self.kenc = KeypointEncoder( | |
| self.config['descriptor_dim'] if self.config['descriptor_dim'] > 0 else 128, | |
| self.config['keypoint_encoder'], | |
| ac_fn=self.config['ac_fn'], | |
| norm_fn=self.config['norm_fn']) | |
| self.gnn = AttentionalGNN( | |
| feature_dim=self.config['descriptor_dim'] if self.config['descriptor_dim'] > 0 else 128, | |
| hidden_dim=self.config['hidden_dim'], | |
| layer_names=self.config['GNN_layers'], | |
| ac_fn=self.config['ac_fn'], | |
| norm_fn=self.config['norm_fn'], | |
| ) | |
| self.final_proj = nn.ModuleList([nn.Conv1d( | |
| self.config['descriptor_dim'] if self.config['descriptor_dim'] > 0 else 128, | |
| self.config['descriptor_dim'] if self.config['descriptor_dim'] > 0 else 128, | |
| kernel_size=1, bias=True) for _ in range(self.n_layers)]) | |
| bin_score = torch.nn.Parameter(torch.tensor(1.)) | |
| self.register_parameter('bin_score', bin_score) | |
| self.match_net = None # GraphLoss(config=self.config) | |
| self.self_prob0 = None | |
| self.self_prob1 = None | |
| self.cross_prob0 = None | |
| self.cross_prob1 = None | |
| self.desc_compressor = None | |
| def forward_train(self, data): | |
| pass | |
| def produce_matches(self, data, p=0.2, n_it=-1, **kwargs): | |
| kpts0, kpts1 = data['keypoints0'], data['keypoints1'] | |
| scores0, scores1 = data['scores0'], data['scores1'] | |
| if kpts0.shape[1] == 0 or kpts1.shape[1] == 0: # no keypoints | |
| shape0, shape1 = kpts0.shape[:-1], kpts1.shape[:-1] | |
| return { | |
| 'matches0': kpts0.new_full(shape0, -1, dtype=torch.int)[0], | |
| 'matches1': kpts1.new_full(shape1, -1, dtype=torch.int)[0], | |
| 'matching_scores0': kpts0.new_zeros(shape0)[0], | |
| 'matching_scores1': kpts1.new_zeros(shape1)[0], | |
| 'skip_train': True | |
| } | |
| if 'norm_keypoints0' in data.keys() and 'norm_keypoints1' in data.keys(): | |
| norm_kpts0 = data['norm_keypoints0'] | |
| norm_kpts1 = data['norm_keypoints1'] | |
| elif 'image0' in data.keys() and 'image1' in data.keys(): | |
| norm_kpts0 = normalize_keypoints(kpts0, data['image0'].shape) | |
| norm_kpts1 = normalize_keypoints(kpts1, data['image1'].shape) | |
| elif 'image_shape0' in data.keys() and 'image_shape1' in data.keys(): | |
| norm_kpts0 = normalize_keypoints(kpts0, data['image_shape0']) | |
| norm_kpts1 = normalize_keypoints(kpts1, data['image_shape1']) | |
| else: | |
| raise ValueError('Require image shape for keypoint coordinate normalization') | |
| # Keypoint MLP encoder. | |
| enc0, enc1 = self.encode_keypoint(norm_kpts0=norm_kpts0, norm_kpts1=norm_kpts1, scores0=scores0, | |
| scores1=scores1) | |
| if self.config['descriptor_dim'] > 0: | |
| desc0, desc1 = data['descriptors0'], data['descriptors1'] | |
| desc0 = desc0.transpose(0, 2, 1) # [B, N, D ] -> [B, D, N] | |
| desc1 = desc1.transpose(0, 2, 1) # [B, N, D ] -> [B, D, N] | |
| with torch.no_grad(): | |
| if desc0.shape[1] != self.config['descriptor_dim']: | |
| desc0 = self.desc_compressor(desc0) | |
| if desc1.shape[1] != self.config['descriptor_dim']: | |
| desc1 = self.desc_compressor(desc1) | |
| desc0 = desc0 + enc0 | |
| desc1 = desc1 + enc1 | |
| else: | |
| desc0 = enc0 | |
| desc1 = enc1 | |
| desc0s, desc1s = self.gnn.predict(desc0, desc1, n_it=n_it) | |
| mdescs0 = self.final_proj[n_it](desc0s[-1]) | |
| mdescs1 = self.final_proj[n_it](desc1s[-1]) | |
| dist = torch.einsum('bdn,bdm->bnm', mdescs0, mdescs1) | |
| if self.config['descriptor_dim'] > 0: | |
| dist = dist / self.config['descriptor_dim'] ** .5 | |
| else: | |
| dist = dist / 128 ** .5 | |
| score = self.compute_score(dist=dist, dustbin=self.bin_score, iteration=self.sinkhorn_iterations) | |
| indices0, indices1, mscores0, mscores1 = self.compute_matches(scores=score, p=p) | |
| output = { | |
| 'matches0': indices0, # use -1 for invalid match | |
| 'matches1': indices1, # use -1 for invalid match | |
| 'matching_scores0': mscores0, | |
| 'matching_scores1': mscores1, | |
| } | |
| return output | |
| def forward(self, data, mode=0): | |
| if not self.training: | |
| return self.produce_matches(data=data, n_it=-1) | |
| return self.forward_train(data=data) | |
| def encode_keypoint(self, norm_kpts0, norm_kpts1, scores0, scores1): | |
| return self.kenc(norm_kpts0, scores0), self.kenc(norm_kpts1, scores1) | |
| def compute_distance(self, desc0, desc1, layer_id=-1): | |
| mdesc0 = self.final_proj[layer_id](desc0) | |
| mdesc1 = self.final_proj[layer_id](desc1) | |
| dist = torch.einsum('bdn,bdm->bnm', mdesc0, mdesc1) | |
| dist = dist / self.config['descriptor_dim'] ** .5 | |
| return dist | |
| def compute_score(self, dist, dustbin, iteration): | |
| if self.with_sinkhorn: | |
| score = sink_algorithm(M=dist, dustbin=dustbin, | |
| iteration=iteration) # [nI * nB, N, M] | |
| else: | |
| score = dual_softmax(M=dist, dustbin=dustbin) | |
| return score | |
| def compute_matches(self, scores, p=0.2): | |
| max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1) | |
| indices0, indices1 = max0.indices, max1.indices | |
| mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0) | |
| mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1) | |
| zero = scores.new_tensor(0) | |
| # mscores0 = torch.where(mutual0, max0.values.exp(), zero) | |
| mscores0 = torch.where(mutual0, max0.values, zero) | |
| mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero) | |
| # valid0 = mutual0 & (mscores0 > self.config['match_threshold']) | |
| valid0 = mutual0 & (mscores0 > p) | |
| valid1 = mutual1 & valid0.gather(1, indices1) | |
| indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1)) | |
| indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1)) | |
| return indices0, indices1, mscores0, mscores1 | |