From 8caf8747db7b8e5c0ec08668c77593fd1a8b8044 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Wed, 6 Dec 2023 12:02:46 +0000 Subject: [PATCH] Add support for StableLM-3B model --- .../src/pipelines/stablelm_pipeline.py | 138 +++++++++++++++--- 1 file changed, 117 insertions(+), 21 deletions(-) diff --git a/apps/language_models/src/pipelines/stablelm_pipeline.py b/apps/language_models/src/pipelines/stablelm_pipeline.py index c51796d8c2..05673e905b 100644 --- a/apps/language_models/src/pipelines/stablelm_pipeline.py +++ b/apps/language_models/src/pipelines/stablelm_pipeline.py @@ -4,13 +4,49 @@ from io import BytesIO from pathlib import Path from apps.language_models.utils import ( - get_torch_mlir_module_bytecode, get_vmfb_from_path, ) from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase from apps.language_models.src.model_wrappers.stablelm_model import ( StableLMModel, ) +import argparse + +parser = argparse.ArgumentParser( + prog="stablelm runner", + description="runs a StableLM model", +) + +parser.add_argument( + "--precision", "-p", default="fp16", choices=["fp32", "fp16", "int4"] +) +parser.add_argument("--device", "-d", default="cuda", help="vulkan, cpu, cuda") +parser.add_argument( + "--stablelm_vmfb_path", default=None, help="path to StableLM's vmfb" +) +parser.add_argument( + "--stablelm_mlir_path", + default=None, + help="path to StableLM's mlir file", +) +parser.add_argument( + "--use_precompiled_model", + default=True, + action=argparse.BooleanOptionalAction, + help="use the precompiled vmfb", +) +parser.add_argument( + "--load_mlir_from_shark_tank", + default=True, + action=argparse.BooleanOptionalAction, + help="download precompile mlir from shark tank", +) +parser.add_argument( + "--hf_auth_token", + type=str, + default=None, + help="Specify your own huggingface authentication token for stablelm-3B model.", +) class StopOnTokens(StoppingCriteria): @@ -29,7 +65,7 @@ def __init__( self, model_name, hf_model_path="stabilityai/stablelm-tuned-alpha-3b", - max_num_tokens=512, + max_num_tokens=256, device="cuda", precision="fp32", debug="False", @@ -51,7 +87,10 @@ def shouldStop(self, tokens): def get_src_model(self): model = AutoModelForCausalLM.from_pretrained( - self.hf_model_path, torch_dtype=torch.float32 + self.hf_model_path, + trust_remote_code=True, + torch_dtype=torch.float32, + use_auth_token="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", ) return model @@ -83,13 +122,19 @@ def compile(self): print( f"[DEBUG] mlir path {mlir_path} {'exists' if mlir_path.exists() else 'does not exist'}" ) - if mlir_path.exists(): - with open(mlir_path, "rb") as f: - bytecode = f.read() - else: + if not mlir_path.exists(): model = StableLMModel(self.get_src_model()) model_inputs = self.get_model_inputs() - ts_graph = get_torch_mlir_module_bytecode(model, model_inputs) + from shark.shark_importer import import_with_fx + + ts_graph = import_with_fx( + model, + model_inputs, + is_f16=True if self.precision in ["fp16", "int4"] else False, + precision=self.precision, + f16_input_mask=[False, False], + mlir_type="torchscript", + ) module = torch_mlir.compile( ts_graph, [*model_inputs], @@ -100,15 +145,16 @@ def compile(self): bytecode_stream = BytesIO() module.operation.write_bytecode(bytecode_stream) bytecode = bytecode_stream.getvalue() - f_ = open(tmp_model_name + ".mlir", "wb") - f_.write(bytecode) - print("Saved mlir") - f_.close() + f_ = open(mlir_path, "wb") + f_.write(bytecode) + print("Saved mlir at: ", mlir_path) + f_.close() + del bytecode from shark.shark_inference import SharkInference shark_module = SharkInference( - mlir_module=bytecode, device=self.device, mlir_dialect="tm_tensor" + mlir_module=mlir_path, device=self.device, mlir_dialect="tm_tensor" ) shark_module.compile() @@ -120,14 +166,22 @@ def compile(self): return shark_module def get_tokenizer(self): - tok = AutoTokenizer.from_pretrained(self.hf_model_path) + tok = AutoTokenizer.from_pretrained( + self.hf_model_path, + use_auth_token="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", + ) tok.add_special_tokens({"pad_token": ""}) # print("[DEBUG] Sucessfully loaded the tokenizer to the memory") return tok def generate(self, prompt): words_list = [] + import time + + start = time.time() + count = 0 for i in range(self.max_num_tokens): + count = count + 1 params = { "new_text": prompt, } @@ -145,6 +199,12 @@ def generate(self, prompt): if detok == "": break prompt = prompt + detok + end = time.time() + print( + "\n\nTime taken is {:.2f} tokens/second\n".format( + count / (end - start) + ) + ) return words_list def generate_new_token(self, params): @@ -178,10 +238,46 @@ def generate_new_token(self, params): return ret_dict -# Initialize a StopOnTokens object -system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version) -- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI. -- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user. -- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes. -- StableLM will refuse to participate in anything that could harm a human. -""" +if __name__ == "__main__": + args = parser.parse_args() + + stable_lm = SharkStableLM( + model_name="StableLM", + hf_model_path="stabilityai/stablelm-3b-4e1t", + device=args.device, + precision=args.precision, + ) + + default_prompt_text = "The weather is always wonderful" + continue_execution = True + + print("\n-----\nScript executing for the following config: \n") + print("StableLM Model: ", stable_lm.hf_model_path) + print("Precision: ", args.precision) + print("Device: ", args.device) + + while continue_execution: + use_default_prompt = input( + "\nDo you wish to use the default prompt text? Y/N ?: " + ) + if use_default_prompt in ["Y", "y"]: + prompt = default_prompt_text + else: + prompt = input("Please enter the prompt text: ") + print("\nPrompt Text: ", prompt) + + res_str = stable_lm.generate(prompt) + torch.cuda.empty_cache() + import gc + + gc.collect() + print( + "\n\n-----\nHere's the complete formatted result: \n\n", + prompt + "".join(res_str), + ) + continue_execution = input( + "\nDo you wish to run script one more time? Y/N ?: " + ) + continue_execution = ( + True if continue_execution in ["Y", "y"] else False + )