Skip to content

Commit

Permalink
add skywork support for covnert script.
Browse files Browse the repository at this point in the history
  • Loading branch information
yxq321 committed Nov 8, 2023
1 parent e9c1cec commit 56a5838
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 3 deletions.
12 changes: 9 additions & 3 deletions convert-baichuan-hf-to-gguf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
# HF baichuan --> gguf conversion
# HF skywork / baichuan --> gguf conversion

from __future__ import annotations

Expand Down Expand Up @@ -110,15 +110,21 @@ def parse_args() -> argparse.Namespace:
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
hparams = json.load(f)
print("hello print: ",hparams["architectures"][0])
if hparams["architectures"][0] != "BaichuanForCausalLM" and hparams["architectures"][0] != "BaiChuanForCausalLM":
if hparams["architectures"][0] != "SkyworkForCausalLM" and \
hparams["architectures"][0] != "BaichuanForCausalLM" and \
hparams["architectures"][0] != "BaiChuanForCausalLM":
print("Model architecture not supported: " + hparams["architectures"][0])

sys.exit()

# get number of model parts
num_parts = count_model_parts(dir_model)
print(f"num_parts:{num_parts}\n")
ARCH=gguf.MODEL_ARCH.BAICHUAN
if hparams["architectures"][0] == "SkyworkForCausalLM":
ARCH=gguf.MODEL_ARCH.SKYWORK
else:
ARCH = gguf.MODEL_ARCH.BAICHUAN

gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], endianess=endianess)

print("gguf: get model metadata")
Expand Down
22 changes: 22 additions & 0 deletions gguf-py/gguf/gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class MODEL_ARCH(IntEnum):
LLAMA : int = auto()
FALCON : int = auto()
BAICHUAN : int = auto()
SKYWORK : int = auto()
GPT2 : int = auto()
GPTJ : int = auto()
GPTNEOX : int = auto()
Expand Down Expand Up @@ -123,6 +124,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.LLAMA: "llama",
MODEL_ARCH.FALCON: "falcon",
MODEL_ARCH.BAICHUAN: "baichuan",
MODEL_ARCH.SKYWORK: "skywork",
MODEL_ARCH.GPT2: "gpt2",
MODEL_ARCH.GPTJ: "gptj",
MODEL_ARCH.GPTNEOX: "gptneox",
Expand Down Expand Up @@ -213,6 +215,22 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.SKYWORK: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_ROT_EMBD,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.STARCODER: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.POS_EMBD,
Expand Down Expand Up @@ -318,6 +336,10 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_ROT_EMBD,
],
MODEL_ARCH.SKYWORK: [ # TODO by yxq
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_ROT_EMBD,
],
MODEL_ARCH.PERSIMMON: [
MODEL_TENSOR.ROPE_FREQS,
]
Expand Down

0 comments on commit 56a5838

Please sign in to comment.