Spaces:
Runtime error
Runtime error
| import math | |
| from dataclasses import dataclass | |
| from typing import Tuple, Optional, Literal | |
| from einops import rearrange | |
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| import torch.distributed as dist | |
| from kernel import act_quant, fp8_gemm, fp8_index | |
| world_size = 1 | |
| rank = 0 | |
| block_size = 128 | |
| class ModelArgs: | |
| """ | |
| Data class for defining model arguments and hyperparameters. | |
| Attributes: | |
| max_batch_size (int): Maximum batch size. | |
| max_seq_len (int): Maximum sequence length. | |
| dtype (Literal["bf16", "fp8"]): Data type for computations. | |
| scale_fmt (Optional[str]): Format for quantization scale. | |
| vocab_size (int): Vocabulary size. | |
| dim (int): Model dimension. | |
| inter_dim (int): Intermediate dimension for MLP layers. | |
| moe_inter_dim (int): Intermediate dimension for MoE layers. | |
| n_layers (int): Number of transformer layers. | |
| n_dense_layers (int): Number of dense layers in the model. | |
| n_heads (int): Number of attention heads. | |
| n_routed_experts (int): Number of routed experts for MoE layers. | |
| n_shared_experts (int): Number of shared experts for MoE layers. | |
| n_activated_experts (int): Number of activated experts in MoE layers. | |
| n_expert_groups (int): Number of expert groups. | |
| n_limited_groups (int): Number of limited groups for MoE routing. | |
| score_func (Literal["softmax", "sigmoid"]): Scoring function for MoE routing. | |
| route_scale (float): Scaling factor for routing scores. | |
| q_lora_rank (int): LoRA rank for query projections. | |
| kv_lora_rank (int): LoRA rank for key-value projections. | |
| qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings. | |
| qk_rope_head_dim (int): Dimension for query-key projections with rotary embeddings. | |
| v_head_dim (int): Dimension for value projections. | |
| original_seq_len (int): Original sequence length. | |
| rope_theta (float): Base for rotary positional encoding. | |
| rope_factor (float): Scaling factor for extended sequence lengths. | |
| beta_fast (int): Fast beta correction factor. | |
| beta_slow (int): Slow beta correction factor. | |
| mscale (float): Scaling factor for extended attention. | |
| index_head_dim (int): Dimension for index head. | |
| index_topk (int): Top-k for index head. | |
| """ | |
| max_batch_size: int = 8 | |
| max_seq_len: int = 4096 * 4 | |
| dtype: Literal["bf16", "fp8"] = "bf16" | |
| scale_fmt: Optional[str] = None | |
| vocab_size: int = 102400 | |
| dim: int = 2048 | |
| inter_dim: int = 10944 | |
| moe_inter_dim: int = 1408 | |
| n_layers: int = 27 | |
| n_dense_layers: int = 1 | |
| n_heads: int = 16 | |
| # moe | |
| n_routed_experts: int = 64 | |
| n_shared_experts: int = 2 | |
| n_activated_experts: int = 6 | |
| n_expert_groups: int = 1 | |
| n_limited_groups: int = 1 | |
| score_func: Literal["softmax", "sigmoid"] = "softmax" | |
| route_scale: float = 1. | |
| # mla | |
| q_lora_rank: int = 0 | |
| kv_lora_rank: int = 512 | |
| qk_nope_head_dim: int = 128 | |
| qk_rope_head_dim: int = 64 | |
| v_head_dim: int = 128 | |
| # yarn | |
| original_seq_len: int = 4096 | |
| rope_theta: float = 10000.0 | |
| rope_factor: float = 40 | |
| beta_fast: int = 32 | |
| beta_slow: int = 1 | |
| mscale: float = 1. | |
| # index | |
| index_n_heads: int = 64 | |
| index_head_dim: int = 128 | |
| index_topk: int = 2048 | |
| class ParallelEmbedding(nn.Module): | |
| """ | |
| Embedding layer with parallelism support across distributed processes. | |
| Args: | |
| vocab_size (int): Vocabulary size. | |
| dim (int): Embedding dimension. | |
| """ | |
| def __init__(self, vocab_size: int, dim: int): | |
| super().__init__() | |
| self.vocab_size = vocab_size | |
| self.dim = dim | |
| assert vocab_size % world_size == 0, f"Vocabulary size must be divisible by world size (world_size={world_size})" | |
| self.part_vocab_size = (vocab_size // world_size) | |
| self.vocab_start_idx = rank * self.part_vocab_size | |
| self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size | |
| self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim)) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Forward pass for parallel embedding layer. | |
| Args: | |
| x (torch.Tensor): Input tensor containing token indices. | |
| Returns: | |
| torch.Tensor: Embedded representations. | |
| Raises: | |
| ValueError: If `world_size` is not defined. | |
| """ | |
| if world_size > 1: | |
| mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx) | |
| x = x - self.vocab_start_idx | |
| x[mask] = 0 | |
| y = F.embedding(x, self.weight) | |
| if world_size > 1: | |
| y[mask] = 0 | |
| dist.all_reduce(y) | |
| return y | |
| def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, | |
| scale_fmt: Optional[str] = None) -> torch.Tensor: | |
| """ | |
| Applies a linear transformation to the incoming data: y = xA^T + b. | |
| This function supports specialized implementations based on quantization | |
| and tensor formats. | |
| Args: | |
| x (torch.Tensor): The input tensor. | |
| weight (torch.Tensor): The weight tensor. It may be quantized and | |
| requires dequantization for certain cases. | |
| bias (Optional[torch.Tensor]): The bias tensor to be added. Default is None. | |
| scale_fmt (Optional[str]): The format of scaling factors. | |
| Returns: | |
| torch.Tensor: The result of the linear transformation, which may involve | |
| quantization-aware computations depending on the input parameters. | |
| Notes: | |
| - If `weight` is quantized (e.g., `element_size() == 1`), a dequantized version | |
| is used for computation. | |
| - For other cases, the function applies quantization to `x` and uses `fp8_gemm` for computation. | |
| """ | |
| assert bias is None | |
| if weight.dtype != torch.float8_e4m3fn: | |
| return F.linear(x, weight) | |
| else: | |
| x, scale = act_quant(x, block_size, scale_fmt) | |
| return fp8_gemm(x, scale, weight, weight.scale) | |
| class Linear(nn.Module): | |
| """ | |
| Custom linear layer with support for quantized weights and optional bias. | |
| Args: | |
| in_features (int): Number of input features. | |
| out_features (int): Number of output features. | |
| bias (bool): Whether to include a bias term. Defaults to False. | |
| dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. | |
| """ | |
| dtype = torch.bfloat16 | |
| scale_fmt: Optional[str] = None | |
| def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): | |
| super().__init__() | |
| self.in_features = in_features | |
| self.out_features = out_features | |
| self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype or Linear.dtype)) | |
| if self.weight.element_size() == 1: | |
| scale_out_features = (out_features + block_size - 1) // block_size | |
| scale_in_features = (in_features + block_size - 1) // block_size | |
| self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float32)) | |
| else: | |
| self.register_parameter("scale", None) | |
| if bias: | |
| self.bias = nn.Parameter(torch.empty(out_features)) | |
| else: | |
| self.register_parameter("bias", None) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Forward pass for the custom linear layer. | |
| Args: | |
| x (torch.Tensor): Input tensor. | |
| Returns: | |
| torch.Tensor: Transformed tensor after linear computation. | |
| """ | |
| return linear(x, self.weight, self.bias, self.scale_fmt) | |
| class ColumnParallelLinear(Linear): | |
| """ | |
| Linear layer with column parallelism, splitting output features across distributed processes. | |
| Args: | |
| in_features (int): Number of input features. | |
| out_features (int): Total number of output features. | |
| bias (bool): Whether to include a bias term. Defaults to False. | |
| dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. | |
| """ | |
| def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): | |
| assert out_features % world_size == 0, f"Output features must be divisible by world size (world_size={world_size})" | |
| self.part_out_features = out_features // world_size | |
| super().__init__(in_features, self.part_out_features, bias, dtype) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Forward pass for column parallel linear layer. | |
| Args: | |
| x (torch.Tensor): Input tensor. | |
| Returns: | |
| torch.Tensor: Transformed tensor with column-parallel computation. | |
| """ | |
| y = linear(x, self.weight, self.bias, self.scale_fmt) | |
| return y | |
| class RowParallelLinear(Linear): | |
| """ | |
| Linear layer with row parallelism, splitting input features across distributed processes. | |
| Args: | |
| in_features (int): Total number of input features. | |
| out_features (int): Number of output features. | |
| bias (bool): Whether to include a bias term. Defaults to False. | |
| dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. | |
| """ | |
| def __init__(self, in_features: int, out_features: int, bias: bool = False, reduce_output = True, dtype = None): | |
| assert in_features % world_size == 0, f"Input features must be divisible by world size (world_size={world_size})" | |
| self.part_in_features = in_features // world_size | |
| self.reduce_output = reduce_output | |
| super().__init__(self.part_in_features, out_features, bias, dtype) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Forward pass for row parallel linear layer. | |
| Args: | |
| x (torch.Tensor): Input tensor. | |
| Returns: | |
| torch.Tensor: Transformed tensor with row-parallel computation. | |
| """ | |
| y = linear(x, self.weight, None, self.scale_fmt) | |
| if self.reduce_output and world_size > 1: | |
| y = y.float() | |
| dist.all_reduce(y) | |
| if self.bias is not None: | |
| y += self.bias | |
| return y.type_as(x) | |
| class RMSNorm(nn.Module): | |
| """ | |
| Root Mean Square Layer Normalization (RMSNorm). | |
| Args: | |
| dim (int): Dimension of the input tensor. | |
| eps (float): Epsilon value for numerical stability. Defaults to 1e-6. | |
| """ | |
| def __init__(self, dim: int, eps: float = 1e-6): | |
| super().__init__() | |
| self.dim = dim | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32)) | |
| def forward(self, x: torch.Tensor, residual: Optional[torch.Tensor] = None): | |
| """ | |
| Forward pass for RMSNorm. | |
| Args: | |
| x (torch.Tensor): Input tensor. | |
| Returns: | |
| torch.Tensor: Normalized tensor with the same shape as input. | |
| """ | |
| dtype = x.dtype | |
| if residual is None: | |
| x = x.float() | |
| var = x.pow(2).mean(-1, keepdim=True) | |
| x = x * torch.rsqrt(var + self.eps) | |
| return (self.weight * x).to(dtype) | |
| else: | |
| x = residual = x.float() + residual.float() | |
| var = x.pow(2).mean(-1, keepdim=True) | |
| x = x * torch.rsqrt(var + self.eps) | |
| return (self.weight * x).to(dtype), residual.to(dtype) | |
| class LayerNorm(nn.Module): | |
| """ | |
| Layer Normalization. | |
| """ | |
| def __init__(self, dim: int, eps: float = 1e-6): | |
| super().__init__() | |
| self.dim = dim | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32)) | |
| self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) | |
| def forward(self, x: torch.Tensor): | |
| return F.layer_norm(x.float(), (self.dim,), self.weight, self.bias, self.eps).type_as(x) | |
| def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor: | |
| """ | |
| Precomputes frequency-based complex exponential values for rotary positional embeddings. | |
| Args: | |
| args (ModelArgs): Model arguments containing positional embedding parameters. | |
| Returns: | |
| torch.Tensor: Precomputed complex exponential values for positional embeddings. | |
| """ | |
| dim = args.qk_rope_head_dim | |
| seqlen = args.max_seq_len | |
| beta_fast = args.beta_fast | |
| beta_slow = args.beta_slow | |
| base = args.rope_theta | |
| factor = args.rope_factor | |
| def find_correction_dim(num_rotations, dim, base, max_seq_len): | |
| """ | |
| Computes the correction dimension for a given number of rotations in the rotary positional embedding. | |
| Args: | |
| num_rotations (float): Number of rotations to compute the correction for. | |
| dim (int): Dimensionality of the embedding space. | |
| base (float): Base value for the exponential computation. | |
| max_seq_len (int): Maximum sequence length. | |
| Returns: | |
| float: The correction dimension based on the input parameters. | |
| """ | |
| return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base)) | |
| def find_correction_range(low_rot, high_rot, dim, base, max_seq_len): | |
| """ | |
| Computes the range of correction dimensions for rotary positional embeddings. | |
| Args: | |
| low_rot (float): Lower bound for the number of rotations. | |
| high_rot (float): Upper bound for the number of rotations. | |
| dim (int): Dimensionality of the embedding space. | |
| base (float): Base value for the exponential computation. | |
| max_seq_len (int): Maximum sequence length. | |
| Returns: | |
| Tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices. | |
| """ | |
| low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len)) | |
| high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len)) | |
| return max(low, 0), min(high, dim-1) | |
| def linear_ramp_factor(min, max, dim): | |
| """ | |
| Computes a linear ramp function used to smooth values between a minimum and maximum range. | |
| Args: | |
| min (float): Minimum value for the ramp function. | |
| max (float): Maximum value for the ramp function. | |
| dim (int): Dimensionality of the ramp tensor. | |
| Returns: | |
| torch.Tensor: A tensor of shape (dim,) with values linearly interpolated between 0 and 1, | |
| clamped to the range [0, 1]. | |
| """ | |
| if min == max: | |
| max += 0.001 | |
| linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) | |
| ramp_func = torch.clamp(linear_func, 0, 1) | |
| return ramp_func | |
| freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) | |
| if seqlen > args.original_seq_len: | |
| low, high = find_correction_range(beta_fast, beta_slow, dim, base, args.original_seq_len) | |
| smooth = 1 - linear_ramp_factor(low, high, dim // 2) | |
| freqs = freqs / factor * (1 - smooth) + freqs * smooth | |
| t = torch.arange(seqlen) | |
| freqs = torch.outer(t, freqs) | |
| freqs_cis = torch.polar(torch.ones_like(freqs), freqs) | |
| return freqs_cis | |
| def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Applies rotary positional embeddings to the input tensor. | |
| Args: | |
| x (torch.Tensor): Input tensor with positional embeddings to be applied. | |
| freqs_cis (torch.Tensor): Precomputed complex exponential values for positional embeddings. | |
| Returns: | |
| torch.Tensor: Tensor with rotary embeddings applied. | |
| """ | |
| dtype = x.dtype | |
| x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2)) | |
| freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1)) | |
| y = torch.view_as_real(x * freqs_cis).flatten(3) | |
| return y.to(dtype) | |
| def rotate_activation(x: torch.Tensor) -> torch.Tensor: | |
| assert x.dtype == torch.bfloat16 | |
| from fast_hadamard_transform import hadamard_transform | |
| hidden_size = x.size(-1) | |
| return hadamard_transform(x, scale=hidden_size ** -0.5) | |
| class Indexer(torch.nn.Module): | |
| def __init__(self, args: ModelArgs): | |
| super().__init__() | |
| self.dim: int = args.dim | |
| self.n_heads: int = args.index_n_heads | |
| self.n_local_heads = args.index_n_heads // world_size | |
| self.head_dim: int = args.index_head_dim | |
| self.rope_head_dim: int = args.qk_rope_head_dim | |
| self.index_topk: int = args.index_topk | |
| self.q_lora_rank: int = args.q_lora_rank | |
| self.wq_b = Linear(self.q_lora_rank, self.n_heads * self.head_dim) | |
| self.wk = Linear(self.dim, self.head_dim) | |
| self.k_norm = LayerNorm(self.head_dim) | |
| self.weights_proj = Linear(self.dim, self.n_heads, dtype=torch.get_default_dtype()) | |
| self.softmax_scale = self.head_dim ** -0.5 | |
| self.scale_fmt = args.scale_fmt | |
| self.register_buffer("k_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.head_dim, dtype=torch.float8_e4m3fn), persistent=False) | |
| self.register_buffer("k_scale_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.head_dim // block_size, dtype=torch.float32), persistent=False) | |
| def forward(self, x: torch.Tensor, qr: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]): | |
| bsz, seqlen, _ = x.size() | |
| end_pos = start_pos + seqlen | |
| q = self.wq_b(qr) | |
| q = rearrange(q, 'b s (h d) -> b s h d', d=self.head_dim) | |
| q_pe, q_nope = torch.split(q, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1) | |
| q_pe = apply_rotary_emb(q_pe, freqs_cis) | |
| q = torch.cat([q_pe, q_nope], dim=-1) | |
| k = self.wk(x) | |
| k = self.k_norm(k) | |
| k_pe, k_nope = torch.split(k, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1) | |
| k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis).squeeze(2) | |
| k = torch.cat([k_pe, k_nope], dim=-1) | |
| q = rotate_activation(q) | |
| k = rotate_activation(k) | |
| q_fp8, q_scale = act_quant(q, block_size, self.scale_fmt) | |
| k_fp8, k_scale = act_quant(k, block_size, self.scale_fmt) | |
| self.k_cache[:bsz, start_pos:end_pos] = k_fp8 | |
| self.k_scale_cache[:bsz, start_pos:end_pos] = k_scale | |
| weights = self.weights_proj(x) * self.n_heads ** -0.5 | |
| weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale | |
| index_score = fp8_index(q_fp8.contiguous(), weights, self.k_cache[:bsz, :end_pos].contiguous(), self.k_scale_cache[:bsz, :end_pos].contiguous()) | |
| if mask is not None: | |
| index_score += mask | |
| topk_indices = index_score.topk(min(self.index_topk, end_pos), dim=-1)[1] | |
| topk_indices_ = topk_indices.clone() | |
| dist.broadcast(topk_indices_, src=0) | |
| assert torch.all(topk_indices == topk_indices_), f"{topk_indices=} {topk_indices_=}" | |
| return topk_indices | |
| def weight_dequant(weight, scale): | |
| shape = weight.shape | |
| assert weight.dim() == 2 | |
| weight = weight.view(shape[0] // block_size, block_size, shape[1] // block_size, block_size).transpose(1, 2).contiguous().view(-1, block_size * block_size) | |
| weight = (weight.float() * scale.view(-1, 1).float()).to(torch.get_default_dtype()).view(shape[0] // block_size, shape[1] // block_size, block_size, block_size).transpose(1, 2).contiguous().view(shape) | |
| return weight | |
| class MLA(nn.Module): | |
| """ | |
| Multi-Head Latent Attention (MLA) Layer. | |
| Attributes: | |
| dim (int): Dimensionality of the input features. | |
| n_heads (int): Number of attention heads. | |
| n_local_heads (int): Number of local attention heads for distributed systems. | |
| q_lora_rank (int): Rank for low-rank query projection. | |
| kv_lora_rank (int): Rank for low-rank key/value projection. | |
| qk_nope_head_dim (int): Dimensionality of non-positional query/key projections. | |
| qk_rope_head_dim (int): Dimensionality of rotary-positional query/key projections. | |
| qk_head_dim (int): Total dimensionality of query/key projections. | |
| v_head_dim (int): Dimensionality of value projections. | |
| softmax_scale (float): Scaling factor for softmax in attention computation. | |
| """ | |
| def __init__(self, args: ModelArgs): | |
| super().__init__() | |
| self.dim = args.dim | |
| self.n_heads = args.n_heads | |
| self.n_local_heads = args.n_heads // world_size | |
| self.q_lora_rank = args.q_lora_rank | |
| self.kv_lora_rank = args.kv_lora_rank | |
| self.qk_nope_head_dim = args.qk_nope_head_dim | |
| self.qk_rope_head_dim = args.qk_rope_head_dim | |
| self.qk_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim | |
| self.v_head_dim = args.v_head_dim | |
| self.wq_a = Linear(self.dim, self.q_lora_rank) | |
| self.q_norm = RMSNorm(self.q_lora_rank) | |
| self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.qk_head_dim) | |
| self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim) | |
| self.kv_norm = RMSNorm(self.kv_lora_rank) | |
| self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim)) | |
| self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim) | |
| self.softmax_scale = self.qk_head_dim ** -0.5 | |
| if args.max_seq_len > args.original_seq_len: | |
| mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0 | |
| self.softmax_scale = self.softmax_scale * mscale * mscale | |
| self.indexer = Indexer(args) | |
| self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False) | |
| self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False) | |
| self.dequant_wkv_b = None | |
| def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]): | |
| """ | |
| Forward pass for the Multi-Head Latent Attention (MLA) Layer. | |
| Args: | |
| x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim). | |
| start_pos (int): Starting position in the sequence for caching. | |
| freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. | |
| mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention. | |
| Returns: | |
| torch.Tensor: Output tensor with the same shape as the input. | |
| """ | |
| bsz, seqlen, _ = x.size() | |
| end_pos = start_pos + seqlen | |
| qr = self.q_norm(self.wq_a(x)) | |
| q = self.wq_b(qr) | |
| q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim) | |
| q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) | |
| q_pe = apply_rotary_emb(q_pe, freqs_cis) | |
| kv = self.wkv_a(x) | |
| kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) | |
| kv = self.kv_norm(kv) | |
| k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis) | |
| self.kv_cache[:bsz, start_pos:end_pos] = kv | |
| self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2) | |
| if mask is not None: # MHA prefill | |
| q = torch.cat([q_nope, q_pe], dim=-1) | |
| kv = self.wkv_b(kv) | |
| kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim) | |
| k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) | |
| k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1) | |
| scores = torch.einsum("bshd,bthd->bsht", q.float(), k.float()) * self.softmax_scale | |
| # indexer | |
| topk_indices = self.indexer(x, qr, start_pos, freqs_cis, mask) | |
| index_mask = torch.full((bsz, seqlen, seqlen), float("-inf"), device=x.device).scatter_(-1, topk_indices, 0) | |
| index_mask += mask | |
| scores += index_mask.unsqueeze(2) | |
| scores = scores.softmax(dim=-1, dtype=torch.float32) | |
| x = torch.einsum("bsht,bthd->bshd", scores.type_as(x), v) | |
| else: # MHA decode | |
| if self.dequant_wkv_b is None and self.wkv_b.scale is not None: | |
| self.dequant_wkv_b = weight_dequant(self.wkv_b.weight, self.wkv_b.scale) | |
| wkv_b = self.wkv_b.weight if self.dequant_wkv_b is None else self.dequant_wkv_b | |
| wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank) | |
| q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim]) | |
| scores = (torch.einsum("bshc,btc->bsht", q_nope.float(), self.kv_cache[:bsz, :end_pos].float()) + | |
| torch.einsum("bshr,btr->bsht", q_pe.float(), self.pe_cache[:bsz, :end_pos].float())) * self.softmax_scale | |
| # indexer | |
| topk_indices = self.indexer(x, qr, start_pos, freqs_cis, mask) | |
| index_mask = torch.full((bsz, 1, end_pos), float("-inf"), device=x.device).scatter_(-1, topk_indices, 0) | |
| scores += index_mask.unsqueeze(2) | |
| scores = scores.softmax(dim=-1, dtype=torch.float32) | |
| x = torch.einsum("bsht,btc->bshc", scores.type_as(x), self.kv_cache[:bsz, :end_pos]) | |
| x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:]) | |
| x = self.wo(x.flatten(2)) | |
| return x | |
| class MLP(nn.Module): | |
| """ | |
| Multi-Layer Perceptron (MLP) used as a feed-forward layer. | |
| Attributes: | |
| w1 (nn.Module): Linear layer for input-to-hidden transformation. | |
| w2 (nn.Module): Linear layer for hidden-to-output transformation. | |
| w3 (nn.Module): Additional linear layer for feature transformation. | |
| """ | |
| def __init__(self, dim: int, inter_dim: int, reduce_output: bool = True): | |
| """ | |
| Initializes the MLP layer. | |
| Args: | |
| dim (int): Input and output dimensionality. | |
| inter_dim (int): Hidden layer dimensionality. | |
| """ | |
| super().__init__() | |
| self.w1 = ColumnParallelLinear(dim, inter_dim) | |
| self.w2 = RowParallelLinear(inter_dim, dim, reduce_output=reduce_output) | |
| self.w3 = ColumnParallelLinear(dim, inter_dim) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Forward pass for the MLP layer. | |
| Args: | |
| x (torch.Tensor): Input tensor. | |
| Returns: | |
| torch.Tensor: Output tensor after MLP computation. | |
| """ | |
| return self.w2((F.silu(self.w1(x).float()) * self.w3(x).float()).type_as(x)) | |
| class Gate(nn.Module): | |
| """ | |
| Gating mechanism for routing inputs in a mixture-of-experts (MoE) model. | |
| Attributes: | |
| dim (int): Dimensionality of input features. | |
| topk (int): Number of top experts activated for each input. | |
| n_groups (int): Number of groups for routing. | |
| topk_groups (int): Number of groups to route inputs to. | |
| score_func (str): Scoring function ('softmax' or 'sigmoid'). | |
| route_scale (float): Scaling factor for routing weights. | |
| weight (torch.nn.Parameter): Learnable weights for the gate. | |
| bias (Optional[torch.nn.Parameter]): Optional bias term for the gate. | |
| """ | |
| def __init__(self, args: ModelArgs): | |
| """ | |
| Initializes the Gate module. | |
| Args: | |
| args (ModelArgs): Model arguments containing gating parameters. | |
| """ | |
| super().__init__() | |
| self.dim = args.dim | |
| self.topk = args.n_activated_experts | |
| self.n_groups = args.n_expert_groups | |
| self.topk_groups = args.n_limited_groups | |
| self.score_func = args.score_func | |
| self.route_scale = args.route_scale | |
| self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim)) | |
| self.bias = nn.Parameter(torch.empty(args.n_routed_experts, dtype=torch.float32)) if self.dim == 7168 else None | |
| def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Forward pass for the gating mechanism. | |
| Args: | |
| x (torch.Tensor): Input tensor. | |
| Returns: | |
| Tuple[torch.Tensor, torch.Tensor]: Routing weights and selected expert indices. | |
| """ | |
| scores = linear(x.float(), self.weight.float()) | |
| if self.score_func == "softmax": | |
| scores = scores.softmax(dim=-1) | |
| else: | |
| scores = scores.sigmoid() | |
| original_scores = scores | |
| if self.bias is not None: | |
| scores = scores + self.bias | |
| if self.n_groups > 1: | |
| scores = scores.view(x.size(0), self.n_groups, -1) | |
| if self.bias is None: | |
| group_scores = scores.amax(dim=-1) | |
| else: | |
| group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1) | |
| indices = group_scores.topk(self.topk_groups, dim=-1)[1] | |
| mask = scores.new_ones(x.size(0), self.n_groups, dtype=bool).scatter_(1, indices, False) | |
| scores = scores.masked_fill_(mask.unsqueeze(-1), float("-inf")).flatten(1) | |
| indices = scores.topk(self.topk, dim=-1)[1] | |
| weights = original_scores.gather(1, indices) | |
| if self.score_func == "sigmoid": | |
| weights /= weights.sum(dim=-1, keepdim=True) | |
| weights *= self.route_scale | |
| return weights, indices | |
| class Expert(nn.Module): | |
| """ | |
| Expert layer for Mixture-of-Experts (MoE) models. | |
| Attributes: | |
| w1 (nn.Module): Linear layer for input-to-hidden transformation. | |
| w2 (nn.Module): Linear layer for hidden-to-output transformation. | |
| w3 (nn.Module): Additional linear layer for feature transformation. | |
| """ | |
| def __init__(self, dim: int, inter_dim: int): | |
| """ | |
| Initializes the Expert layer. | |
| Args: | |
| dim (int): Input and output dimensionality. | |
| inter_dim (int): Hidden layer dimensionality. | |
| """ | |
| super().__init__() | |
| self.w1 = Linear(dim, inter_dim) | |
| self.w2 = Linear(inter_dim, dim) | |
| self.w3 = Linear(dim, inter_dim) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Forward pass for the Expert layer. | |
| Args: | |
| x (torch.Tensor): Input tensor. | |
| Returns: | |
| torch.Tensor: Output tensor after expert computation. | |
| """ | |
| return self.w2((F.silu(self.w1(x).float()) * self.w3(x).float()).type_as(x)) | |
| class MoE(nn.Module): | |
| """ | |
| Mixture-of-Experts (MoE) module. | |
| Attributes: | |
| dim (int): Dimensionality of input features. | |
| n_routed_experts (int): Total number of experts in the model. | |
| n_local_experts (int): Number of experts handled locally in distributed systems. | |
| n_activated_experts (int): Number of experts activated for each input. | |
| gate (nn.Module): Gating mechanism to route inputs to experts. | |
| experts (nn.ModuleList): List of expert modules. | |
| shared_experts (nn.Module): Shared experts applied to all inputs. | |
| """ | |
| def __init__(self, args: ModelArgs): | |
| """ | |
| Initializes the MoE module. | |
| Args: | |
| args (ModelArgs): Model arguments containing MoE parameters. | |
| """ | |
| super().__init__() | |
| self.dim = args.dim | |
| assert args.n_routed_experts % world_size == 0, f"Number of experts must be divisible by world size (world_size={world_size})" | |
| self.n_routed_experts = args.n_routed_experts | |
| self.n_local_experts = args.n_routed_experts // world_size | |
| self.n_activated_experts = args.n_activated_experts | |
| self.experts_start_idx = rank * self.n_local_experts | |
| self.experts_end_idx = self.experts_start_idx + self.n_local_experts | |
| self.gate = Gate(args) | |
| self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim) if self.experts_start_idx <= i < self.experts_end_idx else None | |
| for i in range(self.n_routed_experts)]) | |
| self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim, reduce_output=False) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Forward pass for the MoE module. | |
| Args: | |
| x (torch.Tensor): Input tensor. | |
| Returns: | |
| torch.Tensor: Output tensor after expert routing and computation. | |
| """ | |
| shape = x.size() | |
| x = x.view(-1, self.dim) | |
| weights, indices = self.gate(x) | |
| y = torch.zeros_like(x, dtype=torch.float32) | |
| counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist() | |
| for i in range(self.experts_start_idx, self.experts_end_idx): | |
| if counts[i] == 0: | |
| continue | |
| expert = self.experts[i] | |
| idx, top = torch.where(indices == i) | |
| y[idx] += expert(x[idx]) * weights[idx, top, None] | |
| y += self.shared_experts(x) | |
| if world_size > 1: | |
| dist.all_reduce(y) | |
| return y.type_as(x).view(shape) | |
| class Block(nn.Module): | |
| """ | |
| Transformer block combining attention and feed-forward layers. | |
| Attributes: | |
| attn (nn.Module): Attention layer (MLA). | |
| ffn (nn.Module): Feed-forward network (MLP or MoE). | |
| attn_norm (nn.Module): Layer normalization for attention. | |
| ffn_norm (nn.Module): Layer normalization for feed-forward network. | |
| """ | |
| def __init__(self, layer_id: int, args: ModelArgs): | |
| """ | |
| Initializes the Transformer block. | |
| Args: | |
| layer_id (int): Layer index in the transformer. | |
| args (ModelArgs): Model arguments containing block parameters. | |
| """ | |
| super().__init__() | |
| self.attn = MLA(args) | |
| self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args) | |
| self.attn_norm = RMSNorm(args.dim) | |
| self.ffn_norm = RMSNorm(args.dim) | |
| def forward(self, x: torch.Tensor, residual: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor: | |
| """ | |
| Forward pass for the Transformer block. | |
| Args: | |
| x (torch.Tensor): Input tensor. | |
| start_pos (int): Starting position in the sequence. | |
| freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. | |
| mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention. | |
| Returns: | |
| torch.Tensor: Output tensor after block computation. | |
| """ | |
| if residual is None: | |
| x, residual = self.attn_norm(x), x | |
| else: | |
| x, residual = self.attn_norm(x, residual) | |
| x = self.attn(x, start_pos, freqs_cis, mask) | |
| x, residual = self.ffn_norm(x, residual) | |
| x = self.ffn(x) | |
| return x, residual | |
| class Transformer(nn.Module): | |
| """ | |
| Transformer model with positional embeddings, multiple layers, and output projection. | |
| Attributes: | |
| max_seq_len (int): Maximum sequence length for the transformer. | |
| embed (nn.Module): Embedding layer for input tokens. | |
| layers (torch.nn.ModuleList): List of transformer blocks. | |
| norm (nn.Module): Layer normalization applied after all blocks. | |
| head (nn.Module): Output projection layer mapping to vocabulary size. | |
| freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. | |
| """ | |
| def __init__(self, args: ModelArgs): | |
| """ | |
| Initializes the Transformer model. | |
| Args: | |
| args (ModelArgs): Model arguments containing transformer parameters. | |
| """ | |
| global world_size, rank | |
| world_size = dist.get_world_size() if dist.is_initialized() else 1 | |
| rank = dist.get_rank() if dist.is_initialized() else 0 | |
| Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16 | |
| Linear.scale_fmt = args.scale_fmt | |
| super().__init__() | |
| self.max_seq_len = args.max_seq_len | |
| self.embed = ParallelEmbedding(args.vocab_size, args.dim) | |
| self.layers = torch.nn.ModuleList() | |
| for layer_id in range(args.n_layers): | |
| self.layers.append(Block(layer_id, args)) | |
| self.norm = RMSNorm(args.dim) | |
| # lm_head in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for easier computation of logits later. | |
| self.head = ColumnParallelLinear(args.dim, args.vocab_size, dtype=torch.float32) | |
| self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False) | |
| def forward(self, tokens: torch.Tensor, start_pos: int = 0): | |
| """ | |
| Forward pass for the Transformer model. | |
| Args: | |
| tokens (torch.Tensor): Input tensor of token IDs with shape (batch_size, seq_len). | |
| start_pos (int, optional): Starting position in the sequence for rotary embeddings. Defaults to 0. | |
| Returns: | |
| torch.Tensor: Logits tensor of shape (batch_size, vocab_size). | |
| """ | |
| seqlen = tokens.size(1) | |
| freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen] | |
| mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1) if seqlen > 1 else None | |
| h, residual = self.embed(tokens), None | |
| for layer in self.layers: | |
| h, residual = layer(h, residual, start_pos, freqs_cis, mask) | |
| h, _ = self.norm(h, residual) | |
| logits = self.head(h[:, -1].float()) | |
| if world_size > 1: | |
| all_logits = [torch.empty_like(logits) for _ in range(world_size)] | |
| dist.all_gather(all_logits, logits) | |
| logits = torch.cat(all_logits, dim=-1) | |
| return logits | |
| if __name__ == "__main__": | |
| torch.set_default_dtype(torch.bfloat16) | |
| torch.set_default_device("cuda") | |
| torch.manual_seed(0) | |
| args = ModelArgs() | |
| x = torch.randint(0, args.vocab_size, (2, 128)) | |
| model = Transformer(args) | |
| print(model(x).size()) |