-
Notifications
You must be signed in to change notification settings - Fork 0
/
create_hqq.py
50 lines (46 loc) · 1.63 KB
/
create_hqq.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import argparse
import subprocess
import os
from huggingface_hub import create_repo, HfApi, ModelCard
import shutil
import torch
from hqq.engine.hf import HQQModelForCausalLM, AutoTokenizer
from hqq.models.hf.base import AutoHQQHFModel
from hqq.core.quantize import *
# Parse command-line arguments
parser = argparse.ArgumentParser(description='Quantize and process a model with HQQ.')
parser.add_argument('--model', required=True, type=str, help='The path to the model to be processed')
args = parser.parse_args()
model_name = args.model.split("/")[-1]
org_name = args.model.split("/")[0]
target_name = f"{org_name}/{model_name}-HQQ"
print(f"Downloading {model_name} ...")
with subprocess.Popen(['huggingface-cli', 'download', "--local-dir", args.model, args.model], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) as proc:
stdout, stderr = proc.communicate()
print("STDOUT:", stdout)
print("STDERR:", stderr)
os.mkdir(target_name)
print(f"Converting {args.model} to {target_name} ...")
# Quant config
quant_config = BaseQuantizeConfig(
nbits=2,
group_size=64
)
# Quantize model
model = HQQModelForCausalLM.from_pretrained(
'./'+args.model,
cache_dir=".",
attn_implementation="flash_attention_2",
torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained(args.model)
model.quantize_model(quant_config=quant_config, device='cuda')
model.save_quantized(target_name)
tokenizer.save_pretrained(target_name)
shutil.rmtree(args.model)
# Create model card
card = ModelCard.load(args.model)
if card.data.tags is None:
card.data.tags = []
card.data.tags.append("hqq")
card.save(f'{target_name}/README.md')