Skip to content

Commit

Permalink
Add support for Quantized StableLM-3B model
Browse files Browse the repository at this point in the history
  • Loading branch information
vivekkhandelwal1 committed Dec 7, 2023
1 parent 8caf874 commit 0cbd292
Showing 1 changed file with 18 additions and 1 deletion.
19 changes: 18 additions & 1 deletion apps/language_models/src/pipelines/stablelm_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,14 @@ def __init__(
super().__init__(model_name, hf_model_path, max_num_tokens)
self.max_sequence_len = 256
self.device = device
if precision != "int4" and args.hf_auth_token == None:
raise ValueError(
""" HF auth token required for StableLM-3B. Pass it using
--hf_auth_token flag. You can ask for the access to the model
here: https://huggingface.co/tiiuae/falcon-180B-chat."""
)
self.hf_auth_token = args.hf_auth_token

self.precision = precision
self.debug = debug
self.tokenizer = self.get_tokenizer()
Expand All @@ -86,11 +94,20 @@ def shouldStop(self, tokens):
return False

def get_src_model(self):
kwargs = {}
if self.precision == "int4":
self.hf_model_path = "yichunkuo/stablelm-3b-4e1t-gptq"
from transformers import GPTQConfig

quantization_config = GPTQConfig(bits=4, disable_exllama=True)
kwargs["quantization_config"] = quantization_config
kwargs["device_map"] = "cpu"
model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path,
trust_remote_code=True,
torch_dtype=torch.float32,
use_auth_token="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",
use_auth_token=self.hf_auth_token,
**kwargs,
)
return model

Expand Down

0 comments on commit 0cbd292

Please sign in to comment.