From 56a58389875379bd0a9a80d2a0627f5efaeeba13 Mon Sep 17 00:00:00 2001 From: yxq321 <56987208+yxq321@users.noreply.github.com> Date: Wed, 8 Nov 2023 17:39:34 +0800 Subject: [PATCH] add skywork support for covnert script. --- convert-baichuan-hf-to-gguf.py | 12 +++++++++--- gguf-py/gguf/gguf.py | 22 ++++++++++++++++++++++ 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/convert-baichuan-hf-to-gguf.py b/convert-baichuan-hf-to-gguf.py index 67ccbe99f132a..a37e514d4960b 100755 --- a/convert-baichuan-hf-to-gguf.py +++ b/convert-baichuan-hf-to-gguf.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# HF baichuan --> gguf conversion +# HF skywork / baichuan --> gguf conversion from __future__ import annotations @@ -110,7 +110,9 @@ def parse_args() -> argparse.Namespace: with open(dir_model / "config.json", "r", encoding="utf-8") as f: hparams = json.load(f) print("hello print: ",hparams["architectures"][0]) -if hparams["architectures"][0] != "BaichuanForCausalLM" and hparams["architectures"][0] != "BaiChuanForCausalLM": +if hparams["architectures"][0] != "SkyworkForCausalLM" and \ + hparams["architectures"][0] != "BaichuanForCausalLM" and \ + hparams["architectures"][0] != "BaiChuanForCausalLM": print("Model architecture not supported: " + hparams["architectures"][0]) sys.exit() @@ -118,7 +120,11 @@ def parse_args() -> argparse.Namespace: # get number of model parts num_parts = count_model_parts(dir_model) print(f"num_parts:{num_parts}\n") -ARCH=gguf.MODEL_ARCH.BAICHUAN +if hparams["architectures"][0] == "SkyworkForCausalLM": + ARCH=gguf.MODEL_ARCH.SKYWORK +else: + ARCH = gguf.MODEL_ARCH.BAICHUAN + gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], endianess=endianess) print("gguf: get model metadata") diff --git a/gguf-py/gguf/gguf.py b/gguf-py/gguf/gguf.py index a2271d225d001..11b9a3b8a38d6 100644 --- a/gguf-py/gguf/gguf.py +++ b/gguf-py/gguf/gguf.py @@ -84,6 +84,7 @@ class MODEL_ARCH(IntEnum): LLAMA : int = auto() FALCON : int = auto() BAICHUAN : int = auto() + SKYWORK : int = auto() GPT2 : int = auto() GPTJ : int = auto() GPTNEOX : int = auto() @@ -123,6 +124,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.LLAMA: "llama", MODEL_ARCH.FALCON: "falcon", MODEL_ARCH.BAICHUAN: "baichuan", + MODEL_ARCH.SKYWORK: "skywork", MODEL_ARCH.GPT2: "gpt2", MODEL_ARCH.GPTJ: "gptj", MODEL_ARCH.GPTNEOX: "gptneox", @@ -213,6 +215,22 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.SKYWORK: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], MODEL_ARCH.STARCODER: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.POS_EMBD, @@ -318,6 +336,10 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_ROT_EMBD, ], + MODEL_ARCH.SKYWORK: [ # TODO by yxq + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_ROT_EMBD, + ], MODEL_ARCH.PERSIMMON: [ MODEL_TENSOR.ROPE_FREQS, ]