From bd5b2b468a85204b57c87697d060647791cb7634 Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Wed, 16 Aug 2023 17:21:35 +0000 Subject: [PATCH] Llama65B patch for int4 fp32 Signed-off-by: Abhishek Varma --- apps/language_models/scripts/vicuna.py | 185 ++++++----- .../src/model_wrappers/vicuna_model.py | 294 +++++++++++++++++- apps/stable_diffusion/web/ui/stablelm_ui.py | 11 + 3 files changed, 402 insertions(+), 88 deletions(-) diff --git a/apps/language_models/scripts/vicuna.py b/apps/language_models/scripts/vicuna.py index 486f3b817e..ade3ad7e7a 100644 --- a/apps/language_models/scripts/vicuna.py +++ b/apps/language_models/scripts/vicuna.py @@ -108,7 +108,7 @@ "--model_name", type=str, default="vicuna", - choices=["vicuna", "llama2_7b", "llama2_70b"], + choices=["vicuna", "llama_65b", "llama2_7b", "llama2_70b"], help="Specify which model to run.", ) parser.add_argument( @@ -161,7 +161,7 @@ class VicunaBase(SharkLLMBase): def __init__( self, model_name, - hf_model_path="TheBloke/vicuna-7B-1.1-HF", + hf_model_path="elinas/llama-65b-hf-transformers-4.29", max_num_tokens=512, device="cpu", precision="int8", @@ -433,7 +433,7 @@ class ShardedVicuna(VicunaBase): def __init__( self, model_name, - hf_model_path="TheBloke/vicuna-7B-1.1-HF", + hf_model_path="elinas/llama-65b-hf-transformers-4.29", max_num_tokens=512, device="cuda", precision="fp32", @@ -1212,7 +1212,7 @@ class UnshardedVicuna(VicunaBase): def __init__( self, model_name, - hf_model_path="TheBloke/vicuna-7B-1.1-HF", + hf_model_path="elinas/llama-65b-hf-transformers-4.29", hf_auth_token: str = None, max_num_tokens=512, device="cpu", @@ -1232,7 +1232,9 @@ def __init__( "HF auth token required. Pass it using --hf_auth_token flag." ) self.hf_auth_token = hf_auth_token - if self.model_name == "llama2_7b": + if self.model_name == "llama_65b": + self.hf_model_path = "elinas/llama-65b-hf-transformers-4.29" + elif self.model_name == "llama2_7b": self.hf_model_path = "meta-llama/Llama-2-7b-chat-hf" elif self.model_name == "llama2_70b": self.hf_model_path = "meta-llama/Llama-2-70b-chat-hf" @@ -1423,21 +1425,21 @@ def compile(self, download_vmfb=False): else: compilation_prompt = "".join(["0" for _ in range(17)]) - if Path(f"first_{self.precision}.mlir").exists(): - print(f"loading first_{self.precision}.mlir") - with open(Path(f"first_{self.precision}.mlir"), "r") as f: - first_module = f.read() + if Path(f"second_{self.precision}.mlir").exists(): + print(f"loading second_{self.precision}.mlir") + with open(Path(f"second_{self.precision}.mlir"), "r") as f: + second_module = f.read() else: - # generate first vicuna - compilation_input_ids = self.tokenizer( - compilation_prompt, - return_tensors="pt", - ).input_ids - compilation_input_ids = torch.tensor( - compilation_input_ids - ).reshape([1, 19]) - firstVicunaCompileInput = (compilation_input_ids,) - model = FirstVicuna( + # generate second vicuna + compilation_input_ids = torch.zeros( + [1, 1], dtype=torch.int64 + ) + pkv = tuple( + (torch.zeros([1, 64, 19, 128], dtype=torch.float32)) + for _ in range(160) + ) + secondVicunaCompileInput = (compilation_input_ids,) + pkv + model = SecondVicuna( self.hf_model_path, self.precision, self.weight_group_size, @@ -1447,27 +1449,33 @@ def compile(self, download_vmfb=False): print(f"[DEBUG] generating torchscript graph") ts_graph = import_with_fx( model, - firstVicunaCompileInput, + secondVicunaCompileInput, is_f16=self.precision == "fp16", precision=self.precision, - f16_input_mask=[False, False], + f16_input_mask=[False] + [True] * 160, mlir_type="torchscript", ) del model - firstVicunaCompileInput = list(firstVicunaCompileInput) - firstVicunaCompileInput[ - 0 - ] = torch_mlir.TensorPlaceholder.like( - firstVicunaCompileInput[0], dynamic_axes=[1] - ) - - firstVicunaCompileInput = tuple(firstVicunaCompileInput) - first_module = None + if self.precision == "fp16": + secondVicunaCompileInput = get_f16_inputs( + secondVicunaCompileInput, + True, + f16_input_mask=[False] + [True] * 160, + ) + secondVicunaCompileInput = list(secondVicunaCompileInput) + for i in range(len(secondVicunaCompileInput)): + if i != 0: + secondVicunaCompileInput[ + i + ] = torch_mlir.TensorPlaceholder.like( + secondVicunaCompileInput[i], dynamic_axes=[2] + ) + secondVicunaCompileInput = tuple(secondVicunaCompileInput) print(f"[DEBUG] generating torch mlir") if self.precision in ["int4", "int8"]: - first_module = torch_mlir.compile( + second_module = torch_mlir.compile( ts_graph, - [*firstVicunaCompileInput], + [*secondVicunaCompileInput], output_type=torch_mlir.OutputType.TORCH, backend_legal_ops=[ "brevitas.matmul_rhs_group_quant" @@ -1478,47 +1486,58 @@ def compile(self, download_vmfb=False): ) print(f"[DEBUG] converting torch to linalg") run_pipeline_with_repro_report( - first_module, + second_module, "builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)", description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR", ) else: - first_module = torch_mlir.compile( + second_module = torch_mlir.compile( ts_graph, - [*firstVicunaCompileInput], + [*secondVicunaCompileInput], torch_mlir.OutputType.LINALG_ON_TENSORS, use_tracing=False, verbose=False, ) + from contextlib import redirect_stdout + print("Writing : second_llama_65b_linalg_ir_before_dynamic ELIDED") + with open('second_llama_65b_linalg_ir_before_dynamic_elided.mlir', 'w') as f: + with redirect_stdout(f): + print(second_module.operation.get_asm(large_elements_limit=4)) + print("FINISHED") del ts_graph - del firstVicunaCompileInput + del secondVicunaCompileInput gc.collect() - print( - "[DEBUG] successfully generated first vicuna linalg mlir" + "[DEBUG] successfully generated second vicuna linalg mlir" ) - first_module = self.write_in_dynamic_inputs0( - str(first_module), dynamic_input_size=19 + second_module = self.write_in_dynamic_inputs1( + str(second_module) ) - if self.cache_vicunas: - with open(f"first_{self.precision}.mlir", "w+") as f: - f.write(first_module) + print("Writing : second_llama_65b_linalg_ir_after_dynamic ELIDED") + with open('second_llama_65b_linalg_ir_after_dynamic_elided.mlir', 'w') as f: + with redirect_stdout(f): + print(second_module.operation.get_asm(large_elements_limit=4)) + print("FINISHED") + # if self.cache_vicunas: + print("Writing : second_llama_65b_linalg_ir_after_dynamic") + with open(f"second_{self.precision}.mlir", "w+") as f: + f.write(second_module) - if Path(f"second_{self.precision}.mlir").exists(): - print(f"loading second_{self.precision}.mlir") - with open(Path(f"second_{self.precision}.mlir"), "r") as f: - second_module = f.read() + if Path(f"first_{self.precision}.mlir").exists(): + print(f"loading first_{self.precision}.mlir") + with open(Path(f"first_{self.precision}.mlir"), "r") as f: + first_module = f.read() else: - # generate second vicuna - compilation_input_ids = torch.zeros( - [1, 1], dtype=torch.int64 - ) - pkv = tuple( - (torch.zeros([1, 32, 19, 128], dtype=torch.float32)) - for _ in range(64) - ) - secondVicunaCompileInput = (compilation_input_ids,) + pkv - model = SecondVicuna( + # generate first vicuna + compilation_input_ids = self.tokenizer( + compilation_prompt, + return_tensors="pt", + ).input_ids + compilation_input_ids = torch.tensor( + compilation_input_ids + ).reshape([1, 19]) + firstVicunaCompileInput = (compilation_input_ids,) + model = FirstVicuna( self.hf_model_path, self.precision, self.weight_group_size, @@ -1528,33 +1547,27 @@ def compile(self, download_vmfb=False): print(f"[DEBUG] generating torchscript graph") ts_graph = import_with_fx( model, - secondVicunaCompileInput, + firstVicunaCompileInput, is_f16=self.precision == "fp16", precision=self.precision, - f16_input_mask=[False] + [True] * 64, + f16_input_mask=[False, False], mlir_type="torchscript", ) del model - if self.precision == "fp16": - secondVicunaCompileInput = get_f16_inputs( - secondVicunaCompileInput, - True, - f16_input_mask=[False] + [True] * 64, - ) - secondVicunaCompileInput = list(secondVicunaCompileInput) - for i in range(len(secondVicunaCompileInput)): - if i != 0: - secondVicunaCompileInput[ - i - ] = torch_mlir.TensorPlaceholder.like( - secondVicunaCompileInput[i], dynamic_axes=[2] - ) - secondVicunaCompileInput = tuple(secondVicunaCompileInput) + firstVicunaCompileInput = list(firstVicunaCompileInput) + firstVicunaCompileInput[ + 0 + ] = torch_mlir.TensorPlaceholder.like( + firstVicunaCompileInput[0], dynamic_axes=[1] + ) + + firstVicunaCompileInput = tuple(firstVicunaCompileInput) + first_module = None print(f"[DEBUG] generating torch mlir") if self.precision in ["int4", "int8"]: - second_module = torch_mlir.compile( + first_module = torch_mlir.compile( ts_graph, - [*secondVicunaCompileInput], + [*firstVicunaCompileInput], output_type=torch_mlir.OutputType.TORCH, backend_legal_ops=[ "brevitas.matmul_rhs_group_quant" @@ -1565,30 +1578,31 @@ def compile(self, download_vmfb=False): ) print(f"[DEBUG] converting torch to linalg") run_pipeline_with_repro_report( - second_module, + first_module, "builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)", description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR", ) else: - second_module = torch_mlir.compile( + first_module = torch_mlir.compile( ts_graph, - [*secondVicunaCompileInput], + [*firstVicunaCompileInput], torch_mlir.OutputType.LINALG_ON_TENSORS, use_tracing=False, verbose=False, ) del ts_graph - del secondVicunaCompileInput + del firstVicunaCompileInput gc.collect() + print( - "[DEBUG] successfully generated second vicuna linalg mlir" + "[DEBUG] successfully generated first vicuna linalg mlir" ) - second_module = self.write_in_dynamic_inputs1( - str(second_module) + first_module = self.write_in_dynamic_inputs0( + str(first_module), dynamic_input_size=19 ) if self.cache_vicunas: - with open(f"second_{self.precision}.mlir", "w+") as f: - f.write(second_module) + with open(f"first_{self.precision}.mlir", "w+") as f: + f.write(first_module) combined_module = self.combine_mlir_scripts( first_module, second_module, self.vicuna_mlir_path @@ -1752,6 +1766,7 @@ def autocomplete(self, prompt): model_list = { "vicuna": "vicuna=>TheBloke/vicuna-7B-1.1-HF", + "llama_65b": "elinas/llama-65b-hf-transformers-4.29", "llama2_7b": "llama2_7b=>meta-llama/Llama-2-7b-chat-hf", "llama2_70b": "llama2_70b=>meta-llama/Llama-2-70b-chat-hf", } diff --git a/apps/language_models/src/model_wrappers/vicuna_model.py b/apps/language_models/src/model_wrappers/vicuna_model.py index b4eff51542..77d69f9607 100644 --- a/apps/language_models/src/model_wrappers/vicuna_model.py +++ b/apps/language_models/src/model_wrappers/vicuna_model.py @@ -147,6 +147,102 @@ def forward( i62, i63, i64, + i65, + i66, + i67, + i68, + i69, + i70, + i71, + i72, + i73, + i74, + i75, + i76, + i77, + i78, + i79, + i80, + i81, + i82, + i83, + i84, + i85, + i86, + i87, + i88, + i89, + i90, + i91, + i92, + i93, + i94, + i95, + i96, + i97, + i98, + i99, + i100, + i101, + i102, + i103, + i104, + i105, + i106, + i107, + i108, + i109, + i110, + i111, + i112, + i113, + i114, + i115, + i116, + i117, + i118, + i119, + i120, + i121, + i122, + i123, + i124, + i125, + i126, + i127, + i128, + i129, + i130, + i131, + i132, + i133, + i134, + i135, + i136, + i137, + i138, + i139, + i140, + i141, + i142, + i143, + i144, + i145, + i146, + i147, + i148, + i149, + i150, + i151, + i152, + i153, + i154, + i155, + i156, + i157, + i158, + i159, + i160, ): # input_ids = input_tuple[0] # input_tuple = torch.unbind(pkv, dim=0) @@ -277,6 +373,198 @@ def forward( i63, i64, ), + ( + i65, + i66, + ), + ( + i67, + i68, + ), + ( + i69, + i70, + ), + ( + i71, + i72, + ), + ( + i73, + i74, + ), + ( + i75, + i76, + ), + ( + i77, + i78, + ), + ( + i79, + i80, + ), + ( + i81, + i82, + ), + ( + i83, + i84, + ), + ( + i85, + i86, + ), + ( + i87, + i88, + ), + ( + i89, + i90, + ), + ( + i91, + i92, + ), + ( + i93, + i94, + ), + ( + i95, + i96, + ), + ( + i97, + i98, + ), + ( + i99, + i100, + ), + ( + i101, + i102, + ), + ( + i103, + i104, + ), + ( + i105, + i106, + ), + ( + i107, + i108, + ), + ( + i109, + i110, + ), + ( + i111, + i112, + ), + ( + i113, + i114, + ), + ( + i115, + i116, + ), + ( + i117, + i118, + ), + ( + i119, + i120, + ), + ( + i121, + i122, + ), + ( + i123, + i124, + ), + ( + i125, + i126, + ), + ( + i127, + i128 + ), + ( + i129, + i130, + ), + ( + i131, + i132, + ), + ( + i133, + i134, + ), + ( + i135, + i136, + ), + ( + i137, + i138, + ), + ( + i139, + i140, + ), + ( + i141, + i142, + ), + ( + i143, + i144, + ), + ( + i145, + i146, + ), + ( + i147, + i148, + ), + ( + i149, + i150, + ), + ( + i151, + i152, + ), + ( + i153, + i154, + ), + ( + i155, + i156, + ), + ( + i157, + i158 + ), + ( + i159, + i160 + ) ) op = self.model( input_ids=token, use_cache=True, past_key_values=past_key_values @@ -305,9 +593,9 @@ def forward(self, input_ids): # generate second vicuna compilation_input_ids = torch.zeros([1, 1], dtype=torch.int64) pkv = tuple( - (torch.zeros([1, 32, 19, 128], dtype=torch.float32)) - for _ in range(64) + (torch.zeros([1, 64, 19, 128], dtype=torch.float32)) + for _ in range(160) ) secondVicunaCompileInput = (compilation_input_ids,) + pkv second_output = self.second_vicuna(*secondVicunaCompileInput) - return second_output + return second_output \ No newline at end of file diff --git a/apps/stable_diffusion/web/ui/stablelm_ui.py b/apps/stable_diffusion/web/ui/stablelm_ui.py index 2e9e56ff56..2fa136f722 100644 --- a/apps/stable_diffusion/web/ui/stablelm_ui.py +++ b/apps/stable_diffusion/web/ui/stablelm_ui.py @@ -23,6 +23,7 @@ def user(message, history): past_key_values = None model_map = { + "llama_65b": "elinas/llama-65b-hf-transformers-4.29", "llama2_7b": "meta-llama/Llama-2-7b-chat-hf", "llama2_70b": "meta-llama/Llama-2-70b-chat-hf", "codegen": "Salesforce/codegen25-7b-multi", @@ -34,6 +35,15 @@ def user(message, history): # NOTE: Each `model_name` should have its own start message start_message = { + "llama_65b": ( + "System: You are a helpful, respectful and honest assistant. Always answer " + "as helpfully as possible, while being safe. Your answers should not " + "include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal " + "content. Please ensure that your responses are socially unbiased and positive " + "in nature. If a question does not make any sense, or is not factually coherent, " + "explain why instead of answering something not correct. If you don't know the " + "answer to a question, please don't share false information." + ), "llama2_7b": ( "System: You are a helpful, respectful and honest assistant. Always answer " "as helpfully as possible, while being safe. Your answers should not " @@ -160,6 +170,7 @@ def chat( "vicuna4", "vicuna1p3", "codegen", + "llama_65b", "llama2_7b", "llama2_70b", ]: