diff --git a/open_diloco/model.py b/open_diloco/model.py new file mode 100644 index 0000000..c4b047e --- /dev/null +++ b/open_diloco/model.py @@ -0,0 +1,796 @@ +# copied and adapted from : https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/model.py + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Llama 2 is licensed under the LLAMA 2 Community License, +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + + +import math + +from functools import partial + + +import torch +import torch.nn.functional as F +from torch import nn + +import triton +import triton.language as tl + +from torch.distributed._tensor import Partial, Replicate, Shard +from torch.distributed._tensor.experimental import local_map +from transformers.modeling_outputs import CausalLMOutputWithPast + +from dataclasses import dataclass +from typing import Optional, Tuple +from typing_extensions import Self + + +@dataclass +class ModelArgs: + dim: int = 4096 + n_layers: int = 32 + n_heads: int = 32 + n_kv_heads: Optional[int] = None + vocab_size: int = -1 # defined later by tokenizer + multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 + ffn_dim_multiplier: Optional[float] = None + norm_eps: float = 1e-5 + rope_theta: float = 10000 + + # max_batch_size: int = 32 + max_seq_len: int = 2048 + # If `True`, then each transformer block init uses its layer ID, and if + # `False`, each uses the total number of transformer blocks + depth_init: bool = True + norm_type: str = "rmsnorm" + + @classmethod + def from_name(cls, name: str) -> Self: + return { + "2M": ModelArgs(dim=256, n_layers=8, n_heads=8, vocab_size=32_000), + "150M": ModelArgs(dim=1024, n_layers=12, n_heads=16, vocab_size=32_000), + "1B": ModelArgs(dim=2048, n_layers=18, n_heads=16, vocab_size=32_000), + "7B": ModelArgs(dim=4096, n_layers=32, n_heads=32), + }[name] + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + + Args: + dim (int): Dimension of the frequency tensor. + end (int): End index for precomputing frequencies. + theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. + + Returns: + torch.Tensor: Precomputed frequency tensor with complex exponentials. + """ + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) + freqs = torch.outer(t, freqs).float() + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + """ + Reshape frequency tensor for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim), + and the first seqlen elements will be sliced, but dim must match x. + + Args: + freqs_cis (torch.Tensor): Frequency tensor to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + + Returns: + torch.Tensor: Reshaped frequency tensor. + """ + ndim = x.ndim + assert 0 <= 1 < ndim + seqlen = x.shape[1] + freqs_cis = freqs_cis[0:seqlen] + assert freqs_cis.shape == (seqlen, x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. + + This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided + frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor + is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are + returned as real tensors. + + Args: + xq (torch.Tensor): Query tensor to apply rotary embeddings. + xk (torch.Tensor): Key tensor to apply rotary embeddings. + freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + torch.unsqueeze(x, dim=3) + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +class Attention(nn.Module): + """ + Multi-head attention module. + + Args: + model_args (ModelArgs): Model configuration arguments. + + Attributes: + n_kv_heads (int): Number of key and value heads. + n_heads (int): Number of query heads. + n_rep (int): Number of repetitions for local heads. + head_dim (int): Dimension size of each attention head. + wq (Linear): Linear transformation for queries. + wk (Linear): Linear transformation for keys. + wv (Linear): Linear transformation for values. + wo (Linear): Linear transformation for output. + + """ + + def __init__(self, model_args: ModelArgs): + super().__init__() + self.n_heads = model_args.n_heads + self.n_kv_heads = model_args.n_heads if model_args.n_kv_heads is None else model_args.n_kv_heads + self.n_rep = self.n_heads // self.n_kv_heads + self.head_dim = model_args.dim // model_args.n_heads + + self.wq = nn.Linear(model_args.dim, model_args.n_heads * self.head_dim, bias=False) + self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wo = nn.Linear(model_args.n_heads * self.head_dim, model_args.dim, bias=False) + + def init_weights(self, init_std: float): + for linear in (self.wq, self.wk, self.wv): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + ): + """ + Forward pass of the attention module. + + Args: + x (torch.Tensor): Input tensor. + freqs_cis (torch.Tensor): Precomputed frequency tensor. + + Returns: + torch.Tensor: Output tensor after attention. + + """ + bs, seqlen, _ = x.shape + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + + # Use -1 instead of `n_heads` (or `n_kv_heads`) to infer the actual + # local heads from sizes of xq, xk, and xv as TP may have sharded them + # after the above linear ops. + xq = xq.view(bs, seqlen, -1, self.head_dim) + xk = xk.view(bs, seqlen, -1, self.head_dim) + xv = xv.view(bs, seqlen, -1, self.head_dim) + + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + + # repeat k/v heads if n_kv_heads < n_heads + keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + values = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + + # we use casual mask for training + output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) + output = output.transpose(1, 2).contiguous() # (bs, seqlen, n_local_heads, head_dim) + output = output.view(bs, seqlen, -1) + return self.wo(output) + + +class FeedForward(nn.Module): + """ + FeedForward module + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. + ffn_dim_multiplier (Optional[float]): Custom multiplier for hidden dimension. Defaults to None. + + Attributes: + w1 (Linear): Linear transformation for the first layer. + w2 (Linear): Linear transformation for the second layer. + w3 (Linear): Linear transformation for the third layer. + + """ + + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) + for linear in (self.w2, self.w3): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + + +class TransformerBlock(nn.Module): + """ + TransformerBlock Module + + Args: + layer_id (int): Identifier for the layer. + model_args (ModelArgs): Model configuration arguments. + + Attributes: + n_heads (int): Number of attention heads. + dim (int): Dimension size of the model. + head_dim (int): Dimension size of each attention head. + attention (Attention): Attention module. + feed_forward (FeedForward): FeedForward module. + layer_id (int): Identifier for the layer. + attention_norm (RMSNorm): Layer normalization for attention output. + ffn_norm (RMSNorm): Layer normalization for feedforward output. + + """ + + def __init__(self, layer_id: int, model_args: ModelArgs): + super().__init__() + self.n_heads = model_args.n_heads + self.dim = model_args.dim + self.attention = Attention(model_args) + self.feed_forward = FeedForward( + dim=model_args.dim, + hidden_dim=4 * model_args.dim, + multiple_of=model_args.multiple_of, + ffn_dim_multiplier=model_args.ffn_dim_multiplier, + ) + self.layer_id = layer_id + self.num_layers = model_args.n_layers + + self.attention_norm = build_norm(model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps) + self.ffn_norm = build_norm(model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps) + + if model_args.depth_init: + self.weight_init_std = 0.02 / (2 * (self.layer_id + 1)) ** 0.5 + else: + self.weight_init_std = 0.02 / (2 * self.num_layers) ** 0.5 + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + ): + """ + Perform a forward pass through the TransformerBlock. + + Args: + x (torch.Tensor): Input tensor. + freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + + Returns: + torch.Tensor: Output tensor after applying attention and feedforward layers. + + """ + h = x + self.attention(self.attention_norm(x), freqs_cis) + out = h + self.feed_forward(self.ffn_norm(h)) + return out + + def init_weights(self): + for norm in (self.attention_norm, self.ffn_norm): + norm.reset_parameters() + self.attention.init_weights(self.weight_init_std) + self.feed_forward.init_weights(self.weight_init_std) + + +class Transformer(nn.Module): + """ + Transformer Module + + Args: + model_args (ModelArgs): Model configuration arguments. + + Attributes: + model_args (ModelArgs): Model configuration arguments. + vocab_size (int): Vocabulary size. + n_layers (int): Number of layers in the model. + tok_embeddings (ParallelEmbedding): Token embeddings. + layers (torch.nn.ModuleList): List of Transformer blocks. + norm (RMSNorm): Layer normalization for the model output. + output (ColumnParallelLinear): Linear layer for final output. + freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + + """ + + def __init__(self, model_args: ModelArgs): + super().__init__() + self.model_args = model_args + self.vocab_size = model_args.vocab_size + self.n_layers = model_args.n_layers + + self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) + + # TODO persistent should be set to false, since this buffer can be recomputed. + # however, we set it to true for 2 reasons. (1) due to pytorch/pytorch#123411, + # compile or pipeline-tracer will not correctly handle non-persistent buffers, + # so we need to fix that. (2) if we initialize pipeline-parallel models from + # a seed checkpoint rather than calling init_weights, we need freqs_cis to be + # initialized by the checkpoint, or we need to add a separate initializer for + # just the non-persistent buffers that is called after loading checkpoints. + self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True) + + self.layers = torch.nn.ModuleDict() + for layer_id in range(model_args.n_layers): + self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args) + + self.norm = build_norm(model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps) + + self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False) + self.init_weights() + + def init_weights(self): + """ + [Note: On ``init_weights`` vs. ``reset_parameters``] + Modules may define ``reset_parameters`` to initialize parameter values. + ``reset_parameters`` is meant to only initialize directly owned + parameters/buffers, not those of their child modules, and it can be + used to give the initial values for these tensors. + Separately, users may want custom initialization for their modules, + different from that in ``reset_parameters``. For this, we define + ``init_weights``. We only call it in the constructor of this + ``Transformer`` root module to avoid reinitializing tensors. + """ + with torch.device(self.freqs_cis.device): + self.freqs_cis = self._precompute_freqs_cis() + if self.tok_embeddings is not None: + nn.init.normal_(self.tok_embeddings.weight) + for layer in self.layers.values(): + if layer is not None: + layer.init_weights() + if self.norm is not None: + self.norm.reset_parameters() + final_out_std = self.model_args.dim**-0.5 + cutoff_factor = 3 + if self.output is not None: + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + + def _precompute_freqs_cis(self) -> torch.Tensor: + return precompute_freqs_cis( + self.model_args.dim // self.model_args.n_heads, + # Need to compute until at least the max token limit for generation + # (use 2x max sequence length to be safe) + self.model_args.max_seq_len * 2, + self.model_args.rope_theta, + ) + + def forward(self, tokens: torch.Tensor): + """ + Perform a forward pass through the Transformer model. + + Args: + tokens (torch.Tensor): Input token indices. + + Returns: + torch.Tensor: Output logits after applying the Transformer model. + + """ + # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages + h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens + + for layer in self.layers.values(): + h = layer(h, self.freqs_cis) + + h = self.norm(h) if self.norm else h + output = self.output(h).float() if self.output else h + return output + + @classmethod + def from_model_args(cls, model_args: ModelArgs) -> "Transformer": + """ + Initialize a Transformer model from a ModelArgs object. + + Args: + model_args (ModelArgs): Model configuration arguments. + + Returns: + Transformer: Transformer model. + + """ + return cls(model_args) + + +class TransformerHF(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.model = Transformer(config) + + def forward(self, input_ids: torch.LongTensor) -> CausalLMOutputWithPast: + return CausalLMOutputWithPast(logits=self.model(input_ids)) + + +def build_norm(norm_type: str, dim: int, eps: float = 1e-6): + """ + Builds the specified normalization layer based on the norm_type. + + Args: + norm_type (str): The type of normalization layer to build. + Supported types: 1. rmsnorm 2. fused_rmsnorm 3. layernorm 4. np_layernorm + dim (int): The dimension of the normalization layer. + eps (float, optional): The epsilon value for numerical stability. Defaults to 1e-6. + + Returns: + The built normalization layer. + + Raises: + NotImplementedError: If an unknown norm_type is provided. + """ + norm_type = norm_type.lower() # Normalize to lowercase + + if norm_type == "layernorm": + return nn.LayerNorm(dim, eps=eps, bias=False) + elif norm_type == "np_layernorm": + return nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False) + elif norm_type == "rmsnorm": + return RMSNorm(dim, eps=eps) + elif norm_type == "compiled_rmsnorm": + import warnings + + warnings.warn("compiled_rmsnorm is currently experimental and not ready to use yet.") + return RMSNorm(dim, eps=eps, compile=True) + elif norm_type == "fused_rmsnorm": + return FusedRMSNorm(dim, eps=eps) + else: + raise NotImplementedError(f"Unknown norm_type: '{norm_type}'") + + +class FusedRMSNorm(nn.Module): + """Fused RMS Norm, wraps a fused Triton Kernel""" + + def __init__( + self, + dim: int, + eps: float = 1e-6, + ): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + self.fused_rms_norm_fn = fused_rms_norm_fn + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """leverages Triton Fused RMS Norm kernel""" + return self.fused_rms_norm_fn( + x, + self.weight, + eps=self.eps, + ) + + def reset_parameters(self): + torch.nn.init.ones_(self.weight) # type: ignore + + +class RMSNorm(nn.Module): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + + def __init__(self, dim: int, eps: float = 1e-6, compile: bool = False): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + self.rmsnorm_fn = torch.compile(self.compute_rmsnorm, fullgraph=True) if compile else self.compute_rmsnorm + + @staticmethod + def compute_rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float): + def _norm(x, eps): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) + + output = _norm(x.float(), eps).type_as(x) + return output * weight + + def forward(self, x: torch.Tensor): + return self.rmsnorm_fn(x, self.weight, self.eps) + + def reset_parameters(self): + torch.nn.init.ones_(self.weight) # type: ignore + + +# FusedRMSNorm in Triton + +# Credit +# Tri Dao's Triton LayerNorm: https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py +# Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["N"], +) +@triton.jit +def _rms_norm_fwd_kernel( + X, + stride_x, + Y, + stride_y, + W, + Rstd, + eps, + M, # num rows + N, # num cols + block_N: tl.constexpr, +): + row = tl.program_id(0) + cols = tl.arange(0, block_N) + + # Load input data and weights + mask = cols < N + x = tl.load(X + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32) + w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32) + + # Compute mean and variance + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + + # Store the reciprocal standard deviation + tl.store(Rstd + row, rstd) + + # Normalize and apply linear transformation + x_hat = x * rstd + y = x_hat * w + + # Write output + tl.store(Y + row * stride_y + cols, y, mask=mask) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["N"], +) +@triton.jit +def _rms_norm_bwd_kernel_sm( + X, + stride_x, + W, + DY, + stride_dy, + DX, + stride_dx, + Rstd, + DW, + eps, + M, # num rows + N, # num cols + rows_per_program, + block_N: tl.constexpr, +): + row_block_id = tl.program_id(0) + row_start = row_block_id * rows_per_program + cols = tl.arange(0, block_N) + mask = cols < N + + # Load weights + w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32) + + # Accumulate gradients for weights + dw = tl.zeros((block_N,), dtype=tl.float32) + + row_end = min(row_start + rows_per_program, M) + for row in range(row_start, row_end): + # Load input, output gradient, and reciprocal standard deviation + x = tl.load(X + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32) + dy = tl.load(DY + row * stride_dy + cols, mask=mask, other=0.0).to(tl.float32) + rstd = tl.load(Rstd + row) + + # Compute normalized input and gradients + x_hat = x * rstd + wdy = w * dy + dw += dy * x_hat + c1 = tl.sum(x_hat * wdy, axis=0) / N + dx = (wdy - x_hat * c1) * rstd + + # Store input gradient + tl.store(DX + row * stride_dx + cols, dx, mask=mask) + + # Store weight gradients + tl.store(DW + row_block_id * N + cols, dw, mask=mask) + + +class TritonFusedRMSNorm(torch.autograd.Function): + @partial( + local_map, + out_placements=[Shard(1)], + in_placements=(None, [Shard(1)], [Replicate()], None), + ) + @staticmethod + def forward(ctx, x, weight, eps): + x_shape_start = x.shape + + # Flatten input + x = x.view(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if weight.stride(-1) != 1: + weight = weight.contiguous() + + M, N = x.shape + y = torch.empty_like(x) + rstd = torch.empty((M,), dtype=torch.float32, device=x.device) + + max_size = 65536 // x.element_size() + block_N = min(max_size, triton.next_power_of_2(N)) + + if N > block_N: + raise ValueError(f"N {N} must be <= {block_N=}") + + grid = lambda meta: (M,) # noqa: E731 + _rms_norm_fwd_kernel[grid]( + x, + x.stride(0), + y, + y.stride(0), + weight, + rstd, + eps, + M, + N, + block_N, + ) + + ctx.eps = eps + ctx.save_for_backward(x, weight, rstd) + ctx.x_shape_start = x_shape_start + + y = y.reshape(x_shape_start) + return y + + @partial( + local_map, + out_placements=([Shard(1)], [Partial()], None), + in_placements=(None, [Shard(1)]), + ) + @staticmethod + def backward(ctx, dy): + x, weight, rstd = ctx.saved_tensors + eps = ctx.eps + x_shape_start = ctx.x_shape_start + + # Flatten input and output gradients + dy = dy.view(-1, dy.shape[-1]) + if dy.stride(-1) != 1: + dy = dy.contiguous() + + M, N = dy.shape + dx = torch.empty_like(x) + dw = torch.empty_like(weight) + + sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count + _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) + + max_size = 65536 // x.element_size() + block_N = min(max_size, triton.next_power_of_2(N)) + rows_per_sm = math.ceil(M / sm_count) + + if N > block_N: + raise ValueError(f"N {N} must be <= {block_N=}") + + grid = lambda meta: (sm_count,) # noqa: E731 + _rms_norm_bwd_kernel_sm[grid]( + x, + x.stride(0), + weight, + dy, + dy.stride(0), + dx, + dx.stride(0), + rstd, + _dw, + eps, + M, + N, + rows_per_sm, + block_N, + ) + dw = _dw.sum(0).to(weight.dtype) + dx = dx.view(x_shape_start) + return dx, dw, None + + +# expose fusedRMSNorm as a function +def fused_rms_norm_fn( + x, + weight, + eps=1e-6, +): + return TritonFusedRMSNorm.apply( + x, + weight, + eps, + ) diff --git a/open_diloco/train_fsdp.py b/open_diloco/train_fsdp.py index 17ef139..f2ad18c 100644 --- a/open_diloco/train_fsdp.py +++ b/open_diloco/train_fsdp.py @@ -12,6 +12,7 @@ from contextlib import nullcontext import datetime from typing import Any, Literal +from einops import rearrange import fsspec from pydantic import model_validator @@ -22,11 +23,13 @@ from datasets.distributed import split_dataset_by_node from fsspec.generic import GenericFileSystem from torch.distributed import destroy_process_group, init_process_group +import torch.nn.functional as F + from torchdata.stateful_dataloader import StatefulDataLoader from transformers import ( AutoTokenizer, - DataCollatorForLanguageModeling, + LlamaTokenizer, LlamaConfig, LlamaForCausalLM, get_cosine_schedule_with_warmup, @@ -50,9 +53,11 @@ from open_diloco.utils import ( ActivationNormMetric, FakeTokenizedDataset, + collate_causal_mask, get_compression_kwargs, get_sharding_strategy, ) +from open_diloco.model import ModelArgs, TransformerHF TIMEOUT_NCCL_MINUTES = os.environ.get("TIMEOUT_NCCL_MINUTES", 120) @@ -115,6 +120,7 @@ def cast_str_to_list(cls, values: dict[str, Any]) -> dict[str, Any]: class Config(BaseConfig): path_model: str = "PrimeIntellect/llama-150m-fresh" + torch_titan_llama: bool = False torch_compile: bool = True attn_implementation: str = "sdpa" # Data @@ -142,7 +148,9 @@ class Config(BaseConfig): max_steps: int | None = None -def get_dataloader(tokenizer, world_size, rank, local_rank, config: Config) -> StatefulDataLoader: +def get_dataloader( + tokenizer: LlamaTokenizer, world_size: int, rank: int, local_rank: int, config: Config +) -> StatefulDataLoader: if config.fake_data: train_dataset = FakeTokenizedDataset(config.seq_length, TEST_VOCAB_SIZE) else: @@ -157,9 +165,9 @@ def tokenize_function(data): ) return outputs - tokenized_datasets = ds.map(tokenize_function, batched=True, remove_columns=["text", "timestamp", "url"])[ - "train" - ] + tokenized_datasets = ds.map( + tokenize_function, batched=True, remove_columns=["text", "timestamp", "url", "attention_mask"] + )["train"] if config.hv is not None: train_dataset = split_dataset_by_node( @@ -171,7 +179,7 @@ def tokenize_function(data): else: train_dataset = split_dataset_by_node(tokenized_datasets, world_size=world_size, rank=rank) - data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) + data_collator = collate_causal_mask(config.seq_length, tokenizer.pad_token_id, ignore_index=-100) return StatefulDataLoader( train_dataset, @@ -183,8 +191,12 @@ def tokenize_function(data): def get_model(config: Config) -> LlamaForCausalLM: # Load model - config_model = LlamaConfig.from_pretrained(config.path_model, attn_implementation=config.attn_implementation) - return LlamaForCausalLM.from_pretrained(pretrained_model_name_or_path=config.path_model, config=config_model) + if config.torch_titan_llama: + config_model = ModelArgs.from_name(config.path_model) + return TransformerHF(config=config_model) + else: + config_model = LlamaConfig.from_pretrained(config.path_model, attn_implementation=config.attn_implementation) + return LlamaForCausalLM.from_pretrained(pretrained_model_name_or_path=config.path_model, config=config_model) def train(config: Config): @@ -392,8 +404,13 @@ def scheduler_fn(opt): batch[key] = batch[key].to("cuda") with model.no_sync() if is_accumulating else nullcontext(): - outputs = model(**batch) - loss = outputs.loss / gradient_accumulation_steps + logits = model(input_ids=batch["input_ids"]).logits.contiguous() + labels = batch["labels"].contiguous() + + flatten_logits = rearrange(logits, "b seq vocab -> (b seq) vocab") + flatten_labels = rearrange(labels, "b seq -> (b seq)") + + loss = F.cross_entropy(flatten_logits, flatten_labels, ignore_index=-100) / gradient_accumulation_steps loss_batch += loss.detach() diff --git a/open_diloco/utils.py b/open_diloco/utils.py index b85fe58..17d9365 100644 --- a/open_diloco/utils.py +++ b/open_diloco/utils.py @@ -173,5 +173,32 @@ def __init__(self, seq_len: int, vocab_size: int): def __iter__(self) -> Generator[dict[str, Any], Any, None]: while True: input_ids = torch.randint(3, self.vocab_size, (self.seq_len,)).tolist() - attention_mask = [1] * self.seq_len - yield {"input_ids": input_ids, "attention_mask": attention_mask} + yield {"input_ids": input_ids} + + +def collate_causal_mask(max_seq_length: int = -1, pad_id: int = 0, ignore_index: int = -100) -> callable: + return partial(_collate_fn_causal_mask, max_seq_length=max_seq_length, pad_id=pad_id, ignore_index=ignore_index) + + +def _collate_fn_causal_mask( + samples: list[dict[str, torch.LongTensor]], max_seq_length: int = -1, pad_id: int = 0, ignore_index: int = -100 +) -> dict[str, torch.LongTensor]: + assert samples[0].keys() == {"input_ids"} + + batched = {"input_ids": [], "labels": []} + + if max_seq_length > 0: + max_seq_length += 1 # this makes sure that the effective seqlen is correct + + for sample in samples: + input_ids = torch.Tensor(sample["input_ids"]).long() + + if len(input_ids) < max_seq_length: + input_ids = torch.cat([input_ids, torch.full((max_seq_length - len(input_ids),), pad_id)]) + elif len(input_ids) > max_seq_length: + input_ids = input_ids[:max_seq_length] + + batched["input_ids"].append(input_ids[1:]) + batched["labels"].append(input_ids[:-1]) + + return {"input_ids": torch.stack(batched["input_ids"], dim=0), "labels": torch.stack(batched["labels"], dim=0)}