Skip to content

Commit

Permalink
Add support for Quantized StableLM-3B model
Browse files Browse the repository at this point in the history
  • Loading branch information
vivekkhandelwal1 committed Dec 8, 2023
1 parent 8caf874 commit 27edd5c
Showing 1 changed file with 24 additions and 7 deletions.
31 changes: 24 additions & 7 deletions apps/language_models/src/pipelines/stablelm_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,14 @@ def __init__(
super().__init__(model_name, hf_model_path, max_num_tokens)
self.max_sequence_len = 256
self.device = device
if precision != "int4" and args.hf_auth_token == None:
raise ValueError(
""" HF auth token required for StableLM-3B. Pass it using
--hf_auth_token flag. You can ask for the access to the model
here: https://huggingface.co/tiiuae/falcon-180B-chat."""
)
self.hf_auth_token = args.hf_auth_token

self.precision = precision
self.debug = debug
self.tokenizer = self.get_tokenizer()
Expand All @@ -86,12 +94,23 @@ def shouldStop(self, tokens):
return False

def get_src_model(self):
kwargs = {}
if self.precision == "int4":
self.hf_model_path = "TheBloke/stablelm-zephyr-3b-GPTQ"
from transformers import GPTQConfig

quantization_config = GPTQConfig(bits=4, disable_exllama=True)
kwargs["quantization_config"] = quantization_config
kwargs["device_map"] = "cpu"
print("[DEBUG] Loading Model")
model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path,
trust_remote_code=True,
torch_dtype=torch.float32,
use_auth_token="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",
use_auth_token=self.hf_auth_token,
**kwargs,
)
print("[DEBUG] Model loaded successfully")
return model

def get_model_inputs(self):
Expand All @@ -100,9 +119,7 @@ def get_model_inputs(self):
return input_ids, attention_mask

def compile(self):
tmp_model_name = (
f"stableLM_linalg_{self.precision}_seqLen{self.max_sequence_len}"
)
tmp_model_name = f"{self.model_name}_linalg_{self.precision}_seqLen{self.max_sequence_len}"

# device = "cuda" # "cpu"
# TODO: vmfb and mlir name should include precision and device
Expand Down Expand Up @@ -168,7 +185,7 @@ def compile(self):
def get_tokenizer(self):
tok = AutoTokenizer.from_pretrained(
self.hf_model_path,
use_auth_token="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",
use_auth_token=self.hf_auth_token,
)
tok.add_special_tokens({"pad_token": "<PAD>"})
# print("[DEBUG] Sucessfully loaded the tokenizer to the memory")
Expand Down Expand Up @@ -242,8 +259,8 @@ def generate_new_token(self, params):
args = parser.parse_args()

stable_lm = SharkStableLM(
model_name="StableLM",
hf_model_path="stabilityai/stablelm-3b-4e1t",
model_name="stablelm_zephyr_3b",
hf_model_path="stabilityai/stablelm-zephyr-3b",
device=args.device,
precision=args.precision,
)
Expand Down

0 comments on commit 27edd5c

Please sign in to comment.