Spaces:
Runtime error
Runtime error
| # coding=utf-8 | |
| # Copyright 2022 The OpenBMB Team and The HuggingFace Inc. team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ PyTorch CPMAnt""" | |
| import math | |
| from typing import List, Optional, Tuple, Union | |
| import torch | |
| import torch.nn.functional as F | |
| import torch.utils.checkpoint | |
| from torch import nn | |
| from torch.nn import CrossEntropyLoss | |
| from ...activations import ACT2FN | |
| from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast | |
| from ...modeling_utils import PreTrainedModel | |
| from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging | |
| from .configuration_cpmant import CpmAntConfig | |
| logger = logging.get_logger(__name__) | |
| _CHECKPOINT_FOR_DOC = "openbmb/cpm-ant-10b" | |
| _CONFIG_FOR_DOC = "CpmAntConfig" | |
| CPMANT_PRETRAINED_MODEL_ARCHIVE_LIST = [ | |
| "openbmb/cpm-ant-10b", | |
| # See all CPMAnt models at https://huggingface.co/models?filter=cpmant | |
| ] | |
| class CpmAntLayerNorm(nn.Module): | |
| """ | |
| We use Root Mean Square (RMS) Layer Normalization, please see https://arxiv.org/abs/1910.07467 for details." | |
| """ | |
| def __init__(self, config: CpmAntConfig): | |
| super().__init__() | |
| self.eps = config.eps | |
| self.dim_norm = config.hidden_size | |
| self.weight = nn.Parameter(torch.empty(config.hidden_size)) | |
| def forward(self, hidden_states: torch.Tensor): | |
| """ | |
| Args: | |
| hidden_states (`torch.Tensor` of shape `(batch, seq_len, dim_in)`) | |
| """ | |
| if hidden_states.size(-1) != self.dim_norm: | |
| raise AssertionError("hidden_states.size(-1) != self.dim_norm") | |
| old_dtype = hidden_states.dtype | |
| variance = hidden_states.to(torch.float32).pow(2).mean(dim=-1, keepdim=True) | |
| hidden_states = (hidden_states * torch.rsqrt(variance + self.eps)).to(old_dtype) * self.weight | |
| return hidden_states | |
| class CpmAntAttention(nn.Module): | |
| def __init__(self, config: CpmAntConfig): | |
| super().__init__() | |
| self.dim_model = config.hidden_size | |
| self.num_heads = config.num_attention_heads | |
| self.dim_head = config.dim_head | |
| self.project_q = nn.Linear(self.dim_model, self.num_heads * self.dim_head, bias=False) | |
| self.project_k = nn.Linear(self.dim_model, self.num_heads * self.dim_head, bias=False) | |
| self.project_v = nn.Linear(self.dim_model, self.num_heads * self.dim_head, bias=False) | |
| self.attention_out = nn.Linear(self.num_heads * self.dim_head, self.dim_model, bias=False) | |
| self.softmax = torch.nn.Softmax(dim=-1) | |
| if config.dropout_p is not None: | |
| self.dropout = torch.nn.Dropout(p=config.dropout_p) | |
| else: | |
| self.dropout = None | |
| def forward( | |
| self, | |
| hidden_q: torch.Tensor, | |
| hidden_kv: torch.Tensor, | |
| attention_mask: torch.BoolTensor, | |
| position_bias: torch.Tensor, | |
| output_attentions: Optional[bool] = False, | |
| past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | |
| use_cache: Optional[bool] = None, | |
| ): | |
| """ | |
| Args: | |
| hidden_q (`torch.Tensor`): | |
| Input of transformer block(self-attention block). It can be the raw embedding of a batch of sequences. | |
| hidden_kv (`torch.Tensor` of shape `(batch, len_k, dim_model)`)): | |
| Tensor *key_value* and *query* of shape `(batch, len_k, dim_model)` | |
| attention_mask (`torch.Tensor` of shape `(batch, len_seq, len_seq)`): | |
| Avoid invalid areas to participate in the calculation of self-attention. | |
| position_bias (`torch.Tensor` of shape `(batch, len_seq, len_seq)`): | |
| Provide positional information to self-attention block. | |
| output_attentions (`bool`, *optional*): | |
| Whether or not to return the attentions tensors of all attention layers. | |
| past_key_values (`Tuple[torch.Tensor, torch.Tensor]`, *optional*): | |
| Cached past key and value projection states. | |
| use_cache (`bool`, *optional*): | |
| If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding | |
| (see `past_key_values`). | |
| """ | |
| batch_size = hidden_q.size(0) | |
| len_q = hidden_q.size(1) | |
| len_k = hidden_kv.size(1) | |
| query = self.project_q(hidden_q) | |
| key = self.project_k(hidden_kv) | |
| value = self.project_v(hidden_kv) | |
| query = query.view(batch_size, len_q, self.num_heads, self.dim_head).permute(0, 2, 1, 3) | |
| key = key.view(batch_size, len_k, self.num_heads, self.dim_head).permute(0, 2, 1, 3) | |
| value = value.view(batch_size, len_k, self.num_heads, self.dim_head).permute(0, 2, 1, 3) | |
| if past_key_values is not None: | |
| key = torch.cat([past_key_values[0], key], dim=-2) | |
| value = torch.cat([past_key_values[1], value], dim=-2) | |
| len_k = key.size(-2) | |
| # (batch_size, num_heads, len_q, dim_head) @ (batch_size, num_heads, dim_head, len_k) -> (batch_size, num_heads, len_q, len_k) | |
| score = torch.matmul(query, key.transpose(-1, -2)) / math.sqrt(self.dim_head) | |
| score = score + position_bias | |
| score = torch.masked_fill( | |
| score, | |
| attention_mask.view(batch_size, 1, len_q, len_k) == torch.tensor(False), | |
| torch.scalar_tensor(float("-inf"), device=score.device, dtype=score.dtype), | |
| ) | |
| score = self.softmax(score) | |
| score = torch.masked_fill( | |
| score, | |
| attention_mask.view(batch_size, 1, len_q, len_k) == torch.tensor(False), | |
| torch.scalar_tensor(0, device=score.device, dtype=score.dtype), | |
| ) | |
| if output_attentions: | |
| attn_weights = score | |
| else: | |
| attn_weights = None | |
| if self.dropout is not None: | |
| score = self.dropout(score) | |
| # (batch_size, num_heads, len_q, len_k) @ (batch_size, num_heads, len_k, dim_head) -> (batch_size, num_heads, len_q, dim_head) | |
| score = torch.matmul(score, value) | |
| score = score.view(batch_size, self.num_heads, len_q, self.dim_head).permute(0, 2, 1, 3) | |
| score = score.contiguous().view(batch_size, len_q, self.num_heads * self.dim_head) | |
| score = self.attention_out(score) | |
| past_key_values = None | |
| if use_cache: | |
| past_key_values = (key, value) | |
| return score, attn_weights, past_key_values | |
| class CpmAntSelfAttentionBlock(nn.Module): | |
| def __init__(self, config: CpmAntConfig): | |
| super().__init__() | |
| self.layernorm_before_attention = CpmAntLayerNorm(config) | |
| self.self_attention = CpmAntAttention(config) | |
| if config.dropout_p: | |
| self.dropout = torch.nn.Dropout(config.dropout_p) | |
| else: | |
| self.dropout = None | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| position_bias: Optional[torch.Tensor] = None, | |
| output_attentions: Optional[bool] = False, | |
| past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | |
| use_cache: Optional[bool] = None, | |
| ): | |
| """ | |
| Args: | |
| hidden_states (`torch.Tensor` of shape `(batch, len_seq, dim_model)`): | |
| Input of transformer block(self-attention block). It can be the raw embedding of a batch of sequences. | |
| attention_mask (`torch.Tensor` of shape `(batch, len_seq, len_seq)`): | |
| Avoid invalid areas to participate in the calculation of self-attention. | |
| position_bias (`torch.Tensor` of shape `(batch, len_seq, len_seq)`): | |
| Provide positional information to self-attention block. | |
| output_attentions (`bool`, *optional*): | |
| Whether or not to return the attentions tensors of all attention layers. | |
| past_key_values (`Tuple(torch.FloatTensor)`, *optional*): | |
| Cached past key and value projection states. | |
| use_cache (`bool`, *optional*): | |
| If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding | |
| (see `past_key_values`). | |
| """ | |
| outputs = self.layernorm_before_attention(hidden_states) | |
| outputs = self.self_attention( | |
| outputs, outputs, attention_mask, position_bias, output_attentions, past_key_values, use_cache | |
| ) | |
| outputs, attn_weights, current_key_value = outputs | |
| if self.dropout is not None: | |
| outputs = self.dropout(outputs) | |
| hidden_states = hidden_states + outputs | |
| return hidden_states, attn_weights, current_key_value | |
| class CpmAntDenseGatedACT(nn.Module): | |
| def __init__(self, config: CpmAntConfig): | |
| super().__init__() | |
| self.w_0 = nn.Linear(config.hidden_size, config.dim_ff, bias=False) | |
| self.w_1 = nn.Linear(config.hidden_size, config.dim_ff, bias=False) | |
| self.act = torch.nn.GELU() | |
| def forward(self, hidden_states: torch.Tensor): | |
| """Transform an input tensor from one feature space to another via a nonlinear operation | |
| Args: | |
| hidden_states (`torch.Tensor` of shape `(batch, seq_len, dim_in)`) | |
| """ | |
| gate_score = self.act(self.w_0(hidden_states)) | |
| hidden_states = self.w_1(hidden_states) | |
| hidden_states = gate_score * hidden_states | |
| return hidden_states | |
| class CpmAntFeedForward(nn.Module): | |
| def __init__(self, config: CpmAntConfig): | |
| super().__init__() | |
| self.w_in = CpmAntDenseGatedACT(config) | |
| if config.dropout_p is not None: | |
| self.dropout = torch.nn.Dropout(config.dropout_p) | |
| else: | |
| self.dropout = None | |
| self.w_out = nn.Linear(config.dim_ff, config.hidden_size, bias=False) | |
| def forward(self, hidden_states: torch.Tensor): | |
| """ | |
| Args: | |
| hidden_states (`torch.Tensor` of shape `(batch, seq_len, dim_in)`) | |
| """ | |
| hidden_states = self.w_in(hidden_states) | |
| if self.dropout is not None: | |
| hidden_states = self.dropout(hidden_states) | |
| hidden_states = self.w_out(hidden_states) | |
| return hidden_states | |
| class CpmAntFFNBlock(nn.Module): | |
| def __init__(self, config: CpmAntConfig): | |
| super().__init__() | |
| self.layernorm_before_ffn = CpmAntLayerNorm(config) | |
| self.ffn = CpmAntFeedForward(config) | |
| if config.dropout_p: | |
| self.dropout = torch.nn.Dropout(config.dropout_p) | |
| else: | |
| self.dropout = None | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| ): | |
| """ | |
| Args: | |
| hidden_states (`torch.Tensor` of shape `(batch, len_seq, dim_model)`): | |
| Hidden states before feed forward layer. | |
| """ | |
| ln_outputs = self.layernorm_before_ffn(hidden_states) | |
| outputs = self.ffn(ln_outputs) | |
| if self.dropout is not None: | |
| outputs = self.dropout(outputs) | |
| hidden_states = hidden_states + outputs | |
| return hidden_states | |
| class CpmAntTransformerBlock(nn.Module): | |
| def __init__(self, config: CpmAntConfig): | |
| super().__init__() | |
| self.self_att = CpmAntSelfAttentionBlock(config) | |
| self.ffn = CpmAntFFNBlock(config) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| position_bias: Optional[torch.Tensor] = None, | |
| output_attentions: Optional[bool] = False, | |
| past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | |
| use_cache: Optional[bool] = None, | |
| ): | |
| """ | |
| Args: | |
| hidden_states (`torch.Tensor`): | |
| Input to the layer of shape `(batch, seq_len, dim_model)` | |
| attention_mask (`torch.Tensor`): | |
| Avoid invalid areas to participate in the calculation of shape `(batch, seq_len, seq_len)` | |
| position_bias (`torch.Tensor`): | |
| Provides position information to attention mechanism of shape `(num_heads, seq_len, seq_len)` | |
| output_attentions (`bool`, *optional*): | |
| Whether or not to return the attentions tensors of all attention layers. | |
| past_key_values (`Tuple[torch.Tensor, torch.Tensor])`, *optional*): | |
| Cached past key and value projection states | |
| use_cache (`bool`, *optional*): | |
| If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding | |
| (see `past_key_values`). | |
| """ | |
| hidden_states = self.self_att( | |
| hidden_states, | |
| attention_mask=attention_mask, | |
| position_bias=position_bias, | |
| output_attentions=output_attentions, | |
| past_key_values=past_key_values, | |
| use_cache=use_cache, | |
| ) | |
| hidden_states, attn_weights, current_key_value = hidden_states | |
| hidden_states = self.ffn(hidden_states) | |
| return hidden_states, attn_weights, current_key_value | |
| class CpmAntEncoder(nn.Module): | |
| def __init__(self, config: CpmAntConfig): | |
| super().__init__() | |
| self.num_layers = config.num_hidden_layers | |
| self.layers = nn.ModuleList([CpmAntTransformerBlock(config) for ith in range(self.num_layers)]) | |
| self.output_layernorm = CpmAntLayerNorm(config) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| position_bias: torch.Tensor, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | |
| use_cache: Optional[bool] = None, | |
| ): | |
| """ | |
| Args: | |
| hidden_states (`torch.Tensor`): | |
| Input to the layer of shape `(batch, seq_len, dim_model)` | |
| attention_mask (`torch.Tensor`): | |
| Avoid invalid areas to participate in the calculation of shape `(batch, seq_len, seq_len)` | |
| position_bias (`torch.Tensor`): | |
| Provides position information to attention mechanism of shape `(num_heads, seq_len, seq_len)` | |
| output_attentions (`bool`, *optional*): | |
| Whether or not to return the attentions tensors of all attention layers. | |
| output_hidden_states (`bool`, *optional*): | |
| Whether or not to return the hidden states of all layers. | |
| past_key_values (`Tuple[torch.Tensor, torch.Tensor])`, *optional*): | |
| Cached past key and value projection states | |
| use_cache (`bool`, *optional*): | |
| If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding | |
| (see `past_key_values`). | |
| """ | |
| all_hidden_states = () if output_hidden_states else None | |
| all_self_attns = () if output_attentions else None | |
| current_key_values = () if use_cache else None | |
| for i, layer in enumerate(self.layers): | |
| if output_hidden_states: | |
| all_hidden_states += (hidden_states,) | |
| layer_outputs = layer( | |
| hidden_states, | |
| attention_mask, | |
| position_bias, | |
| output_attentions=output_attentions, | |
| past_key_values=past_key_values[i] if past_key_values else None, | |
| use_cache=use_cache, | |
| ) | |
| hidden_states, attn_weights, current_key_value = layer_outputs | |
| if output_attentions: | |
| all_self_attns += (attn_weights,) | |
| if current_key_value is not None: | |
| current_key_values = current_key_values + (current_key_value,) | |
| hidden_states = self.output_layernorm(hidden_states) | |
| if output_hidden_states: | |
| all_hidden_states += (hidden_states,) | |
| return hidden_states, current_key_values, all_hidden_states, all_self_attns | |
| # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->CPMAnt | |
| class CpmAntIntermediate(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.dense = nn.Linear(config.hidden_size, config.intermediate_size) | |
| if isinstance(config.hidden_act, str): | |
| self.intermediate_act_fn = ACT2FN[config.hidden_act] | |
| else: | |
| self.intermediate_act_fn = config.hidden_act | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| hidden_states = self.dense(hidden_states) | |
| hidden_states = self.intermediate_act_fn(hidden_states) | |
| return hidden_states | |
| class CpmAntSegmentPositionEmbedding(nn.Module): | |
| def __init__(self, config: CpmAntConfig): | |
| super().__init__() | |
| self.num_heads = config.num_attention_heads | |
| self.num_buckets = config.position_bias_num_buckets | |
| self.max_distance = config.position_bias_max_distance | |
| self.num_segments = config.segment_types | |
| self.relative_attention_bias = nn.Parameter( | |
| torch.empty( | |
| config.segment_types * config.segment_types + config.position_bias_num_buckets, | |
| config.num_attention_heads, | |
| ) | |
| ) | |
| def forward( | |
| self, | |
| key_pos: torch.Tensor, | |
| query_pos: torch.Tensor, | |
| key_segment: torch.Tensor, | |
| query_segment: torch.Tensor, | |
| ): | |
| with torch.no_grad(): | |
| batch = key_pos.size(0) | |
| keylen = key_pos.size(1) | |
| querylen = query_pos.size(1) | |
| if key_pos.size(0) != query_pos.size(0): | |
| raise AssertionError( | |
| f"key_pos.size(0) should be equal to query_pos.size(0), but got {key_pos.size(0)} and {query_pos.size(0)}!" | |
| ) | |
| if keylen != key_segment.size(1) or querylen != query_segment.size(1): | |
| raise AssertionError( | |
| f"keylen should be equal to key_segment.size(1), but got {keylen} and {key_segment.size(1)}!" | |
| ) | |
| if querylen != query_segment.size(1): | |
| raise AssertionError( | |
| f"querylen should be equal to query_segment.size(1), but got {querylen} and {query_segment.szie(1)}!" | |
| ) | |
| key_pos = key_pos.view(batch, -1, keylen) | |
| query_pos = query_pos.view(batch, querylen, -1) | |
| key_segment = key_segment.view(batch, -1, keylen) | |
| query_segment = query_segment.view(batch, querylen, -1) | |
| relative_position_bucket = self._segment_relative_position_bucket(query_segment, key_segment) | |
| relative_position_bucket = relative_position_bucket + self.num_buckets | |
| # (batch, len_q, len_k) | |
| absolute_position_bucket = self._position_bucket( | |
| torch.arange(keylen, dtype=torch.int32, device=relative_position_bucket.device)[None, :] | |
| - torch.arange(querylen, dtype=torch.int32, device=relative_position_bucket.device)[:, None], | |
| num_buckets=self.num_buckets, | |
| max_distance=self.max_distance, | |
| ) | |
| relative_position_bucket = torch.where( | |
| (key_segment == query_segment), | |
| absolute_position_bucket[None, :, :], | |
| relative_position_bucket, | |
| ) | |
| # (batch, len_q, len_k, num_heads) | |
| embeds = F.embedding(relative_position_bucket, self.relative_attention_bias) | |
| # (batch, num_heads, len_q, len_k) | |
| embeds = embeds.permute(0, 3, 1, 2).contiguous() | |
| return embeds | |
| def _segment_relative_position_bucket(self, query_segment, key_segment): | |
| return query_segment * self.num_segments + key_segment | |
| def _position_bucket(self, relative_position, num_buckets=32, max_distance=128): | |
| relative_buckets = 0 | |
| # always bidirectional in CPMAnt | |
| num_buckets //= 2 | |
| relative_buckets = (relative_position > 0).to(torch.int32) * num_buckets | |
| relative_position = torch.abs(relative_position) | |
| max_exact = num_buckets // 2 | |
| is_small = relative_position < max_exact | |
| relative_postion_if_large = max_exact + ( | |
| torch.log(relative_position.float() / max_exact) | |
| / math.log(max_distance / max_exact) | |
| * (num_buckets - max_exact) | |
| ).to(torch.int32) | |
| relative_postion_if_large = torch.min( | |
| relative_postion_if_large, | |
| torch.full_like(relative_postion_if_large, num_buckets - 1), | |
| ) | |
| relative_buckets += torch.where(is_small, relative_position.to(torch.int32), relative_postion_if_large) | |
| return relative_buckets | |
| # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->CPMAnt | |
| class CpmAntOutput(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.dense = nn.Linear(config.intermediate_size, config.hidden_size) | |
| self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) | |
| self.dropout = nn.Dropout(config.hidden_dropout_prob) | |
| def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: | |
| hidden_states = self.dense(hidden_states) | |
| hidden_states = self.dropout(hidden_states) | |
| hidden_states = self.LayerNorm(hidden_states + input_tensor) | |
| return hidden_states | |
| class CpmAntPreTrainedModel(PreTrainedModel): | |
| """ | |
| An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained | |
| models. | |
| """ | |
| config_class = CpmAntConfig | |
| base_model_prefix = "cpmant" | |
| supports_gradient_checkpointing = True | |
| def _init_weights(self, module): | |
| """Initialize the weights""" | |
| if isinstance(module, nn.Linear): | |
| module.weight.data.normal_(mean=0.0, std=self.config.init_std) | |
| if module.bias is not None: | |
| module.bias.data.zero_() | |
| elif isinstance(module, nn.Embedding): | |
| module.weight.data.normal_(mean=0.0, std=self.config.init_std) | |
| if module.padding_idx is not None: | |
| module.weight.data[module.padding_idx].zero_() | |
| elif isinstance(module, nn.LayerNorm): | |
| module.bias.data.zero_() | |
| module.weight.data.fill_(1.0) | |
| elif isinstance(module, CpmAntLayerNorm): | |
| module.weight.data.fill_(1.0) | |
| elif isinstance(module, CpmAntSegmentPositionEmbedding): | |
| module.relative_attention_bias.data.normal_(mean=0.0, std=self.config.init_std) | |
| def _set_gradient_checkpointing(self, module, value=False): | |
| if isinstance(module, CpmAntEncoder): | |
| module.gradient_checkpointing = value | |
| CPMANT_START_DOCSTRING = r""" | |
| This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use | |
| it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and | |
| behavior. | |
| Parameters | |
| config ([`~CpmAntConfig`]): Model configuration class with all the parameters of the | |
| Initializing with a config file does not load the weights associated with the model, only the | |
| configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. | |
| """ | |
| CPMANT_INPUTS_DOCSTRING = r""" | |
| Args: | |
| input_ids (`torch.Tensor` of shape `(batch_size, seq_len)`): | |
| Indices of input sequence tokens in the vocabulary. | |
| Indices can be obtained using [`CPMAntTokenizer`]. See [`PreTrainedTokenizer.encode`] and | |
| [`PreTrainedTokenizer.__call__`] for details. | |
| [What are input IDs?](../glossary#input-ids) | |
| past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): | |
| Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention | |
| blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. | |
| use_cache (`bool`, *optional*): | |
| If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see | |
| `past_key_values`). | |
| output_attentions (`bool`, *optional*): | |
| Whether or not to return the attentions tensors of all attention layers. | |
| output_hidden_states (`bool`, *optional*): | |
| Whether or not to return the hidden states of all layers. | |
| return_dict (`bool`, *optional*): | |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. | |
| """ | |
| class CpmAntModel(CpmAntPreTrainedModel): | |
| def __init__(self, config: CpmAntConfig): | |
| super().__init__(config) | |
| self.encoder = CpmAntEncoder(config) | |
| self.segment_embedding = nn.Embedding(config.segment_types, config.hidden_size) | |
| self.input_embedding = nn.Embedding( | |
| config.vocab_size + config.prompt_types * config.prompt_length, config.hidden_size | |
| ) | |
| self.position_bias = CpmAntSegmentPositionEmbedding(config) | |
| self.prompt_length = config.prompt_length | |
| self.vocab_size = config.vocab_size | |
| self.post_init() | |
| def get_input_embeddings(self): | |
| return self.input_embedding | |
| def set_input_embeddings(self, embeddings, **kwargs): | |
| self.input_embedding = embeddings | |
| def _prepare_attention_mask(self, input_ids, span, context, length): | |
| batch = input_ids.size(0) | |
| seqlen = input_ids.size(1) | |
| device = input_ids.device | |
| directional_mask_2d = torch.arange(seqlen, device=device) <= torch.arange(seqlen, device=device).view(-1, 1) | |
| attention_mask = context[:, None, :] | ( | |
| context[:, :, None].logical_not() & directional_mask_2d.view(1, seqlen, seqlen) | |
| ) | |
| attention_mask = attention_mask & (span[:, None, :] == span[:, :, None]) | |
| # mask for left padding | |
| mask_1d = ( | |
| torch.tensor(list(range(seqlen - self.prompt_length))[::-1], device=device)[None, :].repeat(batch, 1) | |
| < length[:, None] | |
| ) | |
| mask_1d = torch.cat((torch.ones(batch, self.prompt_length, device=device).bool(), mask_1d), dim=1) | |
| attention_mask = mask_1d.view(batch, seqlen, 1) & mask_1d.view(batch, 1, seqlen) & attention_mask | |
| return attention_mask | |
| def forward( | |
| self, | |
| input_ids: Optional[torch.Tensor] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, | |
| use_cache: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| **kwargs, | |
| ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPast]: | |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| ) | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| use_cache = use_cache if use_cache is not None else self.config.use_cache | |
| # add prompts ahead | |
| if input_ids.dtype != torch.int32: | |
| input_ids = input_ids.to(torch.int32) | |
| dtype, device = input_ids.dtype, input_ids.device | |
| segment = torch.where(input_ids != 0, 2, 0).to(dtype=dtype, device=device) | |
| length = (segment != 0).sum(-1).to(dtype=dtype, device=device) | |
| input_ids = torch.cat( | |
| ( | |
| torch.arange( | |
| self.prompt_length * 2 + self.vocab_size, | |
| self.prompt_length * 3 + self.vocab_size, | |
| dtype=dtype, | |
| device=device, | |
| ).repeat(input_ids.size(0), 1), | |
| input_ids, | |
| ), | |
| dim=1, | |
| ) | |
| batch, seq_length = input_ids.size() | |
| segment = torch.cat((torch.zeros(batch, self.prompt_length, dtype=dtype, device=device), segment), dim=1) | |
| context = torch.full((batch, seq_length), 1, dtype=dtype, device=device) | |
| position = torch.arange(seq_length, dtype=dtype, device=device).repeat(batch, 1) | |
| span = torch.full((batch, seq_length), 0, dtype=dtype, device=device) | |
| if past_key_values is None: | |
| past_length = 0 | |
| past_key_values = tuple([None] * self.encoder.num_layers) | |
| input_ids = input_ids.contiguous() | |
| hidden_states = self.input_embedding(input_ids) | |
| segment_states = self.segment_embedding(segment) | |
| hidden_states = hidden_states + segment_states | |
| else: | |
| past_length = past_key_values[0][0].size(-2) | |
| segment_states = self.segment_embedding(segment) | |
| hidden_states = self.input_embedding(input_ids) + segment_states[:, -1:, :] | |
| attention_mask = self._prepare_attention_mask(input_ids, span, context, length) | |
| position_bias = self.position_bias(position, position, segment, segment) | |
| attention_mask = attention_mask[:, past_length:, :] | |
| position_bias = position_bias[:, :, past_length:, :] | |
| hidden_states = hidden_states[:, past_length:, :] | |
| hidden_states, present_key_values, all_hidden_states, all_attentions = self.encoder( | |
| hidden_states, | |
| attention_mask, | |
| position_bias, | |
| output_attentions, | |
| output_hidden_states, | |
| past_key_values, | |
| use_cache, | |
| ) | |
| if past_length == 0: | |
| hidden_states = hidden_states[:, self.prompt_length :, :] | |
| # drop the prompt | |
| if all_attentions is not None: | |
| new_attentions = () | |
| for attention in all_attentions: | |
| new_attentions += (attention[:, :, self.prompt_length :, self.prompt_length :],) | |
| all_attentions = new_attentions | |
| if all_hidden_states is not None: | |
| new_hidden_states = () | |
| for hidden_state in all_hidden_states: | |
| new_hidden_states += (hidden_state[:, self.prompt_length :, :],) | |
| all_hidden_states = new_hidden_states | |
| if not return_dict: | |
| return tuple( | |
| v for v in [hidden_states, present_key_values, all_hidden_states, all_attentions] if v is not None | |
| ) | |
| return BaseModelOutputWithPast( | |
| last_hidden_state=hidden_states, | |
| past_key_values=present_key_values, | |
| hidden_states=all_hidden_states, | |
| attentions=all_attentions, | |
| ) | |
| class CpmAntForCausalLM(CpmAntPreTrainedModel): | |
| _tied_weights_keys = ["lm_head.weight"] | |
| def __init__(self, config: CpmAntConfig): | |
| super().__init__(config) | |
| self.cpmant = CpmAntModel(config) | |
| # lm_head.weight is tied to cpmant.input_embedding.weight | |
| self.lm_head = nn.Linear( | |
| config.hidden_size, config.vocab_size + config.prompt_types * config.prompt_length, bias=False | |
| ) | |
| self.post_init() | |
| def forward( | |
| self, | |
| input_ids: Optional[torch.Tensor] = None, | |
| past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| labels: Optional[torch.Tensor] = None, | |
| return_dict: Optional[bool] = None, | |
| attention_mask: Optional[torch.Tensor] = None, # dummy parameter for text-generation pipeline | |
| **kwargs, | |
| ) -> Union[Tuple, CausalLMOutputWithPast]: | |
| r""" | |
| Args: | |
| input_ids (`torch.Tensor` of shape `(batch_size, seq_len)`): | |
| Indices of input sequence tokens in the vocabulary. | |
| Indices can be obtained using [`CPMAntTokenizer`]. See [`PreTrainedTokenizer.encode`] and | |
| [`PreTrainedTokenizer.__call__`] for details. | |
| [What are input IDs?](../glossary#input-ids) | |
| past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): | |
| Contains pre-computed hidden-states (key and values in the self-attention blocks and in the | |
| cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. | |
| use_cache (`bool`, *optional*): | |
| If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding | |
| (see `past_key_values`). | |
| output_attentions (`bool`, *optional*): | |
| Whether or not to return the attentions tensors of all attention layers. | |
| output_hidden_states (`bool`, *optional*): | |
| Whether or not to return the hidden states of all layers. | |
| labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): | |
| Labels for computing the masked language modeling loss. | |
| return_dict (`bool`, *optional*): | |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. | |
| attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): | |
| CPMAnt will process attention mask automatically, this parameter is a dummy parameter for | |
| text-generation pipeline. | |
| Example: | |
| Text Generation with CpmAntForCausalLM. | |
| ```python | |
| >>> from transformers import CPMAntTokenizer, CpmAntForCausalLM | |
| >>> texts = "今天天气不错," | |
| >>> model = CpmAntForCausalLM.from_pretrained("openbmb/cpm-ant-10b") | |
| >>> tokenizer = CPMAntTokenizer.from_pretrained("openbmb/cpm-ant-10b") | |
| >>> input_ids = tokenizer(texts, return_tensors="pt") | |
| >>> outputs = model.generate(**input_ids) | |
| >>> output_texts = tokenizer.batch_decode(outputs) | |
| >>> print(output_texts) | |
| ['今天天气不错,阳光明媚,我和妈妈一起去超市买东西。\n在超市里,我看到了一个很好玩的玩具,它的名字叫“机器人”。它有一个圆圆的脑袋,两只圆圆的眼睛,还有一个圆圆的'] | |
| ``` | |
| """ | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| model_output = self.cpmant( | |
| input_ids, output_attentions, output_hidden_states, past_key_values, use_cache, return_dict | |
| ) | |
| hidden_states = model_output.last_hidden_state if return_dict else model_output[0] | |
| logits = self.lm_head(hidden_states) | |
| loss = None | |
| if labels is not None: | |
| loss_func = CrossEntropyLoss() | |
| loss = loss_func(logits.view(-1, logits.size(-1)), labels.view(-1)) | |
| if not return_dict: | |
| output = (logits,) + model_output[1:] | |
| return ((loss,) + output) if loss is not None else output | |
| return CausalLMOutputWithPast( | |
| loss=loss, | |
| logits=logits, | |
| past_key_values=model_output.past_key_values, | |
| hidden_states=model_output.hidden_states, | |
| attentions=model_output.attentions, | |
| ) | |
| def get_input_embeddings(self): | |
| return self.cpmant.input_embedding | |
| def set_input_embeddings(self, embeddings): | |
| self.cpmant.input_embedding = embeddings | |
| def get_output_embeddings(self): | |
| return self.lm_head | |
| def set_output_embeddings(self, new_embeddings): | |
| self.lm_head = new_embeddings | |
| def prepare_inputs_for_generation(self, input_ids, **kwargs): | |
| input_ids = input_ids.int() | |
| # save the memory usage of dummy attention mask | |
| if "attention_mask" in kwargs: | |
| kwargs["attention_mask"] = torch.zeros(1, 1) | |
| return { | |
| "input_ids": input_ids, | |
| "use_cache": kwargs["use_cache"], | |
| "past_key_values": kwargs.get("past_key_values", None), | |
| } | |
| def _reorder_cache(self, past_key_values, beam_idx): | |
| past_key_values = [list(each) if each is not None else each for each in past_key_values] | |
| for key_value_layer in past_key_values: | |
| key_value_layer[0] = key_value_layer[0][beam_idx] | |
| key_value_layer[1] = key_value_layer[1][beam_idx] | |
| return past_key_values | |