Spaces:
Runtime error
Runtime error
| import torch | |
| import gc | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import Tuple | |
| from .pooling import Attention1dPoolingHead, MeanPoolingHead, LightAttentionPoolingHead | |
| from .pooling import MeanPooling, MeanPoolingProjection | |
| def rotate_half(x): | |
| x1, x2 = x.chunk(2, dim=-1) | |
| return torch.cat((-x2, x1), dim=-1) | |
| def apply_rotary_pos_emb(x, cos, sin): | |
| cos = cos[:, :, : x.shape[-2], :] | |
| sin = sin[:, :, : x.shape[-2], :] | |
| return (x * cos) + (rotate_half(x) * sin) | |
| class RotaryEmbedding(nn.Module): | |
| """ | |
| Rotary position embeddings based on those in | |
| [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation | |
| matrices which depend on their relative positions. | |
| """ | |
| def __init__(self, dim: int): | |
| super().__init__() | |
| # Generate and save the inverse frequency buffer (non trainable) | |
| inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)) | |
| inv_freq = inv_freq | |
| self.register_buffer("inv_freq", inv_freq) | |
| self._seq_len_cached = None | |
| self._cos_cached = None | |
| self._sin_cached = None | |
| def _update_cos_sin_tables(self, x, seq_dimension=2): | |
| seq_len = x.shape[seq_dimension] | |
| # Reset the tables if the sequence length has changed, | |
| # or if we're on a new device (possibly due to tracing for instance) | |
| if seq_len != self._seq_len_cached or self._cos_cached.device != x.device: | |
| self._seq_len_cached = seq_len | |
| t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq) | |
| freqs = torch.outer(t, self.inv_freq) | |
| emb = torch.cat((freqs, freqs), dim=-1).to(x.device) | |
| self._cos_cached = emb.cos()[None, None, :, :] | |
| self._sin_cached = emb.sin()[None, None, :, :] | |
| return self._cos_cached, self._sin_cached | |
| def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2) | |
| return ( | |
| apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached), | |
| apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached), | |
| ) | |
| class CrossModalAttention(nn.Module): | |
| def __init__(self, args): | |
| super().__init__() | |
| self.attention_head_size = args.hidden_size // args.num_attention_head | |
| assert ( | |
| self.attention_head_size * args.num_attention_head == args.hidden_size | |
| ), "Embed size needs to be divisible by num heads" | |
| self.num_attention_head = args.num_attention_head | |
| self.hidden_size = args.hidden_size | |
| self.query_proj = nn.Linear(args.hidden_size, args.hidden_size) | |
| self.key_proj = nn.Linear(args.hidden_size, args.hidden_size) | |
| self.value_proj = nn.Linear(args.hidden_size, args.hidden_size) | |
| self.dropout = nn.Dropout(args.attention_probs_dropout) | |
| self.out_proj = nn.Linear(args.hidden_size, args.hidden_size) | |
| self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size) | |
| def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: | |
| new_x_shape = x.size()[:-1] + (self.num_attention_head, self.attention_head_size) | |
| x = x.view(new_x_shape) | |
| return x.permute(0, 2, 1, 3) | |
| def forward(self, query, key, value, attention_mask=None, output_attentions=False): | |
| key_layer = self.transpose_for_scores(self.key_proj(key)) | |
| value_layer = self.transpose_for_scores(self.value_proj(value)) | |
| query_layer = self.transpose_for_scores(self.query_proj(query)) | |
| query_layer = query_layer * self.attention_head_size**-0.5 | |
| query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer) | |
| # Take the dot product between "query" and "key" to get the raw attention scores. | |
| attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) | |
| if attention_mask is not None: | |
| attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) | |
| attention_scores = attention_scores.masked_fill(attention_mask == 0, float('-inf')) | |
| attention_probs = F.softmax(attention_scores, dim=-1) | |
| # This is actually dropping out entire tokens to attend to, which might | |
| # seem a bit unusual, but is taken from the original Transformer paper. | |
| attention_probs = self.dropout(attention_probs) | |
| context_layer = torch.matmul(attention_probs, value_layer) | |
| context_layer = context_layer.permute(0, 2, 1, 3).contiguous() | |
| new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,) | |
| context_layer = context_layer.view(new_context_layer_shape) | |
| outputs = (context_layer, attention_probs) if output_attentions else context_layer | |
| return outputs | |
| class AdapterModel(nn.Module): | |
| def __init__(self, args): | |
| super().__init__() | |
| self.args = args | |
| if 'foldseek_seq' in args.structure_seq: | |
| self.foldseek_embedding = nn.Embedding(args.vocab_size, args.hidden_size) | |
| self.cross_attention_foldseek = CrossModalAttention(args) | |
| if 'ss8_seq' in args.structure_seq: | |
| self.ss_embedding = nn.Embedding(args.vocab_size, args.hidden_size) | |
| self.cross_attention_ss = CrossModalAttention(args) | |
| if 'esm3_structure_seq' in args.structure_seq: | |
| self.esm3_structure_embedding = nn.Embedding(args.vocab_size, args.hidden_size) | |
| self.cross_attention_esm3_structure = CrossModalAttention(args) | |
| self.layer_norm = nn.LayerNorm(args.hidden_size) | |
| if args.pooling_method == 'attention1d': | |
| self.classifier = Attention1dPoolingHead(args.hidden_size, args.num_labels, args.pooling_dropout) | |
| elif args.pooling_method == 'mean': | |
| if "PPI" in args.dataset: | |
| self.pooling = MeanPooling() | |
| self.projection = MeanPoolingProjection(args.hidden_size, args.num_labels, args.pooling_dropout) | |
| else: | |
| self.classifier = MeanPoolingHead(args.hidden_size, args.num_labels, args.pooling_dropout) | |
| elif args.pooling_method == 'light_attention': | |
| self.classifier = LightAttentionPoolingHead(args.hidden_size, args.num_labels, args.pooling_dropout) | |
| else: | |
| raise ValueError(f"classifier method {args.pooling_method} not supported") | |
| def plm_embedding(self, plm_model, aa_seq, attention_mask, structure_tokens=None): | |
| with torch.no_grad(): | |
| if "ProSST" in self.args.plm_model: | |
| outputs = plm_model(input_ids=aa_seq, attention_mask=attention_mask, ss_input_ids=structure_tokens, output_hidden_states=True) | |
| elif "Prime" in self.args.plm_model or "deep" in self.args.plm_model: | |
| outputs = plm_model(input_ids=aa_seq, attention_mask=attention_mask, output_hidden_states=True) | |
| elif self.training and hasattr(self, 'args') and self.args.training_method == 'full': | |
| outputs = plm_model(input_ids=aa_seq, attention_mask=attention_mask) | |
| else: | |
| outputs = plm_model(input_ids=aa_seq, attention_mask=attention_mask) | |
| if "ProSST" in self.args.plm_model or "Prime" in self.args.plm_model: | |
| seq_embeds = outputs.hidden_states[-1] | |
| else: | |
| seq_embeds = outputs.last_hidden_state | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| return seq_embeds | |
| def forward(self, plm_model, batch): | |
| if "ProSST" in self.args.plm_model: | |
| aa_seq, attention_mask, stru_tokens = batch['aa_seq_input_ids'], batch['aa_seq_attention_mask'], batch['aa_seq_stru_tokens'] | |
| seq_embeds = self.plm_embedding(plm_model, aa_seq, attention_mask, stru_tokens) | |
| else: | |
| aa_seq, attention_mask = batch['aa_seq_input_ids'], batch['aa_seq_attention_mask'] | |
| seq_embeds = self.plm_embedding(plm_model, aa_seq, attention_mask) | |
| if 'foldseek_seq' in self.args.structure_seq: | |
| foldseek_seq = batch['foldseek_seq_input_ids'] | |
| foldseek_embeds = self.foldseek_embedding(foldseek_seq) | |
| foldseek_embeds = self.cross_attention_foldseek(foldseek_embeds, seq_embeds, seq_embeds, attention_mask) | |
| embeds = seq_embeds + foldseek_embeds | |
| embeds = self.layer_norm(embeds) | |
| if 'ss8_seq' in self.args.structure_seq: | |
| ss_seq = batch['ss8_seq_input_ids'] | |
| ss_embeds = self.ss_embedding(ss_seq) | |
| if 'foldseek_seq' in self.args.structure_seq: | |
| # cross attention with foldseek | |
| ss_embeds = self.cross_attention_ss(ss_embeds, embeds, embeds, attention_mask) | |
| embeds = ss_embeds + embeds | |
| else: | |
| # cross attention with sequence | |
| ss_embeds = self.cross_attention_ss(ss_embeds, seq_embeds, seq_embeds, attention_mask) | |
| embeds = ss_embeds + seq_embeds | |
| embeds = self.layer_norm(embeds) | |
| if 'esm3_structure_seq' in self.args.structure_seq: | |
| esm3_structure_seq = batch['esm3_structure_seq_input_ids'] | |
| esm3_structure_embeds = self.esm3_structure_embedding(esm3_structure_seq) | |
| if 'foldseek_seq' in self.args.structure_seq: | |
| # cross attention with foldseek | |
| esm3_structure_embeds = self.cross_attention_esm3_structure(esm3_structure_embeds, embeds, embeds, attention_mask) | |
| embeds = esm3_structure_embeds + embeds | |
| elif 'ss8_seq' in self.args.structure_seq: | |
| # cross attention with ss8 | |
| esm3_structure_embeds = self.cross_attention_esm3_structure(esm3_structure_embeds, ss_embeds, ss_embeds, attention_mask) | |
| embeds = esm3_structure_embeds + ss_embeds | |
| else: | |
| # cross attention with sequence | |
| esm3_structure_embeds = self.cross_attention_esm3_structure(esm3_structure_embeds, seq_embeds, seq_embeds, attention_mask) | |
| embeds = esm3_structure_embeds + seq_embeds | |
| embeds = self.layer_norm(embeds) | |
| if self.args.structure_seq: | |
| logits = self.classifier(embeds, attention_mask) | |
| else: | |
| logits = self.classifier(seq_embeds, attention_mask) | |
| return logits | |