-
Notifications
You must be signed in to change notification settings - Fork 0
/
create_awq.py
30 lines (24 loc) · 952 Bytes
/
create_awq.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer
import torch
import argparse
# Parse command-line arguments
parser = argparse.ArgumentParser(description='Quantize and process a model with AWQ.')
parser.add_argument('--model', type=str, help='The path to the model to be processed')
args = parser.parse_args()
model_path = args.model
quant_name = model_path.split("/")[-1] + "-AWQ"
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }
#Load model
model = AutoAWQForCausalLM.from_pretrained(
model_path, safetensors=True, **{"low_cpu_mem_usage": True}
)
tokenizer = AutoTokenizer.from_pretrained(model_path,
trust_remote_code = True)
# Quantize
model.quantize(
tokenizer,
quant_config=quant_config,
)
model.save_quantized(quant_name, safetensors=True , shard_size="10GB")
tokenizer.save_pretrained(quant_name)