Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama 3.2 1B Instruct on TPU v4, bumping transformers to 4.45.2 #109

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions install-on-TPU-v4.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ pip install -e .

huggingface-cli login
gsutil cp -r gs://entropix/huggingface_hub ~/.cache/huggingface/hub
pip install transformers=4.45.2
62 changes: 35 additions & 27 deletions optimum/tpu/distributed_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# ruff: noqa: E402
import os
from enum import Enum

import time
from loguru import logger


Expand All @@ -23,59 +23,52 @@ class ModelCommand(Enum):


def _mp_fn(rank, model_id, root_mailbox: RootMailbox, sample_fn: callable):
logger.info(f"[Rank {rank}] Starting _mp_fn")
device = xm.xla_device()
world_size = xm.xrt_world_size()
# create agent mailbox out of root's one
mailbox = AgentMailbox(root_mailbox)

logger.debug(
f"Rank {rank} on {device} real device {xm.xla_real_devices([device])} ordinal {xm.get_ordinal()} "
+ f"world size {world_size}"
)

# Model loading and sharding should happen here
logger.info(f"[Rank {rank}] Loading model")
model = AutoModelForCausalLM.from_pretrained(model_id)
model = model.eval()
model.to(device)
logger.info(f"[Rank {rank}] Model loaded and moved to device")

def get_next_token(inputs):
# move inputs to device in a new dict to avoid conflicts
model_inputs = {}
for key, value in inputs.items():
model_inputs[key] = value.to(device)
logger.info(f"[Rank {rank}] Starting get_next_token")
model_inputs = {k: v.to(device) for k, v in inputs.items()}
logger.info(f"[Rank {rank}] Running model inference")
outputs = model(**model_inputs, return_dict=False)[0]
xm.mark_step()
# consider adding a rendezvous here
if rank == 0:
logger.debug(f"Rank {rank} getting tokens")
logger.info(f"[Rank {rank}] Sampling next token")
next_token = sample_fn(outputs)
xm.mark_step()
logger.debug(f"Rank {rank} sending next_tokens {next_token.shape}")
# Data needs to be moved to CPU before setting it
logger.info(f"[Rank {rank}] Sending next token")
mailbox.send(next_token.cpu())
logger.info(f"[Rank {rank}] Finished get_next_token")

while True:
if rank == 0:
mailbox.agent_ready.set()
logger.debug(f"Rank {rank} waiting for commands")
logger.info(f"[Rank {rank}] Waiting for commands")
mailbox.receive()
# Wait for rank 0 to receive command
xm.rendezvous("start")

logger.debug(f"Rank {rank} waiting for command at rendezvous")
logger.info(f"[Rank {rank}] Received command")
command, data = mailbox.command_data
inputs = data[0] if data else None
if command == ModelCommand.PREFILL:
logger.debug(f"Rank {rank} PREFILL")
logger.info(f"[Rank {rank}] Executing PREFILL")
get_next_token(inputs)
elif command == ModelCommand.DECODE:
logger.debug(f"Rank {rank} DECODE")
logger.info(f"[Rank {rank}] Executing DECODE")
get_next_token(inputs)
elif command == ModelCommand.LEAVE:
logger.debug(f"Rank {rank} LEAVE")
# Set model to ready
logger.info(f"[Rank {rank}] Executing LEAVE")
mailbox.agent_ready.set()
break
logger.info(f"[Rank {rank}] Exiting _mp_fn")


def model_loop_fn(*args):
Expand All @@ -85,28 +78,43 @@ def model_loop_fn(*args):

class DistributedModel:
def __init__(self, model_id: str, sample_fn: callable):
logger.info(f"Initializing DistributedModel with model_id: {model_id}")
start_time = time.time()
manager = mp.Manager()
self.mailbox = RootMailbox(manager)

self.model_loop = mp.Process(target=model_loop_fn, args=(model_id, self.mailbox, sample_fn))
self.model_loop.start()
logger.info(f"DistributedModel initialization completed in {time.time() - start_time:.2f} seconds")

def prefill(self, **model_args):
logger.info("Starting prefill operation")
start_time = time.time()
assert self.mailbox is not None, "DistributedModel is not initialized"
return self.mailbox.send(ModelCommand.PREFILL, model_args)[0]
result = self.mailbox.send(ModelCommand.PREFILL, model_args)[0]
logger.info(f"Prefill operation completed in {time.time() - start_time:.2f} seconds")
return result

def decode(self, **model_args):
logger.info("Starting decode operation")
start_time = time.time()
assert self.mailbox is not None, "DistributedModel is not initialized"
return self.mailbox.send(ModelCommand.PREFILL, model_args)[0]
result = self.mailbox.send(ModelCommand.PREFILL, model_args)[0]
logger.info(f"Decode operation completed in {time.time() - start_time:.2f} seconds")
return result

def leave(self):
if self.mailbox is None:
logger.info("DistributedModel already left")
return
logger.info("Initiating leave operation")
start_time = time.time()
self.mailbox.send(ModelCommand.LEAVE)
logger.debug("Joining...")
logger.info("Joining model loop...")
self.model_loop.join()
logger.debug("Model loop finished")
logger.info(f"Model loop finished in {time.time() - start_time:.2f} seconds")
self.mailbox = None

def __del__(self):
logger.info("DistributedModel destructor called")
self.leave()
91 changes: 72 additions & 19 deletions optimum/tpu/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
# print("FA2 available")
#else:
# print("FA2 MISSING")


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -101,33 +104,86 @@ def forward(self, hidden_states):


class LlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
def __init__(
self,
dim,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0,
rope_type="default",
config: Optional[LlamaConfig] = None,
):
super().__init__()
self.scaling_factor = scaling_factor
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# For BC we register cos and sin cached
self.scaling_factor = scaling_factor
self.rope_type = rope_type
self.config = config
self.max_seq_len_cached = max_position_embeddings
self.original_max_seq_len = max_position_embeddings

if rope_type == "llama3":
assert config is not None, "Config must be provided for llama3 rope type"
inv_freq = self._compute_llama3_inv_freq(device)
else:
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32).to(device) / dim))

self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
self.attention_scaling = 1.0 # Default scaling

def _compute_llama3_inv_freq(self, device):
factor = self.config.rope_scaling["factor"]
low_freq_factor = self.config.rope_scaling["low_freq_factor"]
high_freq_factor = self.config.rope_scaling["high_freq_factor"]
old_context_len = self.config.rope_scaling["original_max_position_embeddings"]

pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (factor * pos_freqs)

low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
wavelen = 2 * math.pi / inv_freq_extrapolation

inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq_extrapolation / factor, inv_freq_extrapolation)
smooth_factor = torch.clip((old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor), 0, 1)
smoothed_inv_freq = (1 - smooth_factor) * inv_freq_interpolation + smooth_factor * inv_freq_extrapolation
inv_freq_llama = torch.where((wavelen < high_freq_wavelen) & (wavelen > low_freq_wavelen), smoothed_inv_freq, inv_freq_llama)

return inv_freq_llama

@torch.no_grad()
def forward(self, x, position_ids):
# x: [bs, num_attention_heads, seq_len, head_size]
if self.rope_type == "llama3":
self._update_llama3_inv_freq(position_ids, x.device)

inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285

device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

cos = cos * self.attention_scaling
sin = sin * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

def _update_llama3_inv_freq(self, position_ids, device):
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached:
self.inv_freq = self._compute_llama3_inv_freq(device)
self.max_seq_len_cached = seq_len
elif seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len:
self.inv_freq = self.original_inv_freq
self.max_seq_len_cached = self.original_max_seq_len


class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
Expand Down Expand Up @@ -338,27 +394,24 @@ def _init_rope(self):
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
config=self.config,
)
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
scaling_type = self.config.rope_scaling.get("rope_type", self.config.rope_scaling.get("type", "default"))
scaling_factor = self.config.rope_scaling.get("factor", 1.0)
if scaling_type in ["linear", "dynamic", "llama3"]:
self.rotary_emb = LlamaRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
elif scaling_type == "dynamic":
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
rope_type=scaling_type,
config=self.config,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")


def forward(
self,
hidden_states: torch.Tensor,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ keywords = [
]

dependencies = [
"transformers == 4.41.1",
"transformers == 4.45.2",
"torch == 2.4.0",
"torch-xla[tpu] == 2.4.0",
'typer == 0.6.1',
Expand Down
60 changes: 60 additions & 0 deletions tests/akg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import os
import torch
from transformers import AutoTokenizer
from optimum.tpu.distributed_model import DistributedModel
from loguru import logger
import sys

# Remove default handler
logger.remove()

# Add a handler to write to file
logger.add("distributed_model.log", rotation="100 MB", level="DEBUG")

# Add a handler to write to stderr
logger.add(sys.stderr, level="INFO")

def sample_greedy(logits):
next_logits = logits[:, -1]
next_token_id = torch.argmax(next_logits, dim=-1)[:, None].int()
return next_token_id

def _test_distributed_model_generation(model_id, max_new_tokens=20):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for tests, please create one test similar to tests/test_distributed_model.py (or modify the existing one). To launch it, you can use pytest: python -m pytest -sv /path/to/test_mytest.py::test_my_test_function.

print(f"Beginning test with model: {model_id}")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
tokenizer = AutoTokenizer.from_pretrained(model_id)
text = ["Running something in parallel means"]
inputs = tokenizer(text, return_tensors="pt")
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
tokens = input_ids.clone()

print("Initializing DistributedModel...")
model = DistributedModel(model_id, sample_greedy)

print("Generating tokens...")
for _ in range(max_new_tokens):
pos_ids = torch.arange(tokens.shape[1], device=tokens.device).unsqueeze(0)
next_token = model.prefill(input_ids=tokens, attention_mask=attention_mask, position_ids=pos_ids)
tokens = torch.cat([tokens, next_token], dim=-1)
attention_mask = torch.cat([attention_mask, torch.ones_like(next_token)], dim=-1)

# Optional: Break if EOS token is generated
if next_token.item() == tokenizer.eos_token_id:
break

decoded_text = tokenizer.batch_decode(tokens, skip_special_tokens=True)
print("\n------------------------------------------")
print("Generated text:")
print(decoded_text[0])
print("------------------------------------------")

if __name__ == "__main__":
print("Script started")
try:
_test_distributed_model_generation("meta-llama/Meta-Llama-3.1-8B", max_new_tokens=200)
except Exception as e:
print(f"An error occurred: {str(e)}")
import traceback
traceback.print_exc()
print("Script completed")
12 changes: 12 additions & 0 deletions tests/test-torch-xla.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import torch
import torch_xla.core.xla_model as xm

devices = xm.get_xla_supported_devices()
print(f'PyTorch can access {len(devices)} TPU cores')

# Example tensor operations on TPU
dev = xm.xla_device()
print(f"PyTorich device: {dev}")
t1 = torch.randn(3,3,device=dev)
t2 = torch.randn(3,3,device=dev)
print(t1 + t2)