From 95d3a5689677572d6ad1a9096387da9711880111 Mon Sep 17 00:00:00 2001 From: zhenqincn_intern Date: Tue, 22 Aug 2023 11:33:51 +0800 Subject: [PATCH] fix the problem of `element 0 of tensors does not require grad and does not have a grad_fn` --- LICENSE | 26 ++++++++++ .../core/configs/cfg_compression.py | 4 +- federatedscope/llm/baseline/qlora.yaml | 44 +++++++++++++++++ federatedscope/llm/model/adapter_builder.py | 36 ++++++++++++++ federatedscope/llm/model/model_builder.py | 48 ++++++++++++++++++- setup.py | 3 +- 6 files changed, 156 insertions(+), 5 deletions(-) create mode 100644 federatedscope/llm/baseline/qlora.yaml diff --git a/LICENSE b/LICENSE index 5236fe4c1..31adea9f8 100644 --- a/LICENSE +++ b/LICENSE @@ -721,6 +721,32 @@ furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + + +--------------------------------------------------------------------------------- +The implementations of quantization method of [QLoRA](https://arxiv.org/abs/2305.14314) for LLM is adapted from https://github.com/artidoro/qlora (MIT License) + +MIT License + +Copyright (c) 2023 Artidoro Pagnoni, Tim Dettmers + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE diff --git a/federatedscope/core/configs/cfg_compression.py b/federatedscope/core/configs/cfg_compression.py index c4d90c610..2ad5ffcae 100644 --- a/federatedscope/core/configs/cfg_compression.py +++ b/federatedscope/core/configs/cfg_compression.py @@ -22,10 +22,10 @@ def extend_compression_cfg(cfg): def assert_compression_cfg(cfg): - if cfg.quantization.method.lower() not in ['none', 'uniform']: + if cfg.quantization.method.lower() not in ['none', 'uniform', 'qlora']: logger.warning( f'Quantization method is expected to be one of ["none",' - f'"uniform"], but got "{cfg.quantization.method}". So we ' + f'"uniform", "qlora"], but got "{cfg.quantization.method}". So we ' f'change it to "none"') if cfg.quantization.method.lower( diff --git a/federatedscope/llm/baseline/qlora.yaml b/federatedscope/llm/baseline/qlora.yaml new file mode 100644 index 000000000..37ffc0acd --- /dev/null +++ b/federatedscope/llm/baseline/qlora.yaml @@ -0,0 +1,44 @@ +use_gpu: True +device: 2 +early_stop: + patience: 0 +federate: + mode: standalone + client_num: 1 + total_round_num: 500 + save_to: "llama.ckpt" +data: + root: data/ + type: 'alpaca@llm' + splits: [0.98,0.01,0.01] + splitter: 'iid' +llm: + tok_len: 1000 + chat: + max_len: 2000 + adapter: + use: True + args: [ { 'adapter_package': 'peft', 'adapter_method': 'qlora', 'r': 8, 'lora_alpha': 16, 'lora_dropout': 0.05 } ] +dataloader: + batch_size: 1 +model: + # type: 'decapoda-research/llama-7b-hf@huggingface_llm' + type: 'openlm-research/open_llama_3b@huggingface_llm' + # type: 'gpt2@huggingface_llm' +train: + local_update_steps: 30 + batch_or_epoch: batch + optimizer: + lr: 0.0003 + weight_decay: 0.0 + is_enable_half: True +criterion: + type: CrossEntropyLoss +trainer: + type: llmtrainer +eval: + freq: 1 + metrics: ['loss'] + count_flops: False +quantization: + method: qlora \ No newline at end of file diff --git a/federatedscope/llm/model/adapter_builder.py b/federatedscope/llm/model/adapter_builder.py index d1621c85c..f8efb44c2 100644 --- a/federatedscope/llm/model/adapter_builder.py +++ b/federatedscope/llm/model/adapter_builder.py @@ -10,6 +10,7 @@ def enable_adapter(model, package, adapter, **kwargs): PEFT: https://github.com/huggingface/peft Support methods: LoRA + QLoRA Prefix Tuning P-Tuning Prompt Tuning @@ -20,6 +21,41 @@ def enable_adapter(model, package, adapter, **kwargs): from peft import LoraConfig peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, **kwargs) model = get_peft_model(model, peft_config) + elif adapter == 'qlora': + # The implementation of QLoRA is adapted from https://github.com/artidoro/qlora + import bitsandbytes as bnb + from peft import LoraConfig + from peft.tuners.lora import LoraLayer + def find_all_linear_names(bits, model): + cls = bnb.nn.Linear4bit if bits == 4 else (bnb.nn.Linear8bitLt if bits == 8 else torch.nn.Linear) + lora_module_names = set() + for name, module in model.named_modules(): + if isinstance(module, cls): + names = name.split('.') + lora_module_names.add(names[0] if len(names) == 1 else names[-1]) + if 'lm_head' in lora_module_names: # needed for 16-bit + lora_module_names.remove('lm_head') + return list(lora_module_names) + peft_config = LoraConfig( + r=kwargs['r'], + lora_alpha=kwargs['lora_alpha'], + target_modules=find_all_linear_names(bits=4, model=model), + lora_dropout=kwargs['lora_dropout'], + bias="none", + task_type=TaskType.CAUSAL_LM, + ) + # without the following line, @https://github.com/huggingface/peft/issues/137 + model.enable_input_require_grads() + model = get_peft_model(model, peft_config) + # for name, module in model.named_modules(): + # if isinstance(module, LoraLayer): + # module = module.to(torch.bfloat16) + # if 'norm' in name: + # module = module.to(torch.float32) + # if 'lm_head' in name or 'embed_tokens' in name: + # if hasattr(module, 'weight'): + # if module.weight.dtype == torch.float32: + # module = module.to(torch.bfloat16) elif adapter == 'prefix': from peft import PrefixTuningConfig peft_config = PrefixTuningConfig(task_type=TaskType.CAUSAL_LM, diff --git a/federatedscope/llm/model/model_builder.py b/federatedscope/llm/model/model_builder.py index 49c8f0f53..29b3a0e7f 100644 --- a/federatedscope/llm/model/model_builder.py +++ b/federatedscope/llm/model/model_builder.py @@ -1,4 +1,7 @@ from federatedscope.llm.model.adapter_builder import AdapterModel +from transformers import BitsAndBytesConfig +import torch +from peft import prepare_model_for_kbit_training def get_model_from_huggingface(model_name, config): @@ -7,8 +10,49 @@ def get_model_from_huggingface(model_name, config): kwargs = {} if len(config.llm.cache.model): kwargs['cache_dir'] = config.llm.cache.model - - return AutoModelForCausalLM.from_pretrained(model_name, **kwargs) + if config.quantization.method == 'qlora': + # The implementation of QLoRA is adapted from https://github.com/artidoro/qlora + # kwargs['load_in_4bit'] = True + # kwargs['load_in_8bit'] = False + # kwargs['quantization_config'] = BitsAndBytesConfig( + # load_in_4bit=True, + # load_in_8bit=False, + # llm_int8_threshold=6.0, + # llm_int8_has_fp16_weight=False, + # # bnb_4bit_compute_dtype=torch.bfloat16, + # bnb_4bit_compute_dtype=torch.float32, + # bnb_4bit_use_double_quant=True, + # bnb_4bit_quant_type='nf4' + # ) + # kwargs['device_map'] = config.device + # kwargs['torch_dtype'] = torch.float32 + # kwargs['trust_remote_code'] = False + # kwargs['use_auth_token'] = False + # print('\n\n\n\n\n\n\n\n\n\n\n LLM Model Loaded with K-bit quant \n\n\n\n\n\n\n\n\n\n\n ') + model = AutoModelForCausalLM.from_pretrained( + model_name, + load_in_4bit=True, + load_in_8bit=False, + device_map=config.device, + quantization_config=BitsAndBytesConfig( + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + # bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4' + ), + torch_dtype=torch.bfloat16, + trust_remote_code=False, + use_auth_token=False + # cache_dir=config.llm.cache.model if len(config.llm.cache.model) else None + ) + model.config.torch_dtype=torch.float32 + return prepare_model_for_kbit_training(model, use_gradient_checkpointing=True) + else: + return AutoModelForCausalLM.from_pretrained(model_name, **kwargs) def get_model_from_modelscope(model_name, config): diff --git a/setup.py b/setup.py index 3ce80a0e7..4f72c0a36 100644 --- a/setup.py +++ b/setup.py @@ -53,8 +53,9 @@ 'tokenizers==0.13.3', 'transformers==4.29.2', 'accelerate==0.20.3', - 'peft==0.3.0', + 'peft==0.4.0', # required by QLoRA: prepare_model_for_kbit_training 'sentencepiece==0.1.99', + 'bitsandbytes==0.41.1' ] benchmark_hpo_requires = [