Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experimental Support to QLoRA #702

Open
wants to merge 3 commits into
base: llm
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -703,3 +703,28 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

---------------------------------------------------------------------------------
The implementations of qlora in federatedscope/llm/model/model_builder.py and
federatedscope/llm/model/adapter_builder.py are adapted from
https://github.com/artidoro/qlora (MIT License)

Copyright (c) 2023 Artidoro Pagnoni, Tim Dettmers

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
9 changes: 9 additions & 0 deletions federatedscope/core/auxiliaries/optimizer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,15 @@ def get_optimizer(model, type, lr, **kwargs):
**tmp_kwargs)
else:
return getattr(torch.optim, type)(model, lr, **tmp_kwargs)
elif 'bit.' in type:
type = type.split('.')[-1]
import bitsandbytes
if isinstance(model, torch.nn.Module):
return getattr(bitsandbytes.optim, type)(model.parameters(),
lr, **tmp_kwargs)
else:
return getattr(bitsandbytes.optim, type)(model, lr,
**tmp_kwargs)
else:
raise NotImplementedError(
'Optimizer {} not implement'.format(type))
Expand Down
41 changes: 41 additions & 0 deletions federatedscope/core/configs/cfg_computation_quantization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import logging

from federatedscope.core.configs.config import CN
from federatedscope.register import register_config

logger = logging.getLogger(__name__)


def extend_computation_quantization_cfg(cfg):
# ---------------------------------------------------------------------- #
# quantization (for memory/computation efficiency) related options
# ---------------------------------------------------------------------- #
cfg.computation_quantization = CN()

# Params
# ['qlora', 'uniform']
cfg.computation_quantization.method = 'none'
cfg.computation_quantization.nbits = 4 # [4,8,16]

# --------------- register corresponding check function ----------
cfg.register_cfg_check_fun(assert_quant_cfg)


def assert_quant_cfg(cfg):

if cfg.quantization.method.lower() not in ['none', 'qlora']:
logger.warning(
'Quantization for Communication method is expected '
'to be one of ["none","qlora"]',
f'but got "{cfg.quantization.method}". So we',
'change it to "none"')

if cfg.quantization.method.lower(
) != 'none' and cfg.quantization.nbits not in [4, 8, 16]:
raise ValueError(f'The value of cfg.quantization.nbits is invalid, '
f'which is expected to be one on [4, 8, 16] but got '
f'{cfg.quantization.nbits}.')


register_config("computation_quantization",
extend_computation_quantization_cfg)
11 changes: 9 additions & 2 deletions federatedscope/core/trainers/torch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,13 @@ class GeneralTorchTrainer(Trainer):
def get_model_para(self):
if self.cfg.federate.process_num > 1 or \
self.cfg.federate.share_local_model or \
self.cfg.llm.deepspeed.use:
self.cfg.llm.deepspeed.use or \
self.cfg.computation_quantization.method == 'qlora':
# bitsandbytes quantization does not support model discharge
# provided by [email protected]
# https://github.com/huggingface/transformers/blob/
# fb7d246951d5f60aa36a7958841dfea72f51fc6b/src/
# transformers/trainer.py#L506C9-L512C1
return self._param_filter(self.ctx.model.state_dict())
else:
return self._param_filter(self.ctx.model.cpu().state_dict())
Expand Down Expand Up @@ -467,5 +473,6 @@ def discharge_model(self):
return

if not self.cfg.federate.share_local_model and \
not self.cfg.llm.deepspeed.use:
not self.cfg.llm.deepspeed.use and \
not self.cfg.computation_quantization.method == 'qlora':
self.ctx.model.to(torch.device("cpu"))
47 changes: 47 additions & 0 deletions federatedscope/llm/baseline/qlora.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
use_gpu: True
device: 0
early_stop:
patience: 0
federate:
mode: standalone
client_num: 10
total_round_num: 500
save_to: "llama.ckpt"
share_local_model: True
data:
root: data/
type: 'alpaca@llm'
splits: [0.98,0.01,0.01]
splitter: 'iid'
llm:
tok_len: 1000
chat:
max_len: 2000
adapter:
use: True
args: [ { 'adapter_package': 'peft', 'adapter_method': 'qlora', 'r': 8, 'lora_alpha': 16, 'lora_dropout': 0.05 } ]
dataloader:
batch_size: 1
model:
# type: 'decapoda-research/llama-7b-hf@huggingface_llm'
type: 'openlm-research/open_llama_3b@huggingface_llm'
# type: 'gpt2@huggingface_llm'
train:
local_update_steps: 30
batch_or_epoch: batch
optimizer:
type: bit.SGD
lr: 0.0003
weight_decay: 0.0005
momentum: 0.9
is_enable_half: True
criterion:
type: CrossEntropyLoss
trainer:
type: llmtrainer
eval:
freq: 2
metrics: ['loss']
count_flops: False
computation_quantization:
method: qlora
45 changes: 45 additions & 0 deletions federatedscope/llm/model/adapter_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def enable_adapter(model, package, adapter, **kwargs):
PEFT: https://github.com/huggingface/peft
Support methods:
LoRA
QLoRA
Prefix Tuning
P-Tuning
Prompt Tuning
Expand All @@ -38,6 +39,50 @@ def enable_adapter(model, package, adapter, **kwargs):
from peft import LoraConfig
peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, **kwargs)
model = get_peft_model(model, peft_config)
elif adapter == 'qlora':
# The implementation of QLoRA is adapted from
# https://github.com/artidoro/qlora
import bitsandbytes as bnb
from peft import LoraConfig
from peft.tuners.lora import LoraLayer

def find_all_linear_names(bits, model):
cls = bnb.nn.Linear4bit if bits == 4 else \
(bnb.nn.Linear8bitLt if bits == 8 else torch.nn.Linear)
lora_module_names = set()
for name, module in model.named_modules():
if isinstance(module, cls):
names = name.split('.')
lora_module_names.add(names[0] if len(names) ==
1 else names[-1])
if 'lm_head' in lora_module_names: # needed for 16-bit
lora_module_names.remove('lm_head')
return list(lora_module_names)

peft_config = LoraConfig(
r=kwargs['r'],
lora_alpha=kwargs['lora_alpha'],
target_modules=find_all_linear_names(bits=4, model=model),
lora_dropout=kwargs['lora_dropout'],
bias="none",
task_type=TaskType.CAUSAL_LM,
)
# without the following line, an error with
# `element 0 of tensors does not require grad
# and does not have a grad_fn`
# would be caused
# @https://github.com/huggingface/peft/issues/137
model.enable_input_require_grads()
model = get_peft_model(model, peft_config)
for name, module in model.named_modules():
if isinstance(module, LoraLayer):
module = module.to(torch.float16)
if 'norm' in name:
module = module.to(torch.float32)
if 'lm_head' in name or 'embed_tokens' in name:
if hasattr(module, 'weight'):
if module.weight.dtype == torch.float32:
module = module.to(torch.float16)
elif adapter == 'prefix':
from peft import PrefixTuningConfig
peft_config = PrefixTuningConfig(task_type=TaskType.CAUSAL_LM,
Expand Down
25 changes: 24 additions & 1 deletion federatedscope/llm/model/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,30 @@ def get_model_from_huggingface(model_name, config):
kwargs = {}
if len(config.llm.cache.model):
kwargs['cache_dir'] = config.llm.cache.model

if config.computation_quantization.method == 'qlora':
from transformers import BitsAndBytesConfig
import torch
from peft import prepare_model_for_kbit_training
model = AutoModelForCausalLM.from_pretrained(
model_name,
load_in_4bit=True,
load_in_8bit=False,
device_map=config.device,
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
load_in_8bit=False,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
# bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4'),
torch_dtype=torch.bfloat16,
trust_remote_code=False,
# use_auth_token=False
)
return prepare_model_for_kbit_training(model,
use_gradient_checkpointing=True)
return AutoModelForCausalLM.from_pretrained(model_name, **kwargs)


Expand Down
5 changes: 4 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@
'tokenizers==0.13.3',
'transformers==4.29.2',
'accelerate==0.20.3',
'peft==0.3.0',
# required by QLoRA: prepare_model_for_kbit_training
'peft==0.4.0',
# required by QLoRA
'bitsandbytes==0.41.1',
'sentencepiece==0.1.99',
]

Expand Down