Skip to content

Commit

Permalink
TGI: optimize continuous batching and improve export (#506)
Browse files Browse the repository at this point in the history
* test(tgi): refactor tests

* test(tgi): add LLama test

* feat(tgi): avoid rebuilding cache on prefill

* feat(tgi): log disk usage when fetching/exporting model

* feat(tgi): fetch generation config during export
  • Loading branch information
dacorvo authored Mar 6, 2024
1 parent 249d0b6 commit 8f84127
Show file tree
Hide file tree
Showing 5 changed files with 269 additions and 126 deletions.
106 changes: 62 additions & 44 deletions text-generation-inference/server/text_generation_server/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def clear(self):
self._inputs = ""
self._generation_config = None
self._tokens = []
self._mask = []
self._mask = torch.tensor([])
self._selector = None
self._generated_tokens = 0
self._next_text_token_start = 0
Expand Down Expand Up @@ -182,15 +182,16 @@ def reset(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor, s
self._mask = attention_mask.clone()
self._selector = selector

def pause(self):
def pause(self, reset_on_pause: bool):
"""Mark the current slot as paused for generation.
Note that the KV cache for this slot will still be filled.
"""
# Drop the last token as it will be added back when resuming the slot
self._generated_tokens -= 1
# Subtract the number of cached tokens from the maximum number of tokens
self._generation_config.max_new_tokens -= self._generated_tokens
if reset_on_pause:
# Drop the last token as it will be added back when resuming the slot
self._generated_tokens -= 1
# Subtract the number of cached tokens from the maximum number of tokens
self._generation_config.max_new_tokens -= self._generated_tokens
self._state = Slot.State.PAUSE

def resume(self):
Expand Down Expand Up @@ -291,6 +292,7 @@ def __init__(
tokenizer: PreTrainedTokenizerBase,
):
self.model = model
self.rebuild_cache_on_prefill = not self.model.continuous_batching
# Specify padding options for decoder-only architecture
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "left"
Expand Down Expand Up @@ -349,29 +351,45 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
)
# Assign each request to an empty slot
logger.debug(f"Prefilling {len(batch.requests)} new request(s) with {len(empty_slots)} empty slot(s)")
new_slots = []
for request in batch.requests:
slot = empty_slots.pop()
slot.assign(request, self.model.generation_config)
new_slots.append(slot)
logger.debug(f"Request {slot.request_id} assigned to slot {slot.id}")
if self.rebuild_cache_on_prefill:
# We will clear pending slots and prefill all slots
prefill_slots = self.slots
seq_ids = None
else:
# We only need to pass inputs for the new requests
prefill_slots = new_slots
seq_ids = torch.tensor([slot.id for slot in prefill_slots])
# Reconstruct the full inputs (without padding) as seen by the model.
# This comprises:
# - the inputs for new requests,
# - the inputs and the generated text that has already been cached (i.e. excluding the last generated token)
# for unfinished requests.
inputs = [slot.cached_text for slot in self.slots]
# - only when rebuilding the cache, the inputs and the generated text that has already
# been cached (i.e. excluding the last generated token) for unfinished requests.
inputs = [slot.cached_text for slot in prefill_slots]
# Tokenize with padding
padded_inputs = self.tokenizer(inputs, return_tensors="pt", padding=True)
# If needed truncate sequences to fit into the static dimensions
seq_length = min(padded_inputs.input_ids.shape[-1], self.model.max_length)
input_ids = padded_inputs.input_ids[:, :seq_length]
attention_mask = padded_inputs.attention_mask[:, :seq_length]
# Pause previously active slots during generation and store their last token.
# Pause previously active slots during generation
next_tokens = []
for slot in active_slots:
next_tokens.append(slot.next_token)
slot.pause()
slot.pause(reset_on_pause=self.rebuild_cache_on_prefill)
if self.rebuild_cache_on_prefill:
# The slot will be reset, so we need to store its next token
next_tokens.append(slot.next_token)
# Each slot must be reset with the padded inputs and masks
for i, slot in enumerate(self.slots):
if self.rebuild_cache_on_prefill:
reset_slots = self.slots
else:
reset_slots = prefill_slots
for i, slot in enumerate(reset_slots):
if slot.state != slot.state.EMPTY:
slot_input_ids = input_ids[i : i + 1, :]
# Padded input ids are also required to set logits processors and stopping criterias
Expand All @@ -381,17 +399,17 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
slot_input_ids = slot_input_ids.squeeze(dim=0).type(torch.int64)
slot_attention_mask = attention_mask[i]
slot.reset(slot_input_ids, slot_attention_mask, selector)
# Pause previously active slots during generation.
# The KV cache of paused slots will be prefilled during generation but new tokens
# will be ignored, as they have already been generated and sent back in the last decode.
model_inputs = self.model.prepare_inputs_for_prefill(input_ids, attention_mask)
# Note: when rebuilding cache on prefill, the new tokens on paused slots will be ignored,
# as they have already been generated and sent back in the last decode.
model_inputs = self.model.prepare_inputs_for_prefill(input_ids, attention_mask, seq_ids)
logits = self.model(**model_inputs)[0]
generation, next_batch = self._generate_token(batch.id, logits, input_ids)
# Reactivate previously active slots for the next decode, and append
# back their next token.
for slot, next_token in zip(active_slots, next_tokens):
slot.append(next_token)
generation, next_batch = self._generate_token(prefill_slots, batch.id, logits, input_ids)
# Reactivate previously active slots for the next decode
for i, slot in enumerate(active_slots):
slot.resume()
if self.rebuild_cache_on_prefill:
# Append back the next token
slot.append(next_tokens[i])
logger.debug("Model ready for decoding")
return generation, next_batch

Expand All @@ -412,37 +430,37 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa
# just carry on with decoding. We adopt the id of the first
# batch in the list as our next batch id.
next_batch_id = batches[0].id
# Reconstruct input_ids and attention_mask from slots
input_ids = None
attention_mask = None
for i, slot in enumerate(self.slots):
active_slots = [slot for slot in self.slots if slot.state == slot.State.READY]
if len(active_slots) == 0:
raise ValueError("Unable to decode tokens for non-prefilled batches (probably due to a previous failure)")
if self.model.continuous_batching:
decode_slots = active_slots
seq_ids = torch.tensor([slot.id for slot in decode_slots])
else:
decode_slots = self.slots
seq_ids = None
# Reconstruct input_ids and attention_mask from decode slots
n_slots = len(decode_slots)
input_ids = torch.full([n_slots, 1], fill_value=self.tokenizer.eos_token_id, dtype=torch.int64)
max_length = 0
for slot in decode_slots:
max_length = max(max_length, slot.attention_mask.size(-1))
attention_mask = torch.zeros([n_slots, max_length], dtype=torch.int64)
for i, slot in enumerate(decode_slots):
if slot.state != Slot.State.EMPTY:
if input_ids is None:
# Create blank inputs covering all slots (even empty ones)
input_ids = torch.full(
[self.model.batch_size, 1], fill_value=self.tokenizer.eos_token_id, dtype=torch.int64
)
# input_ids are simply the tokens generated by the last decode or prefill requests (other tokens are cached)
input_ids[i, 0] = slot.next_token
if attention_mask is None:
# Create default mask covering all slots (even empty ones)
attention_mask = torch.zeros(
[self.model.batch_size, slot.attention_mask.size(-1)], dtype=torch.int64
)
attention_mask[:, -1] = 1
attention_mask[i, :] = slot.attention_mask
if input_ids is None:
raise ValueError("Unable to decode tokens for non-prefilled batches (probably due to a previous failure)")
model_inputs = self.model.prepare_inputs_for_decode(input_ids, attention_mask)
attention_mask[i, : slot.attention_mask.size(-1)] = slot.attention_mask
model_inputs = self.model.prepare_inputs_for_decode(input_ids, attention_mask, seq_ids)
logits = self.model(**model_inputs)[0]
return self._generate_token(next_batch_id, logits, input_ids)
return self._generate_token(decode_slots, next_batch_id, logits, input_ids)

def _generate_token(
self, next_batch_id: int, logits: torch.Tensor, input_ids: torch.LongTensor
self, slots: List[Slot], next_batch_id: int, logits: torch.Tensor, input_ids: torch.LongTensor
) -> Tuple[List[Generation], CachedBatch]:
generations = []
active_slots = False
for i, slot in enumerate(self.slots):
for i, slot in enumerate(slots):
if slot.state != Slot.State.READY:
continue
request_id = slot.request_id
Expand Down
24 changes: 22 additions & 2 deletions text-generation-inference/server/text_generation_server/model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import os
import shutil
import time
from typing import Optional

from huggingface_hub import snapshot_download
from huggingface_hub.constants import HF_HUB_CACHE
from loguru import logger
from transformers import AutoConfig, AutoTokenizer
from transformers import AutoConfig, AutoTokenizer, GenerationConfig

from optimum.neuron import NeuronModelForCausalLM
from optimum.neuron.utils import ModelCacheEntry, get_hub_cached_entries
Expand Down Expand Up @@ -50,6 +51,16 @@ def is_cached(model_id, neuron_config):
return in_cache


def log_cache_size():
path = HF_HUB_CACHE
if os.path.exists(path):
usage = shutil.disk_usage(path)
gb = 2**30
logger.info(f"Cache disk [{path}]: total = {usage.total/gb:.2f} G, free = {usage.free/gb:.2f} G")
else:
raise ValueError(f"The cache directory ({path}) does not exist.")


def fetch_model(
model_id: str,
revision: Optional[str] = None,
Expand All @@ -75,8 +86,9 @@ def fetch_model(
# Note that the model may already be present in the cache.
config = AutoConfig.from_pretrained(model_id, revision=revision)
neuron_config = getattr(config, "neuron", None)
log_cache_size()
if neuron_config is not None:
logger.info("Fetching revision {} for neuron model {}".format(revision, model_id))
logger.info(f"Fetching revision [{revision}] for neuron model {model_id} under {HF_HUB_CACHE}")
return snapshot_download(model_id, revision=revision)
# Not a neuron model: evaluate the export config and check if it has been exported locally
export_kwargs = get_export_kwargs_from_env()
Expand All @@ -99,14 +111,22 @@ def fetch_model(
logger.warning(f"{model_id} is not a neuron model: it will be exported using cached artifacts.")
start = time.time()
logger.info(f"Exporting model to neuron with config {neuron_config}.")
log_cache_size()
start = time.time()
model = NeuronModelForCausalLM.from_pretrained(model_id, export=True, **export_kwargs)
end = time.time()
logger.info(f"Model successfully exported in {end - start:.2f} s.")
logger.info(f"Saving exported model to local storage under {export_path}.")
log_cache_size()
model.save_pretrained(export_path)
logger.info(f"Saving model tokenizer under {export_path}.")
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)
tokenizer.save_pretrained(export_path)
try:
config = GenerationConfig.from_pretrained(model_id, revision=revision)
config.save_pretrained(export_path)
logger.info(f"Saved model default generation config under {export_path}.")
except:
logger.warning(f"No default generation config found for {model_id}.")
logger.info(f"Model successfully exported in {end - start:.2f} s under {export_path}.")
return export_path
Loading

0 comments on commit 8f84127

Please sign in to comment.