Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama 3.2 1B Instruct on TPU v4, bumping transformers to 4.45.2 #109

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ We currently support a few LLM models targeting text generation scenarios:

## Installation

For installation on a TPU v4, use the `install-on-TPU-v4.sh` script. Make sure that you DO NOT install pallas or Jetstream as both are targeting TPU v5e!

Via package:
`optimum-tpu` comes with an handy PyPi released package compatible with your classical python dependency management tool.

`pip install optimum-tpu -f https://storage.googleapis.com/libtpu-releases/index.html`
Expand Down
4 changes: 2 additions & 2 deletions examples/text-generation/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def summary(values: List[float]):
def main():
parser = argparse.ArgumentParser(description="Text generation example")
parser.add_argument("--model_id", type=str,
default="google/gemma-2b",
default="meta-llama/Llama-3.2-1B-Instruct",
help="Model ID (e.g.: google/gemma-2b, mistralai/Mistral-7B-v0.3)")
parser.add_argument("--max_new_tokens", type=int, default=20, help="Number of tokens to generate")
parser.add_argument("--max_cache_length", type=int, default=256, help="Maximum cache length for the model")
Expand All @@ -72,7 +72,7 @@ def main():
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype)
device = model.device
model = model.eval()
print(f"✅ Model loaded in {time.time() - prg_start} seconds.")
print(f"✅ Model loaded in {time.time() - prg_start} seconds on {device=}.")

tokenizer = AutoTokenizer.from_pretrained(model_id)
# Set pad token for cases where it is None, e.g. for Mistral
Expand Down
25 changes: 25 additions & 0 deletions install-on-TPU-v4.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
sudo apt remove unattended-upgrades
sudo apt update
export PJRT_DEVICE=TPU
export PATH="/home/artuskg/.local/bin:$PATH"
export DBG_COMPILE=True
pip install build
pip install --upgrade setuptools
sudo apt install python3.10-venv

git clone https://github.com/huggingface/optimum-tpu.git

cd optimum-tpu
make
make build_dist_install_tools
make build_dist

python -m venv optimum_tpu_env
source optimum_tpu_env/bin/activate

pip install torch==2.4.0 torch_xla[tpu]==2.4.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html
pip uninstall torchvision # it might insist von 2.4.1
pip install -e .

huggingface-cli login

62 changes: 40 additions & 22 deletions optimum/tpu/distributed_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
# ruff: noqa: E402
import os
from enum import Enum

import time
from loguru import logger
import sys

# Set the logger to show DEBUG messages
logger.remove() # Remove default logger
logger.add(sys.stdout, level="DEBUG") # Re-add with DEBUG level


os.environ["PJRT_DEVICE"] = "TPU"
Expand All @@ -23,59 +28,57 @@ class ModelCommand(Enum):


def _mp_fn(rank, model_id, root_mailbox: RootMailbox, sample_fn: callable):
logger.debug(f"[Rank {rank}] Starting _mp_fn")
device = xm.xla_device()
world_size = xm.xrt_world_size()
# create agent mailbox out of root's one
mailbox = AgentMailbox(root_mailbox)

logger.debug(
f"Rank {rank} on {device} real device {xm.xla_real_devices([device])} ordinal {xm.get_ordinal()} "
+ f"world size {world_size}"
)

# Model loading and sharding should happen here
logger.debug(f"[Rank {rank}] Loading model")
model = AutoModelForCausalLM.from_pretrained(model_id)
model = model.eval()
model.to(device)
logger.debug(f"[Rank {rank}] Model loaded and moved to {device=}")

def get_next_token(inputs):
# move inputs to device in a new dict to avoid conflicts
model_inputs = {}
for key, value in inputs.items():
model_inputs[key] = value.to(device)
logger.debug(f"[Rank {rank}] Starting get_next_token")
model_inputs = {k: v.to(device) for k, v in inputs.items()}
logger.debug(f"[Rank {rank}] Running model inference")
outputs = model(**model_inputs, return_dict=False)[0]
xm.mark_step()
# consider adding a rendezvous here
if rank == 0:
logger.debug(f"Rank {rank} getting tokens")
logger.debug(f"[Rank {rank}] Sampling next token")
next_token = sample_fn(outputs)
xm.mark_step()
logger.debug(f"Rank {rank} sending next_tokens {next_token.shape}")
# Data needs to be moved to CPU before setting it
logger.debug(f"[Rank {rank}] Sending next token")
mailbox.send(next_token.cpu())
logger.debug(f"[Rank {rank}] Finished get_next_token")

while True:
if rank == 0:
mailbox.agent_ready.set()
logger.debug(f"Rank {rank} waiting for commands")
logger.debug(f"[Rank {rank}] Waiting for commands")
mailbox.receive()
# Wait for rank 0 to receive command
xm.rendezvous("start")

logger.debug(f"Rank {rank} waiting for command at rendezvous")
logger.debug(f"[Rank {rank}] Received command")
command, data = mailbox.command_data
inputs = data[0] if data else None
if command == ModelCommand.PREFILL:
logger.debug(f"Rank {rank} PREFILL")
logger.debug(f"[Rank {rank}] Executing PREFILL")
get_next_token(inputs)
elif command == ModelCommand.DECODE:
logger.debug(f"Rank {rank} DECODE")
logger.debug(f"[Rank {rank}] Executing DECODE")
get_next_token(inputs)
elif command == ModelCommand.LEAVE:
logger.debug(f"Rank {rank} LEAVE")
# Set model to ready
logger.debug(f"[Rank {rank}] Executing LEAVE")
mailbox.agent_ready.set()
break
logger.debug(f"[Rank {rank}] Exiting _mp_fn")


def model_loop_fn(*args):
Expand All @@ -85,28 +88,43 @@ def model_loop_fn(*args):

class DistributedModel:
def __init__(self, model_id: str, sample_fn: callable):
logger.debug(f"Initializing DistributedModel with model_id: {model_id}")
start_time = time.time()
manager = mp.Manager()
self.mailbox = RootMailbox(manager)

self.model_loop = mp.Process(target=model_loop_fn, args=(model_id, self.mailbox, sample_fn))
self.model_loop.start()
logger.debug(f"DistributedModel initialization completed in {time.time() - start_time:.2f} seconds")

def prefill(self, **model_args):
logger.debug("Starting prefill operation")
start_time = time.time()
assert self.mailbox is not None, "DistributedModel is not initialized"
return self.mailbox.send(ModelCommand.PREFILL, model_args)[0]
result = self.mailbox.send(ModelCommand.PREFILL, model_args)[0]
logger.debug(f"Prefill operation completed in {time.time() - start_time:.2f} seconds")
return result

def decode(self, **model_args):
logger.debug("Starting decode operation")
start_time = time.time()
assert self.mailbox is not None, "DistributedModel is not initialized"
return self.mailbox.send(ModelCommand.PREFILL, model_args)[0]
result = self.mailbox.send(ModelCommand.PREFILL, model_args)[0]
logger.debug(f"Decode operation completed in {time.time() - start_time:.2f} seconds")
return result

def leave(self):
if self.mailbox is None:
logger.debug("DistributedModel already left")
return
logger.debug("Initiating leave operation")
start_time = time.time()
self.mailbox.send(ModelCommand.LEAVE)
logger.debug("Joining...")
logger.debug("Joining model loop...")
self.model_loop.join()
logger.debug("Model loop finished")
logger.debug(f"Model loop finished in {time.time() - start_time:.2f} seconds")
self.mailbox = None

def __del__(self):
logger.debug("DistributedModel destructor called")
self.leave()
91 changes: 72 additions & 19 deletions optimum/tpu/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
# print("FA2 available")
#else:
# print("FA2 MISSING")


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -101,33 +104,86 @@ def forward(self, hidden_states):


class LlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
def __init__(
self,
dim,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0,
rope_type="default",
config: Optional[LlamaConfig] = None,
):
super().__init__()
self.scaling_factor = scaling_factor
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# For BC we register cos and sin cached
self.scaling_factor = scaling_factor
self.rope_type = rope_type
self.config = config
self.max_seq_len_cached = max_position_embeddings
self.original_max_seq_len = max_position_embeddings

if rope_type == "llama3":
assert config is not None, "Config must be provided for llama3 rope type"
inv_freq = self._compute_llama3_inv_freq(device)
else:
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32).to(device) / dim))

self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
self.attention_scaling = 1.0 # Default scaling

def _compute_llama3_inv_freq(self, device):
factor = self.config.rope_scaling["factor"]
low_freq_factor = self.config.rope_scaling["low_freq_factor"]
high_freq_factor = self.config.rope_scaling["high_freq_factor"]
old_context_len = self.config.rope_scaling["original_max_position_embeddings"]

pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (factor * pos_freqs)

low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
wavelen = 2 * math.pi / inv_freq_extrapolation

inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq_extrapolation / factor, inv_freq_extrapolation)
smooth_factor = torch.clip((old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor), 0, 1)
smoothed_inv_freq = (1 - smooth_factor) * inv_freq_interpolation + smooth_factor * inv_freq_extrapolation
inv_freq_llama = torch.where((wavelen < high_freq_wavelen) & (wavelen > low_freq_wavelen), smoothed_inv_freq, inv_freq_llama)

return inv_freq_llama

@torch.no_grad()
def forward(self, x, position_ids):
# x: [bs, num_attention_heads, seq_len, head_size]
if self.rope_type == "llama3":
self._update_llama3_inv_freq(position_ids, x.device)

inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285

device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

cos = cos * self.attention_scaling
sin = sin * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

def _update_llama3_inv_freq(self, position_ids, device):
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached:
self.inv_freq = self._compute_llama3_inv_freq(device)
self.max_seq_len_cached = seq_len
elif seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len:
self.inv_freq = self.original_inv_freq
self.max_seq_len_cached = self.original_max_seq_len


class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
Expand Down Expand Up @@ -338,27 +394,24 @@ def _init_rope(self):
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
config=self.config,
)
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
scaling_type = self.config.rope_scaling.get("rope_type", self.config.rope_scaling.get("type", "default"))
scaling_factor = self.config.rope_scaling.get("factor", 1.0)
if scaling_type in ["linear", "dynamic", "llama3"]:
self.rotary_emb = LlamaRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
elif scaling_type == "dynamic":
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
rope_type=scaling_type,
config=self.config,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")


def forward(
self,
hidden_states: torch.Tensor,
Expand Down
11 changes: 6 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ keywords = [
]

dependencies = [
"transformers == 4.41.1",
"transformers == 4.45.2",
"torch == 2.4.0",
"torch-xla[tpu] == 2.4.0",
'typer == 0.6.1',
Expand All @@ -61,10 +61,11 @@ tests = ["pytest", "safetensors"]
quality = ["black", "ruff", "isort"]
# Jetstream/Pytorch support is experimental for now, it needs to be installed manually.
# Pallas is pulled because it will install a compatible version of jax[tpu].
jetstream-pt = [
"jetstream-pt",
"torch-xla[pallas] == 2.4.0"
]
# pallas and jetstream are not supported before v5e. Therefore, comment out on v4 and earlier
#jetstream-pt = [
# "jetstream-pt",
# "torch-xla[pallas] == 2.4.0"
#]

[project.urls]
Homepage = "https://hf.co/hardware"
Expand Down
Loading