Spaces:
Sleeping
Sleeping
| # Compostion of the VisionTransformer class from timm with extra features: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py | |
| import torch | |
| import torch.nn as nn | |
| from typing import Tuple, Union, Sequence, Any | |
| from timm.layers import trunc_normal_ | |
| from timm.models.vision_transformer import Block, Attention | |
| from layers.transformer_layers import BlockWQKVReturn, AttentionWQKVReturn | |
| from utils.misc_utils import compute_attention | |
| class BaselineViT(torch.nn.Module): | |
| """ | |
| Modifications: | |
| - Use PDiscoBlock instead of Block | |
| - Use PDiscoAttention instead of Attention | |
| - Return the mean of k over heads from attention | |
| - Option to use only class tokens or only patch tokens or both (concat) for classification | |
| """ | |
| def __init__(self, init_model: torch.nn.Module, num_classes: int, | |
| class_tokens_only: bool = False, | |
| patch_tokens_only: bool = False, return_transformer_qkv: bool = False) -> None: | |
| super().__init__() | |
| self.num_classes = num_classes | |
| self.class_tokens_only = class_tokens_only | |
| self.patch_tokens_only = patch_tokens_only | |
| self.num_prefix_tokens = init_model.num_prefix_tokens | |
| self.num_reg_tokens = init_model.num_reg_tokens | |
| self.has_class_token = init_model.has_class_token | |
| self.no_embed_class = init_model.no_embed_class | |
| self.cls_token = init_model.cls_token | |
| self.reg_token = init_model.reg_token | |
| self.patch_embed = init_model.patch_embed | |
| self.pos_embed = init_model.pos_embed | |
| self.pos_drop = init_model.pos_drop | |
| self.part_embed = nn.Identity() | |
| self.patch_prune = nn.Identity() | |
| self.norm_pre = init_model.norm_pre | |
| self.blocks = init_model.blocks | |
| self.norm = init_model.norm | |
| self.fc_norm = init_model.fc_norm | |
| if class_tokens_only or patch_tokens_only: | |
| self.head = nn.Linear(init_model.embed_dim, num_classes) | |
| else: | |
| self.head = nn.Linear(init_model.embed_dim * 2, num_classes) | |
| self.h_fmap = int(self.patch_embed.img_size[0] // self.patch_embed.patch_size[0]) | |
| self.w_fmap = int(self.patch_embed.img_size[1] // self.patch_embed.patch_size[1]) | |
| self.return_transformer_qkv = return_transformer_qkv | |
| self.convert_blocks_and_attention() | |
| self._init_weights_head() | |
| def convert_blocks_and_attention(self): | |
| for module in self.modules(): | |
| if isinstance(module, Block): | |
| module.__class__ = BlockWQKVReturn | |
| elif isinstance(module, Attention): | |
| module.__class__ = AttentionWQKVReturn | |
| def _pos_embed(self, x: torch.Tensor) -> torch.Tensor: | |
| pos_embed = self.pos_embed | |
| to_cat = [] | |
| if self.cls_token is not None: | |
| to_cat.append(self.cls_token.expand(x.shape[0], -1, -1)) | |
| if self.reg_token is not None: | |
| to_cat.append(self.reg_token.expand(x.shape[0], -1, -1)) | |
| if self.no_embed_class: | |
| # deit-3, updated JAX (big vision) | |
| # position embedding does not overlap with class token, add then concat | |
| x = x + pos_embed | |
| if to_cat: | |
| x = torch.cat(to_cat + [x], dim=1) | |
| else: | |
| # original timm, JAX, and deit vit impl | |
| # pos_embed has entry for class token, concat then add | |
| if to_cat: | |
| x = torch.cat(to_cat + [x], dim=1) | |
| x = x + pos_embed | |
| return self.pos_drop(x) | |
| def _init_weights_head(self): | |
| trunc_normal_(self.head.weight, std=.02) | |
| if self.head.bias is not None: | |
| nn.init.constant_(self.head.bias, 0.) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]: | |
| x = self.patch_embed(x) | |
| # Position Embedding | |
| x = self._pos_embed(x) | |
| x = self.part_embed(x) | |
| x = self.patch_prune(x) | |
| # Forward pass through transformer | |
| x = self.norm_pre(x) | |
| if self.return_transformer_qkv: | |
| # Return keys of last attention layer | |
| for i, blk in enumerate(self.blocks): | |
| x, qkv = blk(x, return_qkv=True) | |
| else: | |
| x = self.blocks(x) | |
| x = self.norm(x) | |
| # Classification head | |
| x = self.fc_norm(x) | |
| if self.class_tokens_only: # only use class token | |
| x = x[:, 0, :] | |
| elif self.patch_tokens_only: # only use patch tokens | |
| x = x[:, self.num_prefix_tokens:, :].mean(dim=1) | |
| else: | |
| x = torch.cat([x[:, 0, :], x[:, self.num_prefix_tokens:, :].mean(dim=1)], dim=1) | |
| x = self.head(x) | |
| if self.return_transformer_qkv: | |
| return x, qkv | |
| else: | |
| return x | |
| def get_specific_intermediate_layer( | |
| self, | |
| x: torch.Tensor, | |
| n: int = 1, | |
| return_qkv: bool = False, | |
| return_att_weights: bool = False, | |
| ): | |
| num_blocks = len(self.blocks) | |
| attn_weights = [] | |
| if n >= num_blocks: | |
| raise ValueError(f"n must be less than {num_blocks}") | |
| # forward pass | |
| x = self.patch_embed(x) | |
| x = self._pos_embed(x) | |
| x = self.norm_pre(x) | |
| if n == -1: | |
| if return_qkv: | |
| raise ValueError("take_indice cannot be -1 if return_transformer_qkv is True") | |
| else: | |
| return x | |
| for i, blk in enumerate(self.blocks): | |
| if self.return_transformer_qkv: | |
| x, qkv = blk(x, return_qkv=True) | |
| if return_att_weights: | |
| attn_weight, _ = compute_attention(qkv) | |
| attn_weights.append(attn_weight.detach()) | |
| else: | |
| x = blk(x) | |
| if i == n: | |
| output = x.clone() | |
| if self.return_transformer_qkv and return_qkv: | |
| qkv_output = qkv.clone() | |
| break | |
| if self.return_transformer_qkv and return_qkv and return_att_weights: | |
| return output, qkv_output, attn_weights | |
| elif self.return_transformer_qkv and return_qkv: | |
| return output, qkv_output | |
| elif self.return_transformer_qkv and return_att_weights: | |
| return output, attn_weights | |
| else: | |
| return output | |
| def _intermediate_layers( | |
| self, | |
| x: torch.Tensor, | |
| n: Union[int, Sequence] = 1, | |
| ): | |
| outputs, num_blocks = [], len(self.blocks) | |
| if self.return_transformer_qkv: | |
| qkv_outputs = [] | |
| take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n) | |
| # forward pass | |
| x = self.patch_embed(x) | |
| x = self._pos_embed(x) | |
| x = self.norm_pre(x) | |
| for i, blk in enumerate(self.blocks): | |
| if self.return_transformer_qkv: | |
| x, qkv = blk(x, return_qkv=True) | |
| else: | |
| x = blk(x) | |
| if i in take_indices: | |
| outputs.append(x) | |
| if self.return_transformer_qkv: | |
| qkv_outputs.append(qkv) | |
| if self.return_transformer_qkv: | |
| return outputs, qkv_outputs | |
| else: | |
| return outputs | |
| def get_intermediate_layers( | |
| self, | |
| x: torch.Tensor, | |
| n: Union[int, Sequence] = 1, | |
| reshape: bool = False, | |
| return_prefix_tokens: bool = False, | |
| norm: bool = False, | |
| ) -> tuple[tuple, Any]: | |
| """ Intermediate layer accessor (NOTE: This is a WIP experiment). | |
| Inspired by DINO / DINOv2 interface | |
| """ | |
| # take last n blocks if n is an int, if in is a sequence, select by matching indices | |
| if self.return_transformer_qkv: | |
| outputs, qkv = self._intermediate_layers(x, n) | |
| else: | |
| outputs = self._intermediate_layers(x, n) | |
| if norm: | |
| outputs = [self.norm(out) for out in outputs] | |
| prefix_tokens = [out[:, 0:self.num_prefix_tokens] for out in outputs] | |
| outputs = [out[:, self.num_prefix_tokens:] for out in outputs] | |
| if reshape: | |
| grid_size = self.patch_embed.grid_size | |
| outputs = [ | |
| out.reshape(x.shape[0], grid_size[0], grid_size[1], -1).permute(0, 3, 1, 2).contiguous() | |
| for out in outputs | |
| ] | |
| if return_prefix_tokens: | |
| return_out = tuple(zip(outputs, prefix_tokens)) | |
| else: | |
| return_out = tuple(outputs) | |
| if self.return_transformer_qkv: | |
| return return_out, qkv | |
| else: | |
| return return_out | |