Skip to content

Commit

Permalink
Add support for StableLM-3B model
Browse files Browse the repository at this point in the history
  • Loading branch information
vivekkhandelwal1 committed Dec 6, 2023
1 parent dfdd3b1 commit 6b1c6ef
Showing 1 changed file with 129 additions and 21 deletions.
150 changes: 129 additions & 21 deletions apps/language_models/src/pipelines/stablelm_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,49 @@
from io import BytesIO
from pathlib import Path
from apps.language_models.utils import (
get_torch_mlir_module_bytecode,
get_vmfb_from_path,
)
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
from apps.language_models.src.model_wrappers.stablelm_model import (
StableLMModel,
)
import argparse

parser = argparse.ArgumentParser(
prog="stablelm runner",
description="runs a StableLM model",
)

parser.add_argument(
"--precision", "-p", default="fp16", choices=["fp32", "fp16", "int4"]
)
parser.add_argument("--device", "-d", default="cuda", help="vulkan, cpu, cuda")
parser.add_argument(
"--stablelm_vmfb_path", default=None, help="path to StableLM's vmfb"
)
parser.add_argument(
"--stablelm_mlir_path",
default=None,
help="path to StableLM's mlir file",
)
parser.add_argument(
"--use_precompiled_model",
default=True,
action=argparse.BooleanOptionalAction,
help="use the precompiled vmfb",
)
parser.add_argument(
"--load_mlir_from_shark_tank",
default=True,
action=argparse.BooleanOptionalAction,
help="download precompile mlir from shark tank",
)
parser.add_argument(
"--hf_auth_token",
type=str,
default=None,
help="Specify your own huggingface authentication token for stablelm-3B model.",
)


class StopOnTokens(StoppingCriteria):
Expand All @@ -29,7 +65,7 @@ def __init__(
self,
model_name,
hf_model_path="stabilityai/stablelm-tuned-alpha-3b",
max_num_tokens=512,
max_num_tokens=256,
device="cuda",
precision="fp32",
debug="False",
Expand All @@ -51,7 +87,10 @@ def shouldStop(self, tokens):

def get_src_model(self):
model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, torch_dtype=torch.float32
self.hf_model_path,
trust_remote_code=True,
torch_dtype=torch.float32,
use_auth_token="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",
)
return model

Expand Down Expand Up @@ -83,13 +122,19 @@ def compile(self):
print(
f"[DEBUG] mlir path {mlir_path} {'exists' if mlir_path.exists() else 'does not exist'}"
)
if mlir_path.exists():
with open(mlir_path, "rb") as f:
bytecode = f.read()
else:
if not mlir_path.exists():
model = StableLMModel(self.get_src_model())
model_inputs = self.get_model_inputs()
ts_graph = get_torch_mlir_module_bytecode(model, model_inputs)
from shark.shark_importer import import_with_fx

ts_graph = import_with_fx(
model,
model_inputs,
is_f16=True if self.precision in ["fp16", "int4"] else False,
precision=self.precision,
f16_input_mask=[False, False],
mlir_type="torchscript",
)
module = torch_mlir.compile(
ts_graph,
[*model_inputs],
Expand All @@ -100,15 +145,16 @@ def compile(self):
bytecode_stream = BytesIO()
module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
f_ = open(tmp_model_name + ".mlir", "wb")
f_.write(bytecode)
print("Saved mlir")
f_.close()
f_ = open(mlir_path, "wb")
f_.write(bytecode)
print("Saved mlir at: ", mlir_path)
f_.close()
del bytecode

from shark.shark_inference import SharkInference

shark_module = SharkInference(
mlir_module=bytecode, device=self.device, mlir_dialect="tm_tensor"
mlir_module=mlir_path, device=self.device, mlir_dialect="tm_tensor"
)
shark_module.compile()

Expand All @@ -120,14 +166,22 @@ def compile(self):
return shark_module

def get_tokenizer(self):
tok = AutoTokenizer.from_pretrained(self.hf_model_path)
tok = AutoTokenizer.from_pretrained(
self.hf_model_path,
use_auth_token="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",
)
tok.add_special_tokens({"pad_token": "<PAD>"})
# print("[DEBUG] Sucessfully loaded the tokenizer to the memory")
return tok

def generate(self, prompt):
words_list = []
import time

start = time.time()
count = 0
for i in range(self.max_num_tokens):
count = count + 1
params = {
"new_text": prompt,
}
Expand All @@ -145,6 +199,12 @@ def generate(self, prompt):
if detok == "":
break
prompt = prompt + detok
end = time.time()
print(
"\n\nTime taken is {:.2f} tokens/second\n".format(
count / (end - start)
)
)
return words_list

def generate_new_token(self, params):
Expand Down Expand Up @@ -178,10 +238,58 @@ def generate_new_token(self, params):
return ret_dict


# Initialize a StopOnTokens object
system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
- StableLM will refuse to participate in anything that could harm a human.
"""
if __name__ == "__main__":
args = parser.parse_args()

# Initialize a StopOnTokens object
system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
- StableLM will refuse to participate in anything that could harm a human.
"""

stable_lm = SharkStableLM(
model_name="StableLM",
hf_model_path="stabilityai/stablelm-3b-4e1t",
device=args.device,
precision=args.precision,
)

default_prompt_text = "The weather is always wonderful"
continue_execution = True

print("\n-----\nScript executing for the following config: \n")
print("StableLM Model: ", stable_lm.hf_model_path)
print("Precision: ", args.precision)
print("Device: ", args.device)

while continue_execution:
use_default_prompt = input(
"\nDo you wish to use the default prompt text? Y/N ?: "
)
if use_default_prompt in ["Y", "y"]:
prompt = default_prompt_text
else:
prompt = input("Please enter the prompt text: ")
print("\nPrompt Text: ", prompt)

# prompt_template = f"""A helpful assistant who helps the user with any questions asked.
# User: {prompt}
# Assistant:"""

res_str = stable_lm.generate(prompt)
torch.cuda.empty_cache()
import gc

gc.collect()
print(
"\n\n-----\nHere's the complete formatted result: \n\n",
prompt + "".join(res_str),
)
continue_execution = input(
"\nDo you wish to run script one more time? Y/N ?: "
)
continue_execution = (
True if continue_execution in ["Y", "y"] else False
)

0 comments on commit 6b1c6ef

Please sign in to comment.