Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix gguf for mixtral. #490

Open
wants to merge 4 commits into
base: rc_054
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions aphrodite/engine/async_aphrodite.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,8 @@ def from_engine_args(cls,
engine_config = engine_args.create_engine_config()

if engine_config.device_config.device_type == "neuron":
raise NotImplementedError("Neuron is not supported for "
"async engine yet.")
from aphrodite.executor.neuron_executor import NeuronExecutor
executor_class = NeuronExecutor
elif engine_config.device_config.device_type == "cpu":
from aphrodite.executor.cpu_executor import CPUExecutor
executor_class = CPUExecutor
Expand Down
18 changes: 18 additions & 0 deletions aphrodite/executor/neuron_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from aphrodite.lora.request import LoRARequest
from aphrodite.executor.executor_base import ExecutorBase
from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
from aphrodite.common.utils import make_async


class NeuronExecutor(ExecutorBase):
Expand Down Expand Up @@ -57,6 +58,23 @@ def execute_model(self,
seq_group_metadata_list=seq_group_metadata_list)
return output

async def execute_model_async(
self, seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int,
int],
blocks_to_copy: Dict[int, List[int]],
num_lookahead_slots: int) -> List[SamplerOutput]:
assert (blocks_to_swap_in == {} and blocks_to_swap_out == {}
and blocks_to_copy == {}), (
"Cache operations are not supported for Neuron backend.")
assert num_lookahead_slots == 0, (
"lookahead not supported for Neuron backend.")

output = await make_async(
self.driver_worker.execute_model
)(seq_group_metadata_list=seq_group_metadata_list)
return output

def add_lora(self, lora_request: LoRARequest) -> bool:
return self.driver_worker.add_lora(lora_request)

Expand Down
17 changes: 17 additions & 0 deletions aphrodite/modeling/hf_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import glob
import json
import os
import re
from collections import defaultdict
from typing import Any, Iterable, Iterator, List, Optional, Tuple

Expand Down Expand Up @@ -237,6 +238,8 @@ def convert_gguf_to_state_dict(checkpoint, config):
# hack: ggufs have a different name than transformers
if model_type == "cohere":
model_type = "command-r"
elif model_type == "mistral" or model_type == "mixtral":
model_type = "llama"
arch = None
for key, value in MODEL_ARCH_NAMES.items():
if value == model_type:
Expand All @@ -252,8 +255,22 @@ def convert_gguf_to_state_dict(checkpoint, config):

gguf_to_hf_name_map = {}
keys_to_remove = []
prog = re.compile(
r"model.layers.([^\.]*).block_sparse_moe.experts.([^\.]*).([^\.]*)")
for hf_name in state_dict:
name, suffix = hf_name.rsplit(".", 1)
if match := prog.fullmatch(name): # mixtral
bid, xid, wid = match.groups()
if wid == "w1":
wname = "ffn_gate"
elif wid == "w2":
wname = "ffn_down"
elif wid == "w3":
wname = "ffn_up"
gguf_name = f"blk.{bid}.{wname}.{xid}"
gguf_to_hf_name_map[f"{gguf_name}.{suffix}"] = hf_name
continue

gguf_name = name_map.get_name(name)
if gguf_name:
gguf_to_hf_name_map[f"{gguf_name}.{suffix}"] = hf_name
Expand Down
11 changes: 7 additions & 4 deletions aphrodite/modeling/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from aphrodite.modeling.hf_downloader import (default_weight_loader,
hf_model_weights_iterator)
from aphrodite.common.sequence import SamplerOutput
from aphrodite.quantization.gguf import GGUFLinearMethod


class MixtralMLP(nn.Module):
Expand Down Expand Up @@ -119,10 +120,12 @@ def __init__(
if self.linear_method is None:
self.linear_method = UnquantizedLinearMethod()

self.gate = ReplicatedLinear(self.hidden_size,
self.num_total_experts,
bias=False,
linear_method=None)
self.gate = ReplicatedLinear(
self.hidden_size,
self.num_total_experts,
bias=False,
linear_method=linear_method
if isinstance(linear_method, GGUFLinearMethod) else None)

if not isinstance(
self.linear_method, UnquantizedLinearMethod
Expand Down
5 changes: 4 additions & 1 deletion aphrodite/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
from aphrodite.transformers_utils.configs import (BaiChuanConfig, DbrxConfig,
ChatGLMConfig, MPTConfig,
QWenConfig, RWConfig)
from aphrodite.quantization.gguf_utils import GGUFReader
from aphrodite.common.utils import is_neuron

if not is_neuron():
from aphrodite.quantization.gguf_utils import GGUFReader

_CONFIG_REGISTRY = {
"baichuan": BaiChuanConfig,
Expand Down
5 changes: 3 additions & 2 deletions aphrodite/transformers_utils/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
from loguru import logger

from aphrodite.lora.request import LoRARequest
from aphrodite.common.utils import make_async
from aphrodite.quantization.gguf_utils import GGUFReader
from aphrodite.common.utils import make_async, is_neuron
from aphrodite.transformers_utils.tokenizers import BaichuanTokenizer

if not is_neuron():
from aphrodite.quantization.gguf_utils import GGUFReader

def convert_gguf_to_tokenizer(checkpoint):
if os.path.isfile(checkpoint):
Expand Down
1 change: 0 additions & 1 deletion requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ fastapi
colorlog
einops # for phi
prometheus_client # for prometheus metrics
triton >= 2.2.0
lark == 1.1.8 # for grammars
scipy # for quip
rich
Expand Down