diff --git a/apps/language_models/scripts/vicuna.py b/apps/language_models/scripts/vicuna.py index 3e801050c6..f96749ec05 100644 --- a/apps/language_models/scripts/vicuna.py +++ b/apps/language_models/scripts/vicuna.py @@ -3,6 +3,10 @@ from apps.language_models.src.pipelines import vicuna_pipeline as vp from apps.language_models.src.pipelines import vicuna_sharded_pipeline as vsp import torch +import json + +if __name__ == "__main__": + import gc parser = argparse.ArgumentParser( @@ -55,6 +59,12 @@ help="Run model in cli mode", ) +parser.add_argument( + "--config", + default=None, + help="configuration file", +) + if __name__ == "__main__": args, unknown = parser.parse_known_args() @@ -84,6 +94,7 @@ if args.second_vicuna_vmfb_path is None else Path(args.second_vicuna_vmfb_path) ) + vic = vp.Vicuna( "vicuna", device=args.device, @@ -95,16 +106,21 @@ load_mlir_from_shark_tank=args.load_mlir_from_shark_tank, ) else: + if args.config is not None: + config_file = open(args.config) + config_json = json.load(config_file) + config_file.close() + else: + config_json = None vic = vsp.Vicuna( "vicuna", device=args.device, precision=args.precision, + config_json=config_json, ) prompt_history = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n" prologue_prompt = "ASSISTANT:\n" - import gc - while True: # TODO: Add break condition from user input user_prompt = input("User: ") diff --git a/apps/language_models/src/pipelines/vicuna_sharded_pipeline.py b/apps/language_models/src/pipelines/vicuna_sharded_pipeline.py index 79387af462..e9feb43c63 100644 --- a/apps/language_models/src/pipelines/vicuna_sharded_pipeline.py +++ b/apps/language_models/src/pipelines/vicuna_sharded_pipeline.py @@ -19,6 +19,7 @@ import torch import torch_mlir import os +import json class Vicuna(SharkLLMBase): @@ -29,12 +30,14 @@ def __init__( max_num_tokens=512, device="cuda", precision="fp32", + config_json=None, ) -> None: super().__init__(model_name, hf_model_path, max_num_tokens) self.max_sequence_length = 256 self.device = device self.precision = precision self.tokenizer = self.get_tokenizer() + self.config = config_json self.shark_model = self.compile(device=device) def get_tokenizer(self): @@ -220,24 +223,43 @@ def compile_to_vmfb(self, inputs, layers, device="cpu", is_first=True): f_.close() mlirs.append(bytecode) + def get_device_index(layer_string): + if self.config is None: + return None + idx_votes = {} + for key in self.config.keys(): + if re.search(layer_string, key): + if int(self.config[key]["gpu"]) in idx_votes.keys(): + idx_votes[int(self.config[key]["gpu"])] += 1 + else: + idx_votes[int(self.config[key]["gpu"])] = 1 + device_idx = max(idx_votes, key=idx_votes.get) + return device_idx + for idx, layer in tqdm(enumerate(layers), desc="compiling modules"): if is_first: vmfb_path = Path(f"{idx}_0.vmfb") if vmfb_path.exists(): # print(f"Found layer {idx} vmfb") + device_idx = get_device_index( + f"first_vicuna.model.model.layers.{idx}[\s.]" + ) module = SharkInference( None, device=device, - device_idx=idx % 1, + device_idx=device_idx, mlir_dialect="tm_tensor", ) module.load_module(vmfb_path) else: print(f"Compiling layer {idx} vmfb") + device_idx = get_device_index( + f"first_vicuna.model.model.layers.{idx}[\s.]" + ) module = SharkInference( mlirs[idx], device=device, - device_idx=idx % 1, + device_idx=device_idx, mlir_dialect="tm_tensor", ) module.save_module( @@ -255,19 +277,25 @@ def compile_to_vmfb(self, inputs, layers, device="cpu", is_first=True): vmfb_path = Path(f"{idx}_1.vmfb") if vmfb_path.exists(): # print(f"Found layer {idx} vmfb") + device_idx = get_device_index( + f"second_vicuna.model.model.layers.{idx}[\s.]" + ) module = SharkInference( None, device=device, - device_idx=idx % 1, + device_idx=device_idx, mlir_dialect="tm_tensor", ) module.load_module(vmfb_path) else: print(f"Compiling layer {idx} vmfb") + device_idx = get_device_index( + f"second_vicuna.model.model.layers.{idx}[\s.]" + ) module = SharkInference( mlirs[idx], device=device, - device_idx=idx % 1, + device_idx=device_idx, mlir_dialect="tm_tensor", ) module.save_module(