diff --git a/apps/language_models/scripts/vicuna.py b/apps/language_models/scripts/vicuna.py index 3e801050c6..a8ddb6ce58 100644 --- a/apps/language_models/scripts/vicuna.py +++ b/apps/language_models/scripts/vicuna.py @@ -61,26 +61,22 @@ vic = None if not args.sharded: first_vic_mlir_path = ( - Path(f"first_vicuna_{args.precision}.mlir") + None if args.first_vicuna_mlir_path is None else Path(args.first_vicuna_mlir_path) ) second_vic_mlir_path = ( - Path(f"second_vicuna_{args.precision}.mlir") + None if args.second_vicuna_mlir_path is None else Path(args.second_vicuna_mlir_path) ) first_vic_vmfb_path = ( - Path( - f"first_vicuna_{args.precision}_{args.device.replace('://', '_')}.vmfb" - ) + None if args.first_vicuna_vmfb_path is None else Path(args.first_vicuna_vmfb_path) ) second_vic_vmfb_path = ( - Path( - f"second_vicuna_{args.precision}_{args.device.replace('://', '_')}.vmfb" - ) + None if args.second_vicuna_vmfb_path is None else Path(args.second_vicuna_vmfb_path) ) diff --git a/apps/language_models/src/pipelines/vicuna_pipeline.py b/apps/language_models/src/pipelines/vicuna_pipeline.py index 4745504ae5..a89cd55df2 100644 --- a/apps/language_models/src/pipelines/vicuna_pipeline.py +++ b/apps/language_models/src/pipelines/vicuna_pipeline.py @@ -33,6 +33,7 @@ def __init__( first_vicuna_vmfb_path=None, second_vicuna_vmfb_path=None, load_mlir_from_shark_tank=True, + low_device_memory=False, ) -> None: super().__init__(model_name, hf_model_path, max_num_tokens) self.max_sequence_length = 256 @@ -43,12 +44,16 @@ def __init__( self.first_vicuna_mlir_path = first_vicuna_mlir_path self.second_vicuna_mlir_path = second_vicuna_mlir_path self.load_mlir_from_shark_tank = load_mlir_from_shark_tank + self.low_device_memory = low_device_memory + self.first_vic = None + self.second_vic = None if self.first_vicuna_mlir_path == None: self.first_vicuna_mlir_path = self.get_model_path() if self.second_vicuna_mlir_path == None: self.second_vicuna_mlir_path = self.get_model_path("second") if self.first_vicuna_vmfb_path == None: self.first_vicuna_vmfb_path = self.get_model_path(suffix="vmfb") + print(self.first_vicuna_vmfb_path) if self.second_vicuna_vmfb_path == None: self.second_vicuna_vmfb_path = self.get_model_path( "second", "vmfb" @@ -61,7 +66,7 @@ def get_model_path(self, model_number="first", suffix="mlir"): if suffix == "mlir": return Path(f"{model_number}_vicuna_{self.precision}.{suffix}") return Path( - f"{model_number}_vicuna_{safe_device}_{self.precision}.{suffix}" + f"{model_number}_vicuna_{self.precision}_{safe_device}.{suffix}" ) def get_tokenizer(self): @@ -436,12 +441,19 @@ def generate(self, prompt, cli=False): # TODO: refactor for cleaner integration import gc + if not self.low_device_memory: + if self.first_vic == None: + self.first_vic = self.compile_first_vicuna() + if self.second_vic == None: + self.second_vic = self.compile_second_vicuna() res = [] res_tokens = [] params = { "prompt": prompt, "is_first": True, - "fv": self.compile_first_vicuna(), + "fv": self.compile_first_vicuna() + if self.first_vic == None + else self.first_vic, } generated_token_op = self.generate_new_token(params=params) @@ -457,18 +469,20 @@ def generate(self, prompt, cli=False): print(f"Assistant: {detok}", end=" ", flush=True) # Clear First Vic from Memory (main and cuda) - del params - torch.cuda.empty_cache() - gc.collect() + if self.low_device_memory: + del params + torch.cuda.empty_cache() + gc.collect() - sec_vic = self.compile_second_vicuna() for _ in range(self.max_num_tokens - 2): params = { "prompt": None, "is_first": False, "logits": logits, "pkv": pkv, - "sv": sec_vic, + "sv": self.compile_second_vicuna() + if self.second_vic == None + else self.second_vic, } generated_token_op = self.generate_new_token(params=params) @@ -489,9 +503,10 @@ def generate(self, prompt, cli=False): res.append(detok) if cli: print(f"{detok}", end=" ", flush=True) - del sec_vic, pkv, logits - torch.cuda.empty_cache() - gc.collect() + if self.device == "cuda": + del sec_vic, pkv, logits + torch.cuda.empty_cache() + gc.collect() for i in range(len(res_tokens)): if type(res_tokens[i]) != int: diff --git a/shark/iree_utils/compile_utils.py b/shark/iree_utils/compile_utils.py index 78ee1ca6d5..139aa9fea4 100644 --- a/shark/iree_utils/compile_utils.py +++ b/shark/iree_utils/compile_utils.py @@ -20,6 +20,7 @@ import os import re import tempfile +from pathlib import Path # Get the iree-compile arguments given device. @@ -355,6 +356,9 @@ def load_vmfb_using_mmap( # OR 2. We are compiling on the fly, therefore we have the flatbuffer blob to play with. # (This would arise if we're invoking `compile` from a SharkInference obj) temp_file_to_unlink = None + + if isinstance(flatbuffer_blob_or_path, Path): + flatbuffer_blob_or_path = flatbuffer_blob_or_path.__str__() if ( isinstance(flatbuffer_blob_or_path, str) and ".vmfb" in flatbuffer_blob_or_path