Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Haojun/inference #142

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -162,4 +162,6 @@ cython_debug/
.vscode

checkpoints/
wandb/

# wandb
zzhhjjj marked this conversation as resolved.
Show resolved Hide resolved
zzhhjjj marked this conversation as resolved.
Show resolved Hide resolved
**/wandb/**
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ pip install nanotron

Also nice to have: `pip install transformers datasets python-etcd tensorboardX`

We also support a set of flavors that you can install using `pip install -e [$FLAVOR]`:
We also support a set of flavors that you can install using `pip install -e .[$FLAVOR]`:
- `dev`: Used is you are developping in `nanotron`. It installs in particular our linter mechanism. On top of that you have to run `pre-commit install` afterwards.
- `test`: We use `pytest` in order to run out testing suite. In order to run tests in parallel, it will install `pytest-xdist`, which you can leverage by running `pytest -n 12 tests` (12 is the number of parallel test)

Expand Down
13 changes: 7 additions & 6 deletions run_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from nanotron import distributed as dist
from nanotron import logging
from nanotron.config import (
zzhhjjj marked this conversation as resolved.
Show resolved Hide resolved
Config,
GenerationArgs,
LoggingArgs,
ParallelismArgs,
Expand Down Expand Up @@ -68,17 +69,17 @@ def main():

assert args.ckpt_path.exists(), f"Checkpoint path {args.ckpt_path} does not exist"

config = get_config_from_file((args.ckpt_path / "config.yaml").as_posix())
config = get_config_from_file((args.ckpt_path / "config.yaml").as_posix(), config_class=Config)
model_config = config.model.model_config
tokenizer_path = config.tokenizer.tokenizer_name_or_path

parallel_config = ParallelismArgs(
dp=args.dp or config.parallelism.dp,
pp=args.pp or config.parallelism.pp,
tp=args.tp or config.parallelism.tp,
dp=config.parallelism.dp,
pp=config.parallelism.pp,
tp=config.parallelism.tp,
zzhhjjj marked this conversation as resolved.
Show resolved Hide resolved
pp_engine=OneForwardOneBackwardPipelineEngine(),
tp_mode=TensorParallelLinearMode.ALL_REDUCE,
tp_linear_async_communication=True,
tp_mode=TensorParallelLinearMode.REDUCE_SCATTER,
zzhhjjj marked this conversation as resolved.
Show resolved Hide resolved
tp_linear_async_communication=config.parallelism.tp_linear_async_communication,
zzhhjjj marked this conversation as resolved.
Show resolved Hide resolved
)

# Initialise all process groups
Expand Down
1 change: 1 addition & 0 deletions src/nanotron/config/models_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class LlamaConfig:
tie_word_embeddings: bool = False
use_cache: bool = True
vocab_size: int = 32000
for_inference: bool = False # if true, replace TritonRMSNorm with LayerNorm for a fixed output. use TritonRMSNorm for training as it's faster
zzhhjjj marked this conversation as resolved.
Show resolved Hide resolved

def __post_init__(self):
# NOTE: user don't set self._init_method, ModelArgs will set it
Expand Down
38 changes: 35 additions & 3 deletions src/nanotron/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@
from typing import Dict, Optional, Union

import torch

from flash_attn import bert_padding
from flash_attn.flash_attn_interface import (
flash_attn_varlen_func,
flash_attn_with_kvcache,
)
from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
from torch import nn

from nanotron import distributed as dist
Expand Down Expand Up @@ -46,6 +53,23 @@
logger = logging.get_logger(__name__)


class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.zeros(dim))

def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

def forward(self, x):
zzhhjjj marked this conversation as resolved.
Show resolved Hide resolved
output = self._norm(x.float())
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402
output = output * (1.0 + self.weight.float())
return output.type_as(x)


class RotaryEmbedding(nn.Module):
zzhhjjj marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, dim: int, end: int, theta: float = 10000.0):
super().__init__()
Expand Down Expand Up @@ -607,15 +631,23 @@ def __init__(
layer_idx: int,
):
super().__init__()
self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# TritonRMSNorm is faster but generate randomized results. Use TritonRMSNorm for training as it's faster
if not config.for_inference:
self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
# replace TritonRMSNorm with RMSNorm for a fixed output.
zzhhjjj marked this conversation as resolved.
Show resolved Hide resolved
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.attn = CausalSelfAttention(
config=config,
parallel_config=parallel_config,
tp_pg=tp_pg,
layer_idx=layer_idx,
)

self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
if not config.for_inference:
self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
# replace TritonRMSNorm with RMSNorm for a fixed output.
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg)

def forward(
Expand Down
Loading