diff --git a/.github/workflows/test-models.yml b/.github/workflows/test-models.yml index f647d303d9..2af40a6a1a 100644 --- a/.github/workflows/test-models.yml +++ b/.github/workflows/test-models.yml @@ -147,7 +147,7 @@ jobs: PYTHON=python${{ matrix.python-version }} ./setup_venv.sh source shark.venv/bin/activate pytest --update_tank -k vulkan - python build_tools/stable_diffusion_testing.py --device=vulkan + python build_tools/stable_diffusion_testing.py --device=vulkan --no-exit_on_fail - name: Validate Vulkan Models (Windows) if: matrix.suite == 'vulkan' && matrix.os == '7950x' diff --git a/apps/language_models/scripts/vicuna.py b/apps/language_models/scripts/vicuna.py index 81211654fb..72a36a6741 100644 --- a/apps/language_models/scripts/vicuna.py +++ b/apps/language_models/scripts/vicuna.py @@ -1,8 +1,10 @@ import argparse +from dataclasses import dataclass import json import re import gc from io import BytesIO +from os import environ from pathlib import Path from statistics import mean, stdev from tqdm import tqdm @@ -143,6 +145,12 @@ default=[], help="Extra command line arguments passed to the IREE compiler. This can be specified multiple times to pass multiple arguments." ) +parser.add_argument( + "--enable_tracing", + default=False, + action=argparse.BooleanOptionalAction, + help="Enable profiling with Tracy. The script will wait for Tracy to connect and flush the profiling data after each token." +) # Microbenchmarking options. parser.add_argument( @@ -1937,12 +1945,107 @@ def create_prompt(model_name, history): return msg +def miliseconds_to_seconds(ms: float) -> float: + return ms / 1000.0 + + +@dataclass +class BenchmarkRunInfo: + num_prompt_tokens : int + prefill_time_ms : float + token_times_ms : list[float] + + def get_prefill_speed(self) -> float: + seconds = miliseconds_to_seconds(self.prefill_time_ms) + if seconds == 0.0: + return float('inf') + return self.num_prompt_tokens / seconds + + def num_generated_tokens(self) -> int: + return len(self.token_times_ms) + + def get_decode_time_ms(self) -> float: + return sum(self.token_times_ms) + + def get_decode_speed(self) -> float: + seconds = miliseconds_to_seconds(self.get_decode_time_ms()) + if seconds == 0.0: + return float('inf') + return self.num_generated_tokens() / seconds + + def get_e2e_time_ms(self) -> float: + return self.prefill_time_ms + self.get_decode_time_ms() + + def get_e2e_decode_speed(self) -> float: + seconds = miliseconds_to_seconds(self.get_e2e_time_ms()) + if seconds == 0.0: + return float('inf') + return self.num_generated_tokens() / seconds + + def get_e2e_token_processing_speed(self) -> float: + seconds = miliseconds_to_seconds(self.get_e2e_time_ms()) + if seconds == 0.0: + return float('inf') + return (self.num_prompt_tokens + self.num_generated_tokens()) / seconds + + def print(self) -> None: + total_tokens = self.num_prompt_tokens + self.num_generated_tokens() + print(f"Num tokens: {self.num_prompt_tokens:} (prompt), {self.num_generated_tokens()} (generated), {total_tokens} (total)") + print(f"Prefill: {self.prefill_time_ms:.2f} ms, {self.get_prefill_speed():.2f} tokens/s") + print(f"Decode: {self.get_decode_time_ms():.2f} ms, {self.get_decode_speed():.2f} tokens/s") + print(f"Decode end-2-end: {self.get_e2e_decode_speed():.2f} tokens/s (w/o prompt), {self.get_e2e_token_processing_speed():.2f} tokens/s (w/ prompt)") + + +def print_aggregate_stats(run_infos: list[BenchmarkRunInfo]) -> None: + num_iterations = len(run_infos) + print(f'Number of iterations: {num_iterations}') + if num_iterations == 0: + return + + if len(run_infos) == 1: + run_infos[0].print() + return + + total_tokens = run_infos[0].num_prompt_tokens + run_infos[0].num_generated_tokens() + print(f"Num tokens: {run_infos[0].num_prompt_tokens} (prompt), {run_infos[0].num_generated_tokens()} (generated), {total_tokens} (total)") + + def avg_and_stdev(data): + x = list(data) + return mean(x), stdev(x) + + avg_prefill_ms, stdev_prefill = avg_and_stdev(x.prefill_time_ms for x in run_infos) + avg_prefill_speed = mean(x.get_prefill_speed() for x in run_infos) + print(f"Prefill: avg. {avg_prefill_ms:.2f} ms (stdev {stdev_prefill:.2f}), avg. {avg_prefill_speed:.2f} tokens/s") + + avg_decode_ms, stdev_decode = avg_and_stdev(x.get_decode_time_ms() for x in run_infos) + avg_decode_speed = mean(x.get_decode_speed() for x in run_infos) + print(f"Decode: avg. {avg_decode_ms:.2f} ms (stdev {stdev_decode:.2f}), avg. {avg_decode_speed:.2f} tokens/s") + + avg_e2e_decode_speed = mean(x.get_e2e_decode_speed() for x in run_infos) + avg_e2e_processing_speed = mean(x.get_e2e_token_processing_speed() for x in run_infos) + print(f"Decode end-2-end: avg. {avg_e2e_decode_speed:.2f} tokens/s (w/o prompt), avg. {avg_e2e_processing_speed:.2f} (w/ prompt)") + + +def enable_tracy_tracing(): + # Make tracy wait for a caputre to be collected before exiting. + environ["TRACY_NO_EXIT"] = "1" + + if "IREE_PY_RUNTIME" not in environ or environ["IREE_PY_RUNTIME"] != "tracy": + print("ERROR: Tracing enabled but tracy iree runtime not used.", file=sys.stderr) + print("Set the IREE_PY_RUNTIME=tracy environment variable.", file=sys.stderr) + sys.exit(1) + + if __name__ == "__main__": args, unknown = parser.parse_known_args() _extra_args = list(args.Xiree_compile) device_id = None + + if args.enable_tracing: + enable_tracy_tracing() + # Process vulkan target triple. # TODO: This feature should just be in a common utils for other LLMs and in general # any model run via SHARK for Vulkan backend. @@ -2035,8 +2138,7 @@ def create_prompt(model_name, history): iteration = 0 - prefill_times = [] - avg_decode_speed = [] + benchmark_run_infos = [] while True: # TODO: Add break condition from user input @@ -2052,28 +2154,30 @@ def create_prompt(model_name, history): prompt = args.system_prompt + user_prompt history = [[user_prompt, ""]] - token_count = 0 - total_time_ms = 0.001 # In order to avoid divide by zero error - prefill_time = 0 + prompt_token_count = len(vic.tokenizer(prompt).input_ids) + total_time_ms = 0.0 # In order to avoid divide by zero error + prefill_time_ms = 0 is_first = True + token_times_ms = [] + for text, msg, exec_time in vic.generate(prompt, cli=True): + if args.enable_tracing: + vic.shark_model.shark_runner.iree_config.device.flush_profiling() + if msg is None: if is_first: - prefill_time = exec_time + # Note that the prefill time is in seconds, and all the decoded tokens in ms. + prefill_time_ms = exec_time * 1000 is_first = False else: - total_time_ms += exec_time - token_count += 1 + token_times_ms.append(exec_time) elif "formatted" in msg: history[-1][1] = text - tokens_per_sec = (token_count / total_time_ms) * 1000 - prefill_times.append(prefill_time) - avg_decode_speed.append(tokens_per_sec) - - print("\nResponse:", text.strip()) - print(f"\nNum tokens: {token_count}") - print(f"Prefill: {prefill_time:.2f} seconds") - print(f"Decode: {tokens_per_sec:.2f} tokens/s") + print(f"\nResponse:\n{text.strip()}\n") + run_info = BenchmarkRunInfo(prompt_token_count, prefill_time_ms, token_times_ms) + run_info.print() + benchmark_run_infos.append(run_info) + else: sys.exit( "unexpected message from the vicuna generate call, exiting." @@ -2081,6 +2185,4 @@ def create_prompt(model_name, history): if args.enable_microbenchmark: print("\n### Final Statistics ###") - print("Number of iterations:", iteration - 1) - print(f"Prefill: avg. {mean(prefill_times):.2f} s, stdev {stdev(prefill_times):.2f}") - print(f"Decode: avg. {mean(avg_decode_speed):.2f} tokens/s, stdev {stdev(avg_decode_speed):.2f}") + print_aggregate_stats(benchmark_run_infos) diff --git a/apps/language_models/src/model_wrappers/falcon_sharded_model.py b/apps/language_models/src/model_wrappers/falcon_sharded_model.py index aefd9dc152..b5a16c3d40 100644 --- a/apps/language_models/src/model_wrappers/falcon_sharded_model.py +++ b/apps/language_models/src/model_wrappers/falcon_sharded_model.py @@ -69,24 +69,100 @@ def forward(self, hidden_states): return torch.tensor(new_hidden_states) -class DecoderLayer(torch.nn.Module): +class FourWayShardingDecoderLayer(torch.nn.Module): def __init__(self, decoder_layer_model, falcon_variant): super().__init__() self.model = decoder_layer_model + self.falcon_variant = falcon_variant def forward(self, hidden_states, attention_mask): - output = self.model.forward( - hidden_states=hidden_states, - alibi=None, - attention_mask=attention_mask, - use_cache=True, + new_pkvs = [] + for layer in self.model: + outputs = layer( + hidden_states=hidden_states, + alibi=None, + attention_mask=attention_mask, + use_cache=True, + ) + hidden_states = outputs[0] + new_pkvs.append( + ( + outputs[-1][0], + outputs[-1][1], + ) + ) + + ( + (new_pkv00, new_pkv01), + (new_pkv10, new_pkv11), + (new_pkv20, new_pkv21), + (new_pkv30, new_pkv31), + (new_pkv40, new_pkv41), + (new_pkv50, new_pkv51), + (new_pkv60, new_pkv61), + (new_pkv70, new_pkv71), + (new_pkv80, new_pkv81), + (new_pkv90, new_pkv91), + (new_pkv100, new_pkv101), + (new_pkv110, new_pkv111), + (new_pkv120, new_pkv121), + (new_pkv130, new_pkv131), + (new_pkv140, new_pkv141), + (new_pkv150, new_pkv151), + (new_pkv160, new_pkv161), + (new_pkv170, new_pkv171), + (new_pkv180, new_pkv181), + (new_pkv190, new_pkv191), + ) = new_pkvs + result = ( + hidden_states, + new_pkv00, + new_pkv01, + new_pkv10, + new_pkv11, + new_pkv20, + new_pkv21, + new_pkv30, + new_pkv31, + new_pkv40, + new_pkv41, + new_pkv50, + new_pkv51, + new_pkv60, + new_pkv61, + new_pkv70, + new_pkv71, + new_pkv80, + new_pkv81, + new_pkv90, + new_pkv91, + new_pkv100, + new_pkv101, + new_pkv110, + new_pkv111, + new_pkv120, + new_pkv121, + new_pkv130, + new_pkv131, + new_pkv140, + new_pkv141, + new_pkv150, + new_pkv151, + new_pkv160, + new_pkv161, + new_pkv170, + new_pkv171, + new_pkv180, + new_pkv181, + new_pkv190, + new_pkv191, ) - return (output[0], output[1][0], output[1][1]) + return result -class CompiledDecoderLayer(torch.nn.Module): +class CompiledFourWayShardingDecoderLayer(torch.nn.Module): def __init__( - self, layer_id, device_idx, falcon_variant, device, precision + self, layer_id, device_idx, falcon_variant, device, precision, model ): super().__init__() self.layer_id = layer_id @@ -94,6 +170,7 @@ def __init__( self.falcon_variant = falcon_variant self.device = device self.precision = precision + self.model = model def forward( self, @@ -109,19 +186,7 @@ def forward( torch.cuda.empty_cache() gc.collect() - from pathlib import Path - from apps.language_models.utils import get_vmfb_from_path - self.falcon_vmfb_path = Path( - f"falcon_{self.falcon_variant}_layer_{self.layer_id}_{self.precision}_{self.device}.vmfb" - ) - print("vmfb path for layer: ", self.falcon_vmfb_path) - self.model = get_vmfb_from_path( - self.falcon_vmfb_path, - self.device, - "linalg", - device_id=self.device_index, - ) if self.model is None: raise ValueError("Layer vmfb not found") @@ -131,29 +196,101 @@ def forward( if alibi is not None or layer_past is not None: raise ValueError("Past Key Values and alibi should be None") else: - new_hidden_states, pkv1, pkv2 = self.model( + output = self.model( "forward", ( hidden_states, attention_mask, ), ) - del self.model - - return tuple( - [ - torch.tensor(new_hidden_states), - tuple( - [ - torch.tensor(pkv1), - torch.tensor(pkv2), - ] - ), - ] - ) + + result = ( + torch.tensor(output[0]), + ( + torch.tensor(output[1]), + torch.tensor(output[2]), + ), + ( + torch.tensor(output[3]), + torch.tensor(output[4]), + ), + ( + torch.tensor(output[5]), + torch.tensor(output[6]), + ), + ( + torch.tensor(output[7]), + torch.tensor(output[8]), + ), + ( + torch.tensor(output[9]), + torch.tensor(output[10]), + ), + ( + torch.tensor(output[11]), + torch.tensor(output[12]), + ), + ( + torch.tensor(output[13]), + torch.tensor(output[14]), + ), + ( + torch.tensor(output[15]), + torch.tensor(output[16]), + ), + ( + torch.tensor(output[17]), + torch.tensor(output[18]), + ), + ( + torch.tensor(output[19]), + torch.tensor(output[20]), + ), + ( + torch.tensor(output[21]), + torch.tensor(output[22]), + ), + ( + torch.tensor(output[23]), + torch.tensor(output[24]), + ), + ( + torch.tensor(output[25]), + torch.tensor(output[26]), + ), + ( + torch.tensor(output[27]), + torch.tensor(output[28]), + ), + ( + torch.tensor(output[29]), + torch.tensor(output[30]), + ), + ( + torch.tensor(output[31]), + torch.tensor(output[32]), + ), + ( + torch.tensor(output[33]), + torch.tensor(output[34]), + ), + ( + torch.tensor(output[35]), + torch.tensor(output[36]), + ), + ( + torch.tensor(output[37]), + torch.tensor(output[38]), + ), + ( + torch.tensor(output[39]), + torch.tensor(output[40]), + ), + ) + return result -class EightDecoderLayer(torch.nn.Module): +class TwoWayShardingDecoderLayer(torch.nn.Module): def __init__(self, decoder_layer_model, falcon_variant): super().__init__() self.model = decoder_layer_model @@ -175,163 +312,138 @@ def forward(self, hidden_states, attention_mask): outputs[-1][1], ) ) - if self.falcon_variant == "7b": - ( - (new_pkv00, new_pkv01), - (new_pkv10, new_pkv11), - (new_pkv20, new_pkv21), - (new_pkv30, new_pkv31), - (new_pkv40, new_pkv41), - (new_pkv50, new_pkv51), - (new_pkv60, new_pkv61), - (new_pkv70, new_pkv71), - ) = new_pkvs - result = ( - hidden_states, - new_pkv00, - new_pkv01, - new_pkv10, - new_pkv11, - new_pkv20, - new_pkv21, - new_pkv30, - new_pkv31, - new_pkv40, - new_pkv41, - new_pkv50, - new_pkv51, - new_pkv60, - new_pkv61, - new_pkv70, - new_pkv71, - ) - elif self.falcon_variant == "40b": - ( - (new_pkv00, new_pkv01), - (new_pkv10, new_pkv11), - (new_pkv20, new_pkv21), - (new_pkv30, new_pkv31), - (new_pkv40, new_pkv41), - (new_pkv50, new_pkv51), - (new_pkv60, new_pkv61), - (new_pkv70, new_pkv71), - (new_pkv80, new_pkv81), - (new_pkv90, new_pkv91), - (new_pkv100, new_pkv101), - (new_pkv110, new_pkv111), - (new_pkv120, new_pkv121), - (new_pkv130, new_pkv131), - (new_pkv140, new_pkv141), - ) = new_pkvs - result = ( - hidden_states, - new_pkv00, - new_pkv01, - new_pkv10, - new_pkv11, - new_pkv20, - new_pkv21, - new_pkv30, - new_pkv31, - new_pkv40, - new_pkv41, - new_pkv50, - new_pkv51, - new_pkv60, - new_pkv61, - new_pkv70, - new_pkv71, - new_pkv80, - new_pkv81, - new_pkv90, - new_pkv91, - new_pkv100, - new_pkv101, - new_pkv110, - new_pkv111, - new_pkv120, - new_pkv121, - new_pkv130, - new_pkv131, - new_pkv140, - new_pkv141, - ) - elif self.falcon_variant == "180b": - ( - (new_pkv00, new_pkv01), - (new_pkv10, new_pkv11), - (new_pkv20, new_pkv21), - (new_pkv30, new_pkv31), - (new_pkv40, new_pkv41), - (new_pkv50, new_pkv51), - (new_pkv60, new_pkv61), - (new_pkv70, new_pkv71), - (new_pkv80, new_pkv81), - (new_pkv90, new_pkv91), - (new_pkv100, new_pkv101), - (new_pkv110, new_pkv111), - (new_pkv120, new_pkv121), - (new_pkv130, new_pkv131), - (new_pkv140, new_pkv141), - (new_pkv150, new_pkv151), - (new_pkv160, new_pkv161), - (new_pkv170, new_pkv171), - (new_pkv180, new_pkv181), - (new_pkv190, new_pkv191), - ) = new_pkvs - result = ( - hidden_states, - new_pkv00, - new_pkv01, - new_pkv10, - new_pkv11, - new_pkv20, - new_pkv21, - new_pkv30, - new_pkv31, - new_pkv40, - new_pkv41, - new_pkv50, - new_pkv51, - new_pkv60, - new_pkv61, - new_pkv70, - new_pkv71, - new_pkv80, - new_pkv81, - new_pkv90, - new_pkv91, - new_pkv100, - new_pkv101, - new_pkv110, - new_pkv111, - new_pkv120, - new_pkv121, - new_pkv130, - new_pkv131, - new_pkv140, - new_pkv141, - new_pkv150, - new_pkv151, - new_pkv160, - new_pkv161, - new_pkv170, - new_pkv171, - new_pkv180, - new_pkv181, - new_pkv190, - new_pkv191, - ) - else: - raise ValueError( - "Unsupported Falcon variant: ", self.falcon_variant - ) + + ( + (new_pkv00, new_pkv01), + (new_pkv10, new_pkv11), + (new_pkv20, new_pkv21), + (new_pkv30, new_pkv31), + (new_pkv40, new_pkv41), + (new_pkv50, new_pkv51), + (new_pkv60, new_pkv61), + (new_pkv70, new_pkv71), + (new_pkv80, new_pkv81), + (new_pkv90, new_pkv91), + (new_pkv100, new_pkv101), + (new_pkv110, new_pkv111), + (new_pkv120, new_pkv121), + (new_pkv130, new_pkv131), + (new_pkv140, new_pkv141), + (new_pkv150, new_pkv151), + (new_pkv160, new_pkv161), + (new_pkv170, new_pkv171), + (new_pkv180, new_pkv181), + (new_pkv190, new_pkv191), + (new_pkv200, new_pkv201), + (new_pkv210, new_pkv211), + (new_pkv220, new_pkv221), + (new_pkv230, new_pkv231), + (new_pkv240, new_pkv241), + (new_pkv250, new_pkv251), + (new_pkv260, new_pkv261), + (new_pkv270, new_pkv271), + (new_pkv280, new_pkv281), + (new_pkv290, new_pkv291), + (new_pkv300, new_pkv301), + (new_pkv310, new_pkv311), + (new_pkv320, new_pkv321), + (new_pkv330, new_pkv331), + (new_pkv340, new_pkv341), + (new_pkv350, new_pkv351), + (new_pkv360, new_pkv361), + (new_pkv370, new_pkv371), + (new_pkv380, new_pkv381), + (new_pkv390, new_pkv391), + ) = new_pkvs + result = ( + hidden_states, + new_pkv00, + new_pkv01, + new_pkv10, + new_pkv11, + new_pkv20, + new_pkv21, + new_pkv30, + new_pkv31, + new_pkv40, + new_pkv41, + new_pkv50, + new_pkv51, + new_pkv60, + new_pkv61, + new_pkv70, + new_pkv71, + new_pkv80, + new_pkv81, + new_pkv90, + new_pkv91, + new_pkv100, + new_pkv101, + new_pkv110, + new_pkv111, + new_pkv120, + new_pkv121, + new_pkv130, + new_pkv131, + new_pkv140, + new_pkv141, + new_pkv150, + new_pkv151, + new_pkv160, + new_pkv161, + new_pkv170, + new_pkv171, + new_pkv180, + new_pkv181, + new_pkv190, + new_pkv191, + new_pkv200, + new_pkv201, + new_pkv210, + new_pkv211, + new_pkv220, + new_pkv221, + new_pkv230, + new_pkv231, + new_pkv240, + new_pkv241, + new_pkv250, + new_pkv251, + new_pkv260, + new_pkv261, + new_pkv270, + new_pkv271, + new_pkv280, + new_pkv281, + new_pkv290, + new_pkv291, + new_pkv300, + new_pkv301, + new_pkv310, + new_pkv311, + new_pkv320, + new_pkv321, + new_pkv330, + new_pkv331, + new_pkv340, + new_pkv341, + new_pkv350, + new_pkv351, + new_pkv360, + new_pkv361, + new_pkv370, + new_pkv371, + new_pkv380, + new_pkv381, + new_pkv390, + new_pkv391, + ) return result -class CompiledEightDecoderLayer(torch.nn.Module): +class CompiledTwoWayShardingDecoderLayer(torch.nn.Module): def __init__( - self, layer_id, device_idx, falcon_variant, device, precision + self, layer_id, device_idx, falcon_variant, device, precision, model ): super().__init__() self.layer_id = layer_id @@ -339,6 +451,7 @@ def __init__( self.falcon_variant = falcon_variant self.device = device self.precision = precision + self.model = model def forward( self, @@ -354,19 +467,7 @@ def forward( torch.cuda.empty_cache() gc.collect() - from pathlib import Path - from apps.language_models.utils import get_vmfb_from_path - self.falcon_vmfb_path = Path( - f"falcon_{self.falcon_variant}_layer_{self.layer_id}_{self.precision}_{self.device}.vmfb" - ) - print("vmfb path for layer: ", self.falcon_vmfb_path) - self.model = get_vmfb_from_path( - self.falcon_vmfb_path, - self.device, - "linalg", - device_id=self.device_index, - ) if self.model is None: raise ValueError("Layer vmfb not found") @@ -383,196 +484,170 @@ def forward( attention_mask, ), ) - del self.model - if self.falcon_variant == "7b": - result = ( - torch.tensor(output[0]), - ( - torch.tensor(output[1]), - torch.tensor(output[2]), - ), - ( - torch.tensor(output[3]), - torch.tensor(output[4]), - ), - ( - torch.tensor(output[5]), - torch.tensor(output[6]), - ), - ( - torch.tensor(output[7]), - torch.tensor(output[8]), - ), - ( - torch.tensor(output[9]), - torch.tensor(output[10]), - ), - ( - torch.tensor(output[11]), - torch.tensor(output[12]), - ), - ( - torch.tensor(output[13]), - torch.tensor(output[14]), - ), - ( - torch.tensor(output[15]), - torch.tensor(output[16]), - ), - ) - elif self.falcon_variant == "40b": - result = ( - torch.tensor(output[0]), - ( - torch.tensor(output[1]), - torch.tensor(output[2]), - ), - ( - torch.tensor(output[3]), - torch.tensor(output[4]), - ), - ( - torch.tensor(output[5]), - torch.tensor(output[6]), - ), - ( - torch.tensor(output[7]), - torch.tensor(output[8]), - ), - ( - torch.tensor(output[9]), - torch.tensor(output[10]), - ), - ( - torch.tensor(output[11]), - torch.tensor(output[12]), - ), - ( - torch.tensor(output[13]), - torch.tensor(output[14]), - ), - ( - torch.tensor(output[15]), - torch.tensor(output[16]), - ), - ( - torch.tensor(output[17]), - torch.tensor(output[18]), - ), - ( - torch.tensor(output[19]), - torch.tensor(output[20]), - ), - ( - torch.tensor(output[21]), - torch.tensor(output[22]), - ), - ( - torch.tensor(output[23]), - torch.tensor(output[24]), - ), - ( - torch.tensor(output[25]), - torch.tensor(output[26]), - ), - ( - torch.tensor(output[27]), - torch.tensor(output[28]), - ), - ( - torch.tensor(output[29]), - torch.tensor(output[30]), - ), - ) - elif self.falcon_variant == "180b": - result = ( - torch.tensor(output[0]), - ( - torch.tensor(output[1]), - torch.tensor(output[2]), - ), - ( - torch.tensor(output[3]), - torch.tensor(output[4]), - ), - ( - torch.tensor(output[5]), - torch.tensor(output[6]), - ), - ( - torch.tensor(output[7]), - torch.tensor(output[8]), - ), - ( - torch.tensor(output[9]), - torch.tensor(output[10]), - ), - ( - torch.tensor(output[11]), - torch.tensor(output[12]), - ), - ( - torch.tensor(output[13]), - torch.tensor(output[14]), - ), - ( - torch.tensor(output[15]), - torch.tensor(output[16]), - ), - ( - torch.tensor(output[17]), - torch.tensor(output[18]), - ), - ( - torch.tensor(output[19]), - torch.tensor(output[20]), - ), - ( - torch.tensor(output[21]), - torch.tensor(output[22]), - ), - ( - torch.tensor(output[23]), - torch.tensor(output[24]), - ), - ( - torch.tensor(output[25]), - torch.tensor(output[26]), - ), - ( - torch.tensor(output[27]), - torch.tensor(output[28]), - ), - ( - torch.tensor(output[29]), - torch.tensor(output[30]), - ), - ( - torch.tensor(output[31]), - torch.tensor(output[32]), - ), - ( - torch.tensor(output[33]), - torch.tensor(output[34]), - ), - ( - torch.tensor(output[35]), - torch.tensor(output[36]), - ), - ( - torch.tensor(output[37]), - torch.tensor(output[38]), - ), - ( - torch.tensor(output[39]), - torch.tensor(output[40]), - ), - ) - else: - raise ValueError( - "Unsupported Falcon variant: ", self.falcon_variant - ) + result = ( + torch.tensor(output[0]), + ( + torch.tensor(output[1]), + torch.tensor(output[2]), + ), + ( + torch.tensor(output[3]), + torch.tensor(output[4]), + ), + ( + torch.tensor(output[5]), + torch.tensor(output[6]), + ), + ( + torch.tensor(output[7]), + torch.tensor(output[8]), + ), + ( + torch.tensor(output[9]), + torch.tensor(output[10]), + ), + ( + torch.tensor(output[11]), + torch.tensor(output[12]), + ), + ( + torch.tensor(output[13]), + torch.tensor(output[14]), + ), + ( + torch.tensor(output[15]), + torch.tensor(output[16]), + ), + ( + torch.tensor(output[17]), + torch.tensor(output[18]), + ), + ( + torch.tensor(output[19]), + torch.tensor(output[20]), + ), + ( + torch.tensor(output[21]), + torch.tensor(output[22]), + ), + ( + torch.tensor(output[23]), + torch.tensor(output[24]), + ), + ( + torch.tensor(output[25]), + torch.tensor(output[26]), + ), + ( + torch.tensor(output[27]), + torch.tensor(output[28]), + ), + ( + torch.tensor(output[29]), + torch.tensor(output[30]), + ), + ( + torch.tensor(output[31]), + torch.tensor(output[32]), + ), + ( + torch.tensor(output[33]), + torch.tensor(output[34]), + ), + ( + torch.tensor(output[35]), + torch.tensor(output[36]), + ), + ( + torch.tensor(output[37]), + torch.tensor(output[38]), + ), + ( + torch.tensor(output[39]), + torch.tensor(output[40]), + ), + ( + torch.tensor(output[41]), + torch.tensor(output[42]), + ), + ( + torch.tensor(output[43]), + torch.tensor(output[44]), + ), + ( + torch.tensor(output[45]), + torch.tensor(output[46]), + ), + ( + torch.tensor(output[47]), + torch.tensor(output[48]), + ), + ( + torch.tensor(output[49]), + torch.tensor(output[50]), + ), + ( + torch.tensor(output[51]), + torch.tensor(output[52]), + ), + ( + torch.tensor(output[53]), + torch.tensor(output[54]), + ), + ( + torch.tensor(output[55]), + torch.tensor(output[56]), + ), + ( + torch.tensor(output[57]), + torch.tensor(output[58]), + ), + ( + torch.tensor(output[59]), + torch.tensor(output[60]), + ), + ( + torch.tensor(output[61]), + torch.tensor(output[62]), + ), + ( + torch.tensor(output[63]), + torch.tensor(output[64]), + ), + ( + torch.tensor(output[65]), + torch.tensor(output[66]), + ), + ( + torch.tensor(output[67]), + torch.tensor(output[68]), + ), + ( + torch.tensor(output[69]), + torch.tensor(output[70]), + ), + ( + torch.tensor(output[71]), + torch.tensor(output[72]), + ), + ( + torch.tensor(output[73]), + torch.tensor(output[74]), + ), + ( + torch.tensor(output[75]), + torch.tensor(output[76]), + ), + ( + torch.tensor(output[77]), + torch.tensor(output[78]), + ), + ( + torch.tensor(output[79]), + torch.tensor(output[80]), + ), + ) return result diff --git a/apps/language_models/src/pipelines/falcon_pipeline.py b/apps/language_models/src/pipelines/falcon_pipeline.py index e6e43d1331..586f822b8b 100644 --- a/apps/language_models/src/pipelines/falcon_pipeline.py +++ b/apps/language_models/src/pipelines/falcon_pipeline.py @@ -6,10 +6,10 @@ CompiledLNFEmbeddingLayer, LMHeadEmbeddingLayer, CompiledLMHeadEmbeddingLayer, - DecoderLayer, - EightDecoderLayer, - CompiledDecoderLayer, - CompiledEightDecoderLayer, + FourWayShardingDecoderLayer, + TwoWayShardingDecoderLayer, + CompiledFourWayShardingDecoderLayer, + CompiledTwoWayShardingDecoderLayer, ShardedFalconModel, ) from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase @@ -94,6 +94,13 @@ action=argparse.BooleanOptionalAction, help="Run model as sharded", ) +parser.add_argument( + "--num_shards", + type=int, + default=4, + choices=[2, 4], + help="Number of shards.", +) class ShardedFalcon(SharkLLMBase): @@ -122,6 +129,10 @@ def __init__( --hf_auth_token flag. You can ask for the access to the model here: https://huggingface.co/tiiuae/falcon-180B-chat.""" ) + + if args.sharded and "180b" not in self.model_name: + raise ValueError("Sharding supported only for Falcon-180B") + self.hf_auth_token = hf_auth_token self.max_padding_length = 100 self.device = device @@ -131,7 +142,7 @@ def __init__( self.debug = debug self.tokenizer = self.get_tokenizer() self.src_model = self.get_src_model() - self.shark_model = self.compile(compressed=args.compressed) + self.shark_model = self.compile() def get_tokenizer(self): tokenizer = AutoTokenizer.from_pretrained( @@ -146,20 +157,17 @@ def get_tokenizer(self): def get_src_model(self): print("Loading src model: ", self.model_name) kwargs = { - "torch_dtype": torch.float, + "torch_dtype": torch.float32, "trust_remote_code": True, "token": self.hf_auth_token, } if self.precision == "int4": quantization_config = GPTQConfig(bits=4, disable_exllama=True) kwargs["quantization_config"] = quantization_config - kwargs["load_gptq_on_cpu"] = True kwargs["device_map"] = "cpu" falcon_model = AutoModelForCausalLM.from_pretrained( self.hf_model_path, **kwargs ) - if self.precision == "int4": - falcon_model = falcon_model.to(torch.float32) return falcon_model def compile_layer( @@ -288,28 +296,14 @@ def compile_layer( return shark_module, device_idx - def compile(self, compressed=False): + def compile(self): sample_input_ids = torch.zeros([100], dtype=torch.int64) - sample_attention_mask = torch.zeros( - [1, 1, 100, 100], dtype=torch.float32 - ) - num_group_layers = 1 - if "7b" in self.model_name: - num_in_features = 4544 - if compressed: - num_group_layers = 8 - elif "40b" in self.model_name: - num_in_features = 8192 - if compressed: - num_group_layers = 15 - else: - num_in_features = 14848 - sample_attention_mask = sample_attention_mask.to(dtype=torch.bool) - if compressed: - num_group_layers = 20 - + sample_attention_mask = torch.zeros([1, 1, 100, 100], dtype=torch.bool) + num_group_layers = int( + 20 * (4 / args.num_shards) + ) # 4 is the number of default shards sample_hidden_states = torch.zeros( - [1, 100, num_in_features], dtype=torch.float32 + [1, 100, 14848], dtype=torch.float32 ) # Determine number of available devices @@ -319,6 +313,10 @@ def compile(self, compressed=False): haldriver = ireert.get_driver(self.device) num_devices = len(haldriver.query_available_devices()) + if num_devices < 2: + raise ValueError( + "Cannot run Falcon-180B on a single ROCM device." + ) lm_head = LMHeadEmbeddingLayer(self.src_model.lm_head) print("Compiling Layer lm_head") @@ -326,7 +324,9 @@ def compile(self, compressed=False): lm_head, [sample_hidden_states], "lm_head", - device_idx=0 % num_devices if self.device == "rocm" else None, + device_idx=(0 % num_devices) % args.num_shards + if self.device == "rocm" + else None, ) shark_lm_head = CompiledLMHeadEmbeddingLayer(shark_lm_head) @@ -338,7 +338,9 @@ def compile(self, compressed=False): word_embedding, [sample_input_ids], "word_embeddings", - device_idx=1 % num_devices if self.device == "rocm" else None, + device_idx=(1 % num_devices) % args.num_shards + if self.device == "rocm" + else None, ) shark_word_embedding = CompiledWordEmbeddingsLayer( shark_word_embedding @@ -350,7 +352,9 @@ def compile(self, compressed=False): ln_f, [sample_hidden_states], "ln_f", - device_idx=2 % num_devices if self.device == "rocm" else None, + device_idx=(2 % num_devices) % args.num_shards + if self.device == "rocm" + else None, ) shark_ln_f = CompiledLNFEmbeddingLayer(shark_ln_f) @@ -360,24 +364,21 @@ def compile(self, compressed=False): ): device_idx = i % num_devices if self.device == "rocm" else None layer_id = i - pytorch_class = DecoderLayer - compiled_class = CompiledDecoderLayer - if compressed: - layer_id = ( - str(i * num_group_layers) - + "_" - + str((i + 1) * num_group_layers) - ) - pytorch_class = EightDecoderLayer - compiled_class = CompiledEightDecoderLayer + layer_id = ( + str(i * num_group_layers) + + "_" + + str((i + 1) * num_group_layers) + ) + pytorch_class = FourWayShardingDecoderLayer + compiled_class = CompiledFourWayShardingDecoderLayer + if args.num_shards == 2: + pytorch_class = TwoWayShardingDecoderLayer + compiled_class = CompiledTwoWayShardingDecoderLayer print("Compiling Layer {}".format(layer_id)) - if compressed: - layer_i = self.src_model.transformer.h[ - i * num_group_layers : (i + 1) * num_group_layers - ] - else: - layer_i = self.src_model.transformer.h[i] + layer_i = self.src_model.transformer.h[ + i * num_group_layers : (i + 1) * num_group_layers + ] pytorch_layer_i = pytorch_class( layer_i, args.falcon_variant_to_use @@ -388,13 +389,13 @@ def compile(self, compressed=False): layer_id, device_idx=device_idx, ) - del shark_module shark_layer_i = compiled_class( layer_id, device_idx, args.falcon_variant_to_use, self.device, self.precision, + shark_module, ) shark_layers.append(shark_layer_i) diff --git a/apps/shark_studio/api/llm.py b/apps/shark_studio/api/llm.py index dc4b90c872..46ef9ec6bf 100644 --- a/apps/shark_studio/api/llm.py +++ b/apps/shark_studio/api/llm.py @@ -19,10 +19,21 @@ "stop_token": 2, "max_tokens": 4096, "system_prompt": """[INST] <>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <>""", - } + }, + "Trelis/Llama-2-7b-chat-hf-function-calling-v2": { + "initializer": stateless_llama.export_transformer_model, + "hf_model_name": "Trelis/Llama-2-7b-chat-hf-function-calling-v2", + "stop_token": 2, + "max_tokens": 4096, + "system_prompt": """[INST] <>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <>""", + }, } +def safe_name(name): + return name.replace("/", "_").replace("-", "_") + + class LanguageModel: def __init__( self, @@ -31,21 +42,31 @@ def __init__( device=None, precision="fp32", external_weights=None, - external_weight_file=None, use_system_prompt=True, ): print(llm_model_map[model_name]) self.hf_model_name = llm_model_map[model_name]["hf_model_name"] - self.tempfile_name = get_resource_path("llm.torch.tempfile") - self.vmfb_name = get_resource_path("llm.vmfb.tempfile") + self.tempfile_name = get_resource_path( + f"{safe_name(self.hf_model_name)}.mlir.tempfile" + ) + self.vmfb_name = get_resource_path( + f"{safe_name(self.hf_model_name)}.vmfb.tempfile" + ) self.device = device self.precision = precision self.max_tokens = llm_model_map[model_name]["max_tokens"] self.iree_module_dict = None - self.external_weight_file = external_weight_file + self.external_weight_file = None + if external_weights is not None: + self.external_weight_file = ( + f"{safe_name(self.hf_model_name)}.{external_weights}" + ) self.use_system_prompt = use_system_prompt self.global_iter = 0 - if os.path.exists(self.vmfb_name): + if os.path.exists(self.vmfb_name) and ( + os.path.exists(self.external_weight_file) + or external_weights is None + ): self.iree_module_dict = dict() ( self.iree_module_dict["vmfb"], @@ -56,12 +77,12 @@ def __init__( device, device_idx=0, rt_flags=[], - external_weight_file=external_weight_file, + external_weight_file=self.external_weight_file, ) self.tokenizer = AutoTokenizer.from_pretrained( self.hf_model_name, use_fast=False, - use_auth_token=hf_auth_token, + token=hf_auth_token, ) elif not os.path.exists(self.tempfile_name): self.torch_ir, self.tokenizer = llm_model_map[model_name][ @@ -71,7 +92,7 @@ def __init__( hf_auth_token, compile_to="torch", external_weights=external_weights, - external_weight_file=external_weight_file, + external_weight_file=self.external_weight_file, ) with open(self.tempfile_name, "w+") as f: f.write(self.torch_ir) @@ -82,7 +103,7 @@ def __init__( self.tokenizer = AutoTokenizer.from_pretrained( self.hf_model_name, use_fast=False, - use_auth_token=hf_auth_token, + token=hf_auth_token, ) self.compile() @@ -160,11 +181,10 @@ def format_out(results): if __name__ == "__main__": lm = LanguageModel( - "llama2_7b", - hf_auth_token="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", + "Trelis/Llama-2-7b-chat-hf-function-calling-v2", + hf_auth_token=None, device="cpu-task", external_weights="safetensors", - external_weight_file="llama2_7b.safetensors", ) print("model loaded") for i in lm.chat("hi, what are you?"): diff --git a/apps/shark_studio/web/ui/chat.py b/apps/shark_studio/web/ui/chat.py index 31186f4d4d..4fb9138e1a 100644 --- a/apps/shark_studio/web/ui/chat.py +++ b/apps/shark_studio/web/ui/chat.py @@ -25,9 +25,6 @@ def user(message, history): def create_prompt(model_name, history, prompt_prefix): return "" system_message = "" - if prompt_prefix: - system_message = start_message[model_name] - if "llama2" in model_name: B_INST, E_INST = "[INST]", "[/INST]" B_SYS, E_SYS = "<>\n", "\n<>\n\n" @@ -104,7 +101,6 @@ def chat_fn( device=device, precision=precision, external_weights="safetensors", - external_weight_file="llama2_7b.safetensors", use_system_prompt=prompt_prefix, ) history[-1][-1] = "Getting the model ready... Done" diff --git a/apps/stable_diffusion/src/utils/stencils/stencil_utils.py b/apps/stable_diffusion/src/utils/stencils/stencil_utils.py index 685d542405..41800f0cb7 100644 --- a/apps/stable_diffusion/src/utils/stencils/stencil_utils.py +++ b/apps/stable_diffusion/src/utils/stencils/stencil_utils.py @@ -1,6 +1,10 @@ import numpy as np from PIL import Image import torch +import os +from pathlib import Path +import torchvision +import time from apps.stable_diffusion.src.utils.stencils import ( CannyDetector, OpenposeDetector, @@ -10,6 +14,33 @@ stencil = {} +def save_img(img): + from apps.stable_diffusion.src.utils import ( + get_generated_imgs_path, + get_generated_imgs_todays_subdir, + ) + + subdir = Path( + get_generated_imgs_path(), get_generated_imgs_todays_subdir() + ) + os.makedirs(subdir, exist_ok=True) + if isinstance(img, Image.Image): + img.save( + os.path.join( + subdir, "controlnet_" + str(int(time.time())) + ".png" + ) + ) + elif isinstance(img, np.ndarray): + img = Image.fromarray(img) + img.save(os.path.join(subdir, str(int(time.time())) + ".png")) + else: + converter = torchvision.transforms.ToPILImage() + for i in img: + converter(i).save( + os.path.join(subdir, str(int(time.time())) + ".png") + ) + + def HWC3(x): assert x.dtype == np.uint8 if x.ndim == 2: @@ -161,6 +192,7 @@ def hint_canny( detected_map = stencil["canny"]( input_image, low_threshold, high_threshold ) + save_img(detected_map) detected_map = HWC3(detected_map) return detected_map @@ -176,6 +208,7 @@ def hint_openpose( stencil["openpose"] = OpenposeDetector() detected_map, _ = stencil["openpose"](input_image) + save_img(detected_map) detected_map = HWC3(detected_map) return detected_map @@ -187,6 +220,7 @@ def hint_scribble(image: Image.Image): detected_map = np.zeros_like(input_image, dtype=np.uint8) detected_map[np.min(input_image, axis=2) < 127] = 255 + save_img(detected_map) return detected_map @@ -199,5 +233,6 @@ def hint_zoedepth(image: Image.Image): stencil["depth"] = ZoeDetector() detected_map = stencil["depth"](input_image) + save_img(detected_map) detected_map = HWC3(detected_map) return detected_map diff --git a/apps/stable_diffusion/src/utils/utils.py b/apps/stable_diffusion/src/utils/utils.py index 4bccfa54e9..0516247a8e 100644 --- a/apps/stable_diffusion/src/utils/utils.py +++ b/apps/stable_diffusion/src/utils/utils.py @@ -118,7 +118,7 @@ def compile_through_fx( is_f16=False, f16_input_mask=None, use_tuned=False, - save_dir=tempfile.gettempdir(), + save_dir="", debug=False, generate_vmfb=True, extra_args=None, diff --git a/apps/stable_diffusion/web/index.py b/apps/stable_diffusion/web/index.py index 32d40e3dad..a6688ed104 100644 --- a/apps/stable_diffusion/web/index.py +++ b/apps/stable_diffusion/web/index.py @@ -75,11 +75,11 @@ # Setup to use shark_tmp for gradio's temporary image files and clear any # existing temporary images there if they exist. Then we can import gradio. # It has to be in this order or gradio ignores what we've set up. - from apps.stable_diffusion.web.utils.gradio_configs import ( - config_gradio_tmp_imgs_folder, + from apps.stable_diffusion.web.utils.tmp_configs import ( + config_tmp, ) - config_gradio_tmp_imgs_folder() + config_tmp() import gradio as gr # Create custom models folders if they don't exist diff --git a/apps/stable_diffusion/web/ui/css/sd_dark_theme.css b/apps/stable_diffusion/web/ui/css/sd_dark_theme.css index 00ebebc027..b5cb5bb1ec 100644 --- a/apps/stable_diffusion/web/ui/css/sd_dark_theme.css +++ b/apps/stable_diffusion/web/ui/css/sd_dark_theme.css @@ -105,6 +105,18 @@ body { background-color: var(--background-fill-primary); } +.generating.svelte-zlszon.svelte-zlszon { + border: none; +} + +.generating { + border: none !important; +} + +#chatbot { + height: 100% !important; +} + /* display in full width for desktop devices */ @media (min-width: 1536px) { diff --git a/apps/stable_diffusion/web/ui/stablelm_ui.py b/apps/stable_diffusion/web/ui/stablelm_ui.py index 098f33bce1..0df3f8442d 100644 --- a/apps/stable_diffusion/web/ui/stablelm_ui.py +++ b/apps/stable_diffusion/web/ui/stablelm_ui.py @@ -6,6 +6,7 @@ AutoModelForCausalLM, ) from apps.stable_diffusion.web.ui.utils import available_devices +from shark.iree_utils.compile_utils import clean_device_info from datetime import datetime as dt import json import sys @@ -132,27 +133,6 @@ def get_default_config(): c.split_into_layers() -def clean_device_info(raw_device): - # return appropriate device and device_id for consumption by LLM pipeline - # Multiple devices only supported for vulkan and rocm (as of now). - # default device must be selected for all others - - device_id = None - device = ( - raw_device - if "=>" not in raw_device - else raw_device.split("=>")[1].strip() - ) - if "://" in device: - device, device_id = device.split("://") - device_id = int(device_id) # using device index in webui - - if device not in ["rocm", "vulkan"]: - device_id = None - - return device, device_id - - model_vmfb_key = "" @@ -456,7 +436,7 @@ def view_json_file(file_obj): json_view_button.click( fn=view_json_file, inputs=[config_file], outputs=[json_view] ) - chatbot = gr.Chatbot(height=500) + chatbot = gr.Chatbot(elem_id="chatbot") with gr.Row(): with gr.Column(): msg = gr.Textbox( diff --git a/apps/stable_diffusion/web/utils/gradio_configs.py b/apps/stable_diffusion/web/utils/tmp_configs.py similarity index 76% rename from apps/stable_diffusion/web/utils/gradio_configs.py rename to apps/stable_diffusion/web/utils/tmp_configs.py index ae8e6283f8..3e6ba46bfe 100644 --- a/apps/stable_diffusion/web/utils/gradio_configs.py +++ b/apps/stable_diffusion/web/utils/tmp_configs.py @@ -5,11 +5,25 @@ shark_tmp = os.path.join(os.getcwd(), "shark_tmp/") -def config_gradio_tmp_imgs_folder(): - # create shark_tmp if it does not exist - if not os.path.exists(shark_tmp): - os.mkdir(shark_tmp) +def clear_tmp_mlir(): + cleanup_start = time() + print( + "Clearing .mlir temporary files from a prior run. This may take some time..." + ) + mlir_files = [ + filename + for filename in os.listdir(shark_tmp) + if os.path.isfile(os.path.join(shark_tmp, filename)) + and filename.endswith(".mlir") + ] + for filename in mlir_files: + os.remove(shark_tmp + filename) + print( + f"Clearing .mlir temporary files took {time() - cleanup_start:.4f} seconds." + ) + +def clear_tmp_imgs(): # tell gradio to use a directory under shark_tmp for its temporary # image files unless somewhere else has been set if "GRADIO_TEMP_DIR" not in os.environ: @@ -52,3 +66,12 @@ def config_gradio_tmp_imgs_folder(): ) else: print("No temporary images files to clear.") + + +def config_tmp(): + # create shark_tmp if it does not exist + if not os.path.exists(shark_tmp): + os.mkdir(shark_tmp) + + clear_tmp_mlir() + clear_tmp_imgs() diff --git a/requirements.txt b/requirements.txt index 1d49b5e025..fc6aad4c97 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,9 @@ setuptools wheel +# TURBINE shark-turbine @ git+https://github.com/nod-ai/SHARK-Turbine.git@main +turbine-models @ git+https://github.com/nod-ai/SHARK-Turbine.git@main#egg=turbine-models&subdirectory=python/turbine_models # SHARK Runner tqdm @@ -19,9 +21,9 @@ pytest-forked Pillow parameterized -#shark-turbine @ git+https://github.com/nod-ai/SHARK-Turbine.git@main # Add transformers, diffusers and scipy since it most commonly used -tokenizers==0.13.3 +tokenizers +transformers diffusers #accelerate is now required for diffusers import from ckpt. accelerate @@ -50,5 +52,6 @@ pydantic==2.4.1 # pin until pyinstaller-hooks-contrib works with beta versions pefile pyinstaller -# vicuna quantization -brevitas @ git+https://github.com/Xilinx/brevitas.git@56edf56a3115d5ac04f19837b388fd7d3b1ff7ea +# For quantized GPTQ models +optimum +auto_gptq diff --git a/setup_venv.ps1 b/setup_venv.ps1 index 7957a2c2bf..09489bf4cc 100644 --- a/setup_venv.ps1 +++ b/setup_venv.ps1 @@ -89,7 +89,7 @@ else {python -m venv .\shark.venv\} python -m pip install --upgrade pip pip install wheel pip install -r requirements.txt -pip install --pre torch-mlir torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/ +pip install --pre torch-mlir torchvision torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/ pip install --upgrade -f https://nod-ai.github.io/SRT/pip-release-links.html iree-compiler iree-runtime Write-Host "Building SHARK..." pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html diff --git a/setup_venv.sh b/setup_venv.sh index 62c6513a85..5248e275fc 100755 --- a/setup_venv.sh +++ b/setup_venv.sh @@ -151,10 +151,6 @@ if [[ $(uname -s) = 'Linux' && ! -z "${IMPORTER}" ]]; then fi fi -if [[ -z "${NO_BREVITAS}" ]]; then - $PYTHON -m pip install git+https://github.com/Xilinx/brevitas.git@dev -fi - if [[ -z "${CONDA_PREFIX}" && "$SKIP_VENV" != "1" ]]; then echo "${Green}Before running examples activate venv with:" echo " ${Green}source $VENV_DIR/bin/activate" diff --git a/shark/iree_utils/compile_utils.py b/shark/iree_utils/compile_utils.py index a9750466db..00a1a10a24 100644 --- a/shark/iree_utils/compile_utils.py +++ b/shark/iree_utils/compile_utils.py @@ -31,24 +31,7 @@ # Get the iree-compile arguments given device. def get_iree_device_args(device, extra_args=[]): print("Configuring for device:" + device) - device_uri = device.split("://") - if len(device_uri) > 1: - if device_uri[0] not in ["vulkan", "rocm"]: - print( - f"Specific device selection only supported for vulkan and rocm." - f"Proceeding with {device} as device." - ) - # device_uri can be device_num or device_path. - # assuming number of devices for a single driver will be not be >99 - if len(device_uri[1]) <= 2: - # expected to be device index in range 0 - 99 - device_num = int(device_uri[1]) - else: - # expected to be device path - device_num = device_uri[1] - - else: - device_num = 0 + device, device_num = clean_device_info(device) if "cpu" in device: from shark.iree_utils.cpu_utils import get_iree_cpu_args @@ -64,27 +47,50 @@ def get_iree_device_args(device, extra_args=[]): + stack_size_flag + ["--iree-global-opt-enable-quantized-matmul-reassociation"] ) - if device_uri[0] == "cuda": + if device == "cuda": from shark.iree_utils.gpu_utils import get_iree_gpu_args return get_iree_gpu_args() - if device_uri[0] == "vulkan": + if device == "vulkan": from shark.iree_utils.vulkan_utils import get_iree_vulkan_args return get_iree_vulkan_args( device_num=device_num, extra_args=extra_args ) - if device_uri[0] == "metal": + if device == "metal": from shark.iree_utils.metal_utils import get_iree_metal_args return get_iree_metal_args(extra_args=extra_args) - if device_uri[0] == "rocm": + if device == "rocm": from shark.iree_utils.gpu_utils import get_iree_rocm_args return get_iree_rocm_args(device_num=device_num, extra_args=extra_args) return [] +def clean_device_info(raw_device): + # return appropriate device and device_id for consumption by Studio pipeline + # Multiple devices only supported for vulkan and rocm (as of now). + # default device must be selected for all others + + device_id = None + device = ( + raw_device + if "=>" not in raw_device + else raw_device.split("=>")[1].strip() + ) + if "://" in device: + device, device_id = device.split("://") + if len(device_id) <= 2: + device_id = int(device_id) + + if device not in ["rocm", "vulkan"]: + device_id = "" + if device in ["rocm", "vulkan"] and device_id == None: + device_id = 0 + return device, device_id + + # Get the iree-compiler arguments given frontend. def get_iree_frontend_args(frontend): if frontend in ["torch", "pytorch", "linalg", "tm_tensor"]: diff --git a/shark/iree_utils/gpu_utils.py b/shark/iree_utils/gpu_utils.py index 729e5977d9..36be96485e 100644 --- a/shark/iree_utils/gpu_utils.py +++ b/shark/iree_utils/gpu_utils.py @@ -95,6 +95,7 @@ def get_devices_info_from_dump(dump): print("could not execute `iree-run-module --dump_devices=rocm`") if dump_device_info is not None: + device_num = 0 if device_num is None else device_num device_arch_pairs = get_devices_info_from_dump(dump_device_info[0]) if len(device_arch_pairs) > device_num: # can find arch in the list arch_in_device_dump = device_arch_pairs[device_num][1] diff --git a/shark/iree_utils/vulkan_utils.py b/shark/iree_utils/vulkan_utils.py index bf787a3a8f..859c4ba833 100644 --- a/shark/iree_utils/vulkan_utils.py +++ b/shark/iree_utils/vulkan_utils.py @@ -38,15 +38,24 @@ def get_all_vulkan_devices(): @functools.cache def get_vulkan_device_name(device_num=0): - vulkaninfo_list = get_all_vulkan_devices() - if len(vulkaninfo_list) == 0: - raise ValueError("No device name found in VulkanInfo!") - if len(vulkaninfo_list) > 1: - print("Following devices found:") - for i, dname in enumerate(vulkaninfo_list): - print(f"{i}. {dname}") - print(f"Choosing device: {vulkaninfo_list[device_num]}") - return vulkaninfo_list[device_num] + if isinstance(device_num, int): + vulkaninfo_list = get_all_vulkan_devices() + + if len(vulkaninfo_list) == 0: + raise ValueError("No device name found in VulkanInfo!") + if len(vulkaninfo_list) > 1: + print("Following devices found:") + for i, dname in enumerate(vulkaninfo_list): + print(f"{i}. {dname}") + print(f"Choosing device: vulkan://{device_num}") + vulkan_device_name = vulkaninfo_list[device_num] + else: + from iree.runtime import get_driver + + vulkan_device_driver = get_driver(device_num) + vulkan_device_name = vulkan_device_driver.query_available_devices()[0] + print(vulkan_device_name) + return vulkan_device_name def get_os_name(): diff --git a/shark/shark_importer.py b/shark/shark_importer.py index 7082bf3813..3d585a8d0d 100644 --- a/shark/shark_importer.py +++ b/shark/shark_importer.py @@ -800,15 +800,17 @@ def save_mlir( model_name, mlir_dialect="linalg", frontend="torch", - dir=tempfile.gettempdir(), + dir="", ): model_name_mlir = ( model_name + "_" + frontend + "_" + mlir_dialect + ".mlir" ) if dir == "": - dir = tempfile.gettempdir() + dir = os.path.join(".", "shark_tmp") mlir_path = os.path.join(dir, model_name_mlir) print(f"saving {model_name_mlir} to {dir}") + if not os.path.exists(dir): + os.makedirs(dir) if frontend == "torch": with open(mlir_path, "wb") as mlir_file: mlir_file.write(mlir_module)