From cb96194030f4770b02f2e222be4527ab7febf9ed Mon Sep 17 00:00:00 2001 From: Elias Joseph Date: Wed, 21 Jun 2023 00:05:31 -0400 Subject: [PATCH] added ability to use config file to shard vicuna --- apps/language_models/scripts/vicuna.py | 20 +- .../model_wrappers/vicuna_sharded_model.py | 74 +++++- .../src/pipelines/vicuna_sharded_pipeline.py | 230 +++++++++++++++++- 3 files changed, 311 insertions(+), 13 deletions(-) diff --git a/apps/language_models/scripts/vicuna.py b/apps/language_models/scripts/vicuna.py index 3e801050c6..f96749ec05 100644 --- a/apps/language_models/scripts/vicuna.py +++ b/apps/language_models/scripts/vicuna.py @@ -3,6 +3,10 @@ from apps.language_models.src.pipelines import vicuna_pipeline as vp from apps.language_models.src.pipelines import vicuna_sharded_pipeline as vsp import torch +import json + +if __name__ == "__main__": + import gc parser = argparse.ArgumentParser( @@ -55,6 +59,12 @@ help="Run model in cli mode", ) +parser.add_argument( + "--config", + default=None, + help="configuration file", +) + if __name__ == "__main__": args, unknown = parser.parse_known_args() @@ -84,6 +94,7 @@ if args.second_vicuna_vmfb_path is None else Path(args.second_vicuna_vmfb_path) ) + vic = vp.Vicuna( "vicuna", device=args.device, @@ -95,16 +106,21 @@ load_mlir_from_shark_tank=args.load_mlir_from_shark_tank, ) else: + if args.config is not None: + config_file = open(args.config) + config_json = json.load(config_file) + config_file.close() + else: + config_json = None vic = vsp.Vicuna( "vicuna", device=args.device, precision=args.precision, + config_json=config_json, ) prompt_history = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n" prologue_prompt = "ASSISTANT:\n" - import gc - while True: # TODO: Add break condition from user input user_prompt = input("User: ") diff --git a/apps/language_models/src/model_wrappers/vicuna_sharded_model.py b/apps/language_models/src/model_wrappers/vicuna_sharded_model.py index cadba37cae..cba0b0f952 100644 --- a/apps/language_models/src/model_wrappers/vicuna_sharded_model.py +++ b/apps/language_models/src/model_wrappers/vicuna_sharded_model.py @@ -145,7 +145,7 @@ def forward( class ShardedVicunaModel(torch.nn.Module): - def __init__(self, model, layers0, layers1): + def __init__(self, model, layers0, layers1, lmhead, embedding, norm): super().__init__() self.model = model assert len(layers0) == len(model.model.layers) @@ -154,6 +154,12 @@ def __init__(self, model, layers0, layers1): self.model.model.config.output_attentions = False self.layers0 = layers0 self.layers1 = layers1 + self.norm = norm + self.embedding = embedding + self.lmhead = lmhead + self.model.model.norm = self.norm + self.model.model.embed_tokens = self.embedding + self.model.lm_head = self.lmhead def forward( self, @@ -176,3 +182,69 @@ def forward( attention_mask=attention_mask, past_key_values=past_key_values, ) + + +class LMHead(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, hidden_states): + output = self.model(hidden_states) + return output + + +class LMHeadCompiled(torch.nn.Module): + def __init__(self, shark_module): + super().__init__() + self.model = shark_module + + def forward(self, hidden_states): + hidden_states = hidden_states.detach() + output = self.model("forward", (hidden_states,)) + output = torch.tensor(output) + return output + + +class VicunaNorm(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, hidden_states): + output = self.model(hidden_states) + return output + + +class VicunaNormCompiled(torch.nn.Module): + def __init__(self, shark_module): + super().__init__() + self.model = shark_module + + def forward(self, hidden_states): + hidden_states.detach() + output = self.model("forward", (hidden_states,)) + output = torch.tensor(output) + return output + + +class VicunaEmbedding(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, input_ids): + output = self.model(input_ids) + return output + + +class VicunaEmbeddingCompiled(torch.nn.Module): + def __init__(self, shark_module): + super().__init__() + self.model = shark_module + + def forward(self, input_ids): + input_ids.detach() + output = self.model("forward", (input_ids,)) + output = torch.tensor(output) + return output diff --git a/apps/language_models/src/pipelines/vicuna_sharded_pipeline.py b/apps/language_models/src/pipelines/vicuna_sharded_pipeline.py index 79387af462..9cb3a428ea 100644 --- a/apps/language_models/src/pipelines/vicuna_sharded_pipeline.py +++ b/apps/language_models/src/pipelines/vicuna_sharded_pipeline.py @@ -4,6 +4,12 @@ CompiledFirstVicunaLayer, CompiledSecondVicunaLayer, ShardedVicunaModel, + LMHead, + LMHeadCompiled, + VicunaEmbedding, + VicunaEmbeddingCompiled, + VicunaNorm, + VicunaNormCompiled, ) from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase from shark.shark_importer import import_with_fx @@ -19,9 +25,11 @@ import torch import torch_mlir import os +import json class Vicuna(SharkLLMBase): + # Class representing Sharded Vicuna Model def __init__( self, model_name, @@ -29,21 +37,25 @@ def __init__( max_num_tokens=512, device="cuda", precision="fp32", + config_json=None, ) -> None: super().__init__(model_name, hf_model_path, max_num_tokens) self.max_sequence_length = 256 self.device = device self.precision = precision self.tokenizer = self.get_tokenizer() + self.config = config_json self.shark_model = self.compile(device=device) def get_tokenizer(self): + # Retrieve the tokenizer from Huggingface tokenizer = AutoTokenizer.from_pretrained( self.hf_model_path, use_fast=False ) return tokenizer def get_src_model(self): + # Retrieve the torch model from Huggingface kwargs = {"torch_dtype": torch.float} vicuna_model = AutoModelForCausalLM.from_pretrained( self.hf_model_path, **kwargs @@ -51,6 +63,8 @@ def get_src_model(self): return vicuna_model def write_in_dynamic_inputs0(self, module, dynamic_input_size): + # Current solution for ensuring mlir files support dynamic inputs + # TODO find a more elegant way to implement this new_lines = [] for line in module.splitlines(): line = re.sub(f"{dynamic_input_size}x", "?x", line) @@ -107,6 +121,7 @@ def compile_vicuna_layer( past_key_value0=None, past_key_value1=None, ): + # Compile a hidden decoder layer of vicuna if past_key_value0 is None and past_key_value1 is None: model_inputs = (hidden_states, attention_mask, position_ids) else: @@ -126,7 +141,154 @@ def compile_vicuna_layer( ) return mlir_bytecode + def get_device_index(self, layer_string): + # Get the device index from the config file + # In the event that different device indices are assigned to + # different parts of a layer, a majority vote will be taken and + # everything will be run on the most commonly used device + if self.config is None: + return None + idx_votes = {} + for key in self.config.keys(): + if re.search(layer_string, key): + if int(self.config[key]["gpu"]) in idx_votes.keys(): + idx_votes[int(self.config[key]["gpu"])] += 1 + else: + idx_votes[int(self.config[key]["gpu"])] = 1 + device_idx = max(idx_votes, key=idx_votes.get) + return device_idx + + def compile_lmhead( + self, lmh, hidden_states, device="cpu", device_idx=None + ): + # compile the lm head of the vicuna model + # This can be used for both first and second vicuna, so only needs to be run once + mlir_path = Path(f"lmhead.mlir") + vmfb_path = Path(f"lmhead.vmfb") + if mlir_path.exists(): + f_ = open(mlir_path, "rb") + bytecode = f_.read() + f_.close() + else: + hidden_states = torch_mlir.TensorPlaceholder.like( + hidden_states, dynamic_axes=[1] + ) + + module = torch_mlir.compile( + lmh, + (hidden_states,), + torch_mlir.OutputType.LINALG_ON_TENSORS, + use_tracing=False, + verbose=False, + ) + bytecode_stream = BytesIO() + module.operation.write_bytecode(bytecode_stream) + bytecode = bytecode_stream.getvalue() + f_ = open(mlir_path, "wb") + f_.write(bytecode) + f_.close() + + shark_module = SharkInference( + bytecode, + device=device, + mlir_dialect="tm_tensor", + device_idx=device_idx, + ) + if vmfb_path.exists(): + shark_module.load_module(vmfb_path) + else: + shark_module.save_module(module_name="lmhead") + shark_module.load_module(vmfb_path) + compiled_module = LMHeadCompiled(shark_module) + return compiled_module + + def compile_norm(self, fvn, hidden_states, device="cpu", device_idx=None): + # compile the normalization layer of the vicuna model + # This can be used for both first and second vicuna, so only needs to be run once + mlir_path = Path(f"norm.mlir") + vmfb_path = Path(f"norm.vmfb") + if mlir_path.exists(): + f_ = open(mlir_path, "rb") + bytecode = f_.read() + f_.close() + else: + hidden_states = torch_mlir.TensorPlaceholder.like( + hidden_states, dynamic_axes=[1] + ) + + module = torch_mlir.compile( + fvn, + (hidden_states,), + torch_mlir.OutputType.LINALG_ON_TENSORS, + use_tracing=False, + verbose=False, + ) + bytecode_stream = BytesIO() + module.operation.write_bytecode(bytecode_stream) + bytecode = bytecode_stream.getvalue() + f_ = open(mlir_path, "wb") + f_.write(bytecode) + f_.close() + + shark_module = SharkInference( + bytecode, + device=device, + mlir_dialect="tm_tensor", + device_idx=device_idx, + ) + if vmfb_path.exists(): + shark_module.load_module(vmfb_path) + else: + shark_module.save_module(module_name="norm") + shark_module.load_module(vmfb_path) + compiled_module = VicunaNormCompiled(shark_module) + return compiled_module + + def compile_embedding(self, fve, input_ids, device="cpu", device_idx=None): + # compile the embedding layer of the vicuna model + # This can be used for both first and second vicuna, so only needs to be run once + mlir_path = Path(f"embedding.mlir") + vmfb_path = Path(f"embedding.vmfb") + if mlir_path.exists(): + f_ = open(mlir_path, "rb") + bytecode = f_.read() + f_.close() + else: + input_ids = torch_mlir.TensorPlaceholder.like( + input_ids, dynamic_axes=[1] + ) + module = torch_mlir.compile( + fve, + (input_ids,), + torch_mlir.OutputType.LINALG_ON_TENSORS, + use_tracing=False, + verbose=False, + ) + bytecode_stream = BytesIO() + module.operation.write_bytecode(bytecode_stream) + bytecode = bytecode_stream.getvalue() + f_ = open(mlir_path, "wb") + f_.write(bytecode) + f_.close() + + shark_module = SharkInference( + bytecode, + device=device, + mlir_dialect="tm_tensor", + device_idx=device_idx, + ) + if vmfb_path.exists(): + shark_module.load_module(vmfb_path) + else: + shark_module.save_module(module_name="embedding") + shark_module.load_module(vmfb_path) + compiled_module = VicunaEmbeddingCompiled(shark_module) + + return compiled_module + def compile_to_vmfb(self, inputs, layers, device="cpu", is_first=True): + # compile all layers for vmfb + # this needs to be run seperatley for first and second vicuna mlirs, modules = [], [] for idx, layer in tqdm(enumerate(layers), desc="Getting mlirs"): if is_first: @@ -198,10 +360,6 @@ def compile_to_vmfb(self, inputs, layers, device="cpu", is_first=True): verbose=False, ) - # bytecode_stream = BytesIO() - # module.operation.write_bytecode(bytecode_stream) - # bytecode = bytecode_stream.getvalue() - if is_first: module = self.write_in_dynamic_inputs0(str(module), 137) bytecode = module.encode("UTF-8") @@ -224,20 +382,25 @@ def compile_to_vmfb(self, inputs, layers, device="cpu", is_first=True): if is_first: vmfb_path = Path(f"{idx}_0.vmfb") if vmfb_path.exists(): - # print(f"Found layer {idx} vmfb") + device_idx = self.get_device_index( + f"first_vicuna.model.model.layers.{idx}[\s.$]" + ) module = SharkInference( None, device=device, - device_idx=idx % 1, + device_idx=device_idx, mlir_dialect="tm_tensor", ) module.load_module(vmfb_path) else: print(f"Compiling layer {idx} vmfb") + device_idx = self.get_device_index( + f"first_vicuna.model.model.layers.{idx}[\s.$]" + ) module = SharkInference( mlirs[idx], device=device, - device_idx=idx % 1, + device_idx=device_idx, mlir_dialect="tm_tensor", ) module.save_module( @@ -255,19 +418,25 @@ def compile_to_vmfb(self, inputs, layers, device="cpu", is_first=True): vmfb_path = Path(f"{idx}_1.vmfb") if vmfb_path.exists(): # print(f"Found layer {idx} vmfb") + device_idx = self.get_device_index( + f"second_vicuna.model.model.layers.{idx}[\s.$]" + ) module = SharkInference( None, device=device, - device_idx=idx % 1, + device_idx=device_idx, mlir_dialect="tm_tensor", ) module.load_module(vmfb_path) else: print(f"Compiling layer {idx} vmfb") + device_idx = self.get_device_index( + f"second_vicuna.model.model.layers.{idx}[\s.$]" + ) module = SharkInference( mlirs[idx], device=device, - device_idx=idx % 1, + device_idx=device_idx, mlir_dialect="tm_tensor", ) module.save_module( @@ -303,6 +472,42 @@ def get_sharded_model(self, device="cpu"): torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]), ) + norm = VicunaNorm(vicuna_model.model.norm) + device_idx = self.get_device_index( + r"vicuna\.model\.model\.norm(?:\.|\s|$)" + ) + print(device_idx) + norm = self.compile_norm( + norm, + torch.zeros([1, SAMPLE_INPUT_LEN, 4096]), + device=self.device, + device_idx=device_idx, + ) + + embeddings = VicunaEmbedding(vicuna_model.model.embed_tokens) + device_idx = self.get_device_index( + r"vicuna\.model\.model\.embed_tokens(?:\.|\s|$)" + ) + print(device_idx) + embeddings = self.compile_embedding( + embeddings, + (torch.zeros([1, SAMPLE_INPUT_LEN], dtype=torch.int64)), + device=self.device, + device_idx=device_idx, + ) + + lmhead = LMHead(vicuna_model.lm_head) + device_idx = self.get_device_index( + r"vicuna\.model\.lm_head(?:\.|\s|$)" + ) + print(device_idx) + lmhead = self.compile_lmhead( + lmhead, + torch.zeros([1, SAMPLE_INPUT_LEN, 4096]), + device=self.device, + device_idx=device_idx, + ) + layers0 = [ FirstVicunaLayer(layer) for layer in vicuna_model.model.layers ] @@ -323,7 +528,12 @@ def get_sharded_model(self, device="cpu"): shark_layers1 = [CompiledSecondVicunaLayer(m) for m in modules1] sharded_model = ShardedVicunaModel( - vicuna_model, shark_layers0, shark_layers1 + vicuna_model, + shark_layers0, + shark_layers1, + lmhead, + embeddings, + norm, ) return sharded_model