Skip to content

Commit

Permalink
Merge branch 'main' into external_weight_fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-garvey committed Nov 30, 2023
2 parents 1e1edfc + 666e601 commit 33e5d63
Show file tree
Hide file tree
Showing 19 changed files with 809 additions and 548 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test-models.yml
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ jobs:
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
source shark.venv/bin/activate
pytest --update_tank -k vulkan
python build_tools/stable_diffusion_testing.py --device=vulkan
python build_tools/stable_diffusion_testing.py --device=vulkan --no-exit_on_fail
- name: Validate Vulkan Models (Windows)
if: matrix.suite == 'vulkan' && matrix.os == '7950x'
Expand Down
140 changes: 121 additions & 19 deletions apps/language_models/scripts/vicuna.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import argparse
from dataclasses import dataclass
import json
import re
import gc
from io import BytesIO
from os import environ
from pathlib import Path
from statistics import mean, stdev
from tqdm import tqdm
Expand Down Expand Up @@ -143,6 +145,12 @@
default=[],
help="Extra command line arguments passed to the IREE compiler. This can be specified multiple times to pass multiple arguments."
)
parser.add_argument(
"--enable_tracing",
default=False,
action=argparse.BooleanOptionalAction,
help="Enable profiling with Tracy. The script will wait for Tracy to connect and flush the profiling data after each token."
)

# Microbenchmarking options.
parser.add_argument(
Expand Down Expand Up @@ -1937,12 +1945,107 @@ def create_prompt(model_name, history):
return msg


def miliseconds_to_seconds(ms: float) -> float:
return ms / 1000.0


@dataclass
class BenchmarkRunInfo:
num_prompt_tokens : int
prefill_time_ms : float
token_times_ms : list[float]

def get_prefill_speed(self) -> float:
seconds = miliseconds_to_seconds(self.prefill_time_ms)
if seconds == 0.0:
return float('inf')
return self.num_prompt_tokens / seconds

def num_generated_tokens(self) -> int:
return len(self.token_times_ms)

def get_decode_time_ms(self) -> float:
return sum(self.token_times_ms)

def get_decode_speed(self) -> float:
seconds = miliseconds_to_seconds(self.get_decode_time_ms())
if seconds == 0.0:
return float('inf')
return self.num_generated_tokens() / seconds

def get_e2e_time_ms(self) -> float:
return self.prefill_time_ms + self.get_decode_time_ms()

def get_e2e_decode_speed(self) -> float:
seconds = miliseconds_to_seconds(self.get_e2e_time_ms())
if seconds == 0.0:
return float('inf')
return self.num_generated_tokens() / seconds

def get_e2e_token_processing_speed(self) -> float:
seconds = miliseconds_to_seconds(self.get_e2e_time_ms())
if seconds == 0.0:
return float('inf')
return (self.num_prompt_tokens + self.num_generated_tokens()) / seconds

def print(self) -> None:
total_tokens = self.num_prompt_tokens + self.num_generated_tokens()
print(f"Num tokens: {self.num_prompt_tokens:} (prompt), {self.num_generated_tokens()} (generated), {total_tokens} (total)")
print(f"Prefill: {self.prefill_time_ms:.2f} ms, {self.get_prefill_speed():.2f} tokens/s")
print(f"Decode: {self.get_decode_time_ms():.2f} ms, {self.get_decode_speed():.2f} tokens/s")
print(f"Decode end-2-end: {self.get_e2e_decode_speed():.2f} tokens/s (w/o prompt), {self.get_e2e_token_processing_speed():.2f} tokens/s (w/ prompt)")


def print_aggregate_stats(run_infos: list[BenchmarkRunInfo]) -> None:
num_iterations = len(run_infos)
print(f'Number of iterations: {num_iterations}')
if num_iterations == 0:
return

if len(run_infos) == 1:
run_infos[0].print()
return

total_tokens = run_infos[0].num_prompt_tokens + run_infos[0].num_generated_tokens()
print(f"Num tokens: {run_infos[0].num_prompt_tokens} (prompt), {run_infos[0].num_generated_tokens()} (generated), {total_tokens} (total)")

def avg_and_stdev(data):
x = list(data)
return mean(x), stdev(x)

avg_prefill_ms, stdev_prefill = avg_and_stdev(x.prefill_time_ms for x in run_infos)
avg_prefill_speed = mean(x.get_prefill_speed() for x in run_infos)
print(f"Prefill: avg. {avg_prefill_ms:.2f} ms (stdev {stdev_prefill:.2f}), avg. {avg_prefill_speed:.2f} tokens/s")

avg_decode_ms, stdev_decode = avg_and_stdev(x.get_decode_time_ms() for x in run_infos)
avg_decode_speed = mean(x.get_decode_speed() for x in run_infos)
print(f"Decode: avg. {avg_decode_ms:.2f} ms (stdev {stdev_decode:.2f}), avg. {avg_decode_speed:.2f} tokens/s")

avg_e2e_decode_speed = mean(x.get_e2e_decode_speed() for x in run_infos)
avg_e2e_processing_speed = mean(x.get_e2e_token_processing_speed() for x in run_infos)
print(f"Decode end-2-end: avg. {avg_e2e_decode_speed:.2f} tokens/s (w/o prompt), avg. {avg_e2e_processing_speed:.2f} (w/ prompt)")


def enable_tracy_tracing():
# Make tracy wait for a caputre to be collected before exiting.
environ["TRACY_NO_EXIT"] = "1"

if "IREE_PY_RUNTIME" not in environ or environ["IREE_PY_RUNTIME"] != "tracy":
print("ERROR: Tracing enabled but tracy iree runtime not used.", file=sys.stderr)
print("Set the IREE_PY_RUNTIME=tracy environment variable.", file=sys.stderr)
sys.exit(1)


if __name__ == "__main__":
args, unknown = parser.parse_known_args()

_extra_args = list(args.Xiree_compile)

device_id = None

if args.enable_tracing:
enable_tracy_tracing()

# Process vulkan target triple.
# TODO: This feature should just be in a common utils for other LLMs and in general
# any model run via SHARK for Vulkan backend.
Expand Down Expand Up @@ -2035,8 +2138,7 @@ def create_prompt(model_name, history):

iteration = 0

prefill_times = []
avg_decode_speed = []
benchmark_run_infos = []

while True:
# TODO: Add break condition from user input
Expand All @@ -2052,35 +2154,35 @@ def create_prompt(model_name, history):
prompt = args.system_prompt + user_prompt
history = [[user_prompt, ""]]

token_count = 0
total_time_ms = 0.001 # In order to avoid divide by zero error
prefill_time = 0
prompt_token_count = len(vic.tokenizer(prompt).input_ids)
total_time_ms = 0.0 # In order to avoid divide by zero error
prefill_time_ms = 0
is_first = True
token_times_ms = []

for text, msg, exec_time in vic.generate(prompt, cli=True):
if args.enable_tracing:
vic.shark_model.shark_runner.iree_config.device.flush_profiling()

if msg is None:
if is_first:
prefill_time = exec_time
# Note that the prefill time is in seconds, and all the decoded tokens in ms.
prefill_time_ms = exec_time * 1000
is_first = False
else:
total_time_ms += exec_time
token_count += 1
token_times_ms.append(exec_time)
elif "formatted" in msg:
history[-1][1] = text
tokens_per_sec = (token_count / total_time_ms) * 1000
prefill_times.append(prefill_time)
avg_decode_speed.append(tokens_per_sec)

print("\nResponse:", text.strip())
print(f"\nNum tokens: {token_count}")
print(f"Prefill: {prefill_time:.2f} seconds")
print(f"Decode: {tokens_per_sec:.2f} tokens/s")
print(f"\nResponse:\n{text.strip()}\n")
run_info = BenchmarkRunInfo(prompt_token_count, prefill_time_ms, token_times_ms)
run_info.print()
benchmark_run_infos.append(run_info)

else:
sys.exit(
"unexpected message from the vicuna generate call, exiting."
)

if args.enable_microbenchmark:
print("\n### Final Statistics ###")
print("Number of iterations:", iteration - 1)
print(f"Prefill: avg. {mean(prefill_times):.2f} s, stdev {stdev(prefill_times):.2f}")
print(f"Decode: avg. {mean(avg_decode_speed):.2f} tokens/s, stdev {stdev(avg_decode_speed):.2f}")
print_aggregate_stats(benchmark_run_infos)
Loading

0 comments on commit 33e5d63

Please sign in to comment.