"Tuner" refers to additional structures attached to a model to reduce the number of training parameters or improve training accuracy. Currently, SWIFT supports the following tuners:
- LoRA: LORA: LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS
- LoRA+: LoRA+: Efficient Low Rank Adaptation of Large Models
- LLaMA PRO: LLAMA PRO: Progressive LLaMA with Block Expansion
- GaLore: GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection
- LISA: LISA: Layerwise Importance Sampling for Memory-Efficient Large Language Model Fine-Tuning
- UnSloth: https://github.com/unslothai/unsloth
- SCEdit: SCEdit: Efficient and Controllable Image Diffusion Generation via Skip Connection Editing < arXiv | Project Page >
- NEFTune: Noisy Embeddings Improve Instruction Finetuning
- LongLoRA: Efficient Fine-tuning of Long-Context Large Language Models
- Adapter: Parameter-Efficient Transfer Learning for NLP
- Vision Prompt Tuning: Visual Prompt Tuning
- Side: Side-Tuning: A Baseline for Network Adaptation via Additive Side Networks
- Res-Tuning: Res-Tuning: A Flexible and Efficient Tuning Paradigm via Unbinding Tuner from Backbone < arXiv | Project Page | Usage >
- Tuners provided by PEFT, such as IA3, AdaLoRA, etc.
Call Swift.prepare_model()
to add tuners to the model:
from modelscope import Model
from swift import Swift, LoraConfig
import torch
model = Model.from_pretrained('ZhipuAI/chatglm3-6b', torch_dtype=torch.bfloat16, device_map='auto')
lora_config = LoraConfig(
r=16,
target_modules=['query_key_value'],
lora_alpha=32,
lora_dropout=0.)
model = Swift.prepare_model(model, lora_config)
Multiple tuners can also be used simultaneously:
from modelscope import Model
from swift import Swift, LoraConfig, AdapterConfig
import torch
model = Model.from_pretrained('ZhipuAI/chatglm3-6b', torch_dtype=torch.bfloat16, device_map='auto')
lora_config = LoraConfig(
r=16,
target_modules=['query_key_value'],
lora_alpha=32,
lora_dropout=0.)
adapter_config = AdapterConfig(
dim=model.config.hidden_size,
target_modules=['mlp'],
method_name='forward',
hidden_pos=0,
adapter_length=32,
)
model = Swift.prepare_model(model, {'first_tuner': lora_config, 'second_tuner': adapter_config})
# use model to do other things
When using multiple tuners, the second parameter should be a Dict where the key is the tuner name and the value is the tuner configuration.
After training, you can call:
model.save_pretrained(save_directory='./output')
to store the model checkpoint. The model checkpoint file will only include the weights of the tuners, not the weights of the model itself. The stored structure is as follows:
outputs
|-- configuration.json
|-- first_tuner
|-- adapter_config.json
|-- adapter_model.bin
|-- second_tuner
|-- adapter_config.json
|-- adapter_model.bin
|-- ...
If only a single config is passed in, the default name default
will be used:
outputs
|-- configuration.json
|-- default
|-- adapter_config.json
|-- adapter_model.bin
|-- ...
# A100 18G memory
from swift import Seq2SeqTrainer, Seq2SeqTrainingArguments
from modelscope import MsDataset, AutoTokenizer
from modelscope import AutoModelForCausalLM
from swift import Swift, LoraConfig
from swift.llm import get_template, TemplateType
import torch
# load model
model = AutoModelForCausalLM.from_pretrained('ZhipuAI/chatglm3-6b', torch_dtype=torch.bfloat16, device_map='auto', trust_remote_code=True)
lora_config = LoraConfig(
r=16,
target_modules=['query_key_value'],
lora_alpha=32,
lora_dropout=0.05)
model = Swift.prepare_model(model, lora_config)
tokenizer = AutoTokenizer.from_pretrained('ZhipuAI/chatglm3-6b', trust_remote_code=True)
dataset = MsDataset.load('AI-ModelScope/alpaca-gpt4-data-en', split='train')
template = get_template(TemplateType.chatglm3, tokenizer, max_length=1024)
def encode(example):
inst, inp, output = example['instruction'], example.get('input', None), example['output']
if output is None:
return {}
if inp is None or len(inp) == 0:
q = inst
else:
q = f'{inst}\n{inp}'
example, kwargs = template.encode({'query': q, 'response': output})
return example
dataset = dataset.map(encode).filter(lambda e: e.get('input_ids'))
dataset = dataset.train_test_split(test_size=0.001)
train_dataset, val_dataset = dataset['train'], dataset['test']
train_args = Seq2SeqTrainingArguments(
output_dir='output',
learning_rate=1e-4,
num_train_epochs=2,
eval_steps=500,
save_steps=500,
evaluation_strategy='steps',
save_strategy='steps',
dataloader_num_workers=4,
per_device_train_batch_size=1,
gradient_accumulation_steps=16,
logging_steps=10,
)
trainer = Seq2SeqTrainer(
model=model,
args=train_args,
data_collator=template.data_collator,
train_dataset=train_dataset,
eval_dataset=val_dataset,
tokenizer=tokenizer)
trainer.train()
Use Swift.from_pretrained()
to load the stored checkpoint:
from modelscope import Model
from swift import Swift
import torch
model = Model.from_pretrained('ZhipuAI/chatglm2-6b', torch_dtype=torch.bfloat16, device_map='auto')
model = Swift.from_pretrained(model, './output')
# A100 14G memory
import torch
from modelscope import AutoModelForCausalLM, GenerationConfig
from modelscope import AutoTokenizer
from swift import Swift
from swift.llm import get_template, TemplateType, to_device
# load model
model = AutoModelForCausalLM.from_pretrained('ZhipuAI/chatglm3-6b', torch_dtype=torch.bfloat16,
device_map='auto', trust_remote_code=True)
model = Swift.from_pretrained(model, 'output/checkpoint-xxx')
tokenizer = AutoTokenizer.from_pretrained('ZhipuAI/chatglm3-6b', trust_remote_code=True)
template = get_template(TemplateType.chatglm3, tokenizer, max_length=1024)
examples, tokenizer_kwargs = template.encode({'query': 'How are you?'})
if 'input_ids' in examples:
input_ids = torch.tensor(examples['input_ids'])[None]
examples['input_ids'] = input_ids
token_len = input_ids.shape[1]
generation_config = GenerationConfig(
max_new_tokens=1024,
temperature=0.3,
top_k=25,
top_p=0.8,
do_sample=True,
repetition_penalty=1.0,
num_beams=10,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id)
device = next(model.parameters()).device
examples = to_device(examples, device)
generate_ids = model.generate(
generation_config=generation_config,
**examples)
generate_ids = template.get_generate_ids(generate_ids, token_len)
print(tokenizer.decode(generate_ids, **tokenizer_kwargs))
# I'm an AI language model, so I don't have feelings or physical sensations. However, I'm here to assist you with any questions or tasks you may have. How can I help you today?
-
Swift.prepare_model(model, config, **kwargs)
- Explain: Load a tuner onto the model. If it is a subclass of PeftConfig, use the corresponding interface of the Peft library to load the tuner. When using SwiftConfig, this interface can accept a SwiftModel instance and be called repeatedly, which has the same effect as passing a dictionary to config.
- This interface supports parallel loading of multiple tuners of different types for simultaneous use.
- Parameters:
model
: An instance oftorch.nn.Module
orSwiftModel
, the model to be loadedconfig
: An instance ofSwiftConfig
,PeftConfig
, or a dictionary of custom tuner names to configs
- Return value: An instance of
SwiftModel
orPeftModel
- Explain: Load a tuner onto the model. If it is a subclass of PeftConfig, use the corresponding interface of the Peft library to load the tuner. When using SwiftConfig, this interface can accept a SwiftModel instance and be called repeatedly, which has the same effect as passing a dictionary to config.
-
Swift.merge_and_unload(model)
- Explain: Merge the LoRA weights back into the original model and completely unload the LoRA part
- Parameters:
- model: An instance of
SwiftModel
orPeftModel
, the model instance with LoRA loaded
- model: An instance of
- Return value: None
-
Swift.merge(model)
-
Explain: Merge the LoRA weights back into the original model without unloading the LoRA part
-
Parameters:
- model: An instance of
SwiftModel
orPeftModel
, the model instance with LoRA loaded
- model: An instance of
-
Return value: None
-
-
Swift.unmerge(model)
-
Explain: Split the LoRA weights from the original model weights back into the LoRA structure
-
Parameters:
- model: An instance of
SwiftModel
orPeftModel
, the model instance with LoRA loaded
- model: An instance of
-
Return value: None
-
-
Swift.save_to_peft_format(ckpt_dir, output_dir)
-
Explain: Convert the stored LoRA checkpoint to a Peft compatible format. The main changes are:
-
default
will be split from the correspondingdefault
folder into the output_dir root directory -
The
{tuner_name}.
field in weights will be removed, for examplemodel.layer.0.self.in_proj.lora_A.default.weight
will becomemodel.layer.0.self.in_proj.lora_A.weight
-
The prefix
basemodel.model
will be added to the keys in weights -
Note: Only LoRA can be converted, other types of tuners cannot be converted due to Peft itself not supporting them. Additionally, when there are extra parameters like
dtype
set in LoRAConfig, it does not support conversion to Peft format. In this case, you can manually delete the corresponding fields in adapter_config.json
-
-
Parameters:
- ckpt_dir: Original weights directory
- output_dir: Target weights directory
-
Return value: None
-
-
Swift.from_pretrained(model, model_id, adapter_name, revision, **kwargs)
- Explain: Load tuners from the stored weights directory onto the model. If adapter_name is not passed, all tuners under the model_id directory will be loaded. Same as
prepare_model
, this interface can be called repeatedly. - Parameters:
- model: An instance of
torch.nn.Module
orSwiftModel
, the model to be loaded - model_id:
str
type, the tuner checkpoint to be loaded, can be a ModelScope hub id or a local directory produced by training - adapter_name:
str
orList[str]
orDict[str, str]
type orNone
, the tuner name in the tuner directory to be loaded. IfNone
, all named tuners will be loaded. Ifstr
orList[str]
, only certain specific tuners will be loaded. IfDict
, the tuner indicated bykey
will be loaded and renamed tovalue
. - revision: If model_id is a ModelScope id, revision can specify the corresponding version number
- model: An instance of
- Explain: Load tuners from the stored weights directory onto the model. If adapter_name is not passed, all tuners under the model_id directory will be loaded. Same as
The following lists the interfaces that users may call. Other internal interfaces or interfaces not recommended for use can be viewed through the make docs
command to generate the API Doc documentation.
-
SwiftModel.create_optimizer_param_groups(self, **defaults)
- Explain: Create parameter groups based on the loaded tuners, currently only effective for the
LoRA+
algorithm - Parameters:
- defaults: Default parameters for
optimizer_groups
, such aslr
andweight_decay
- defaults: Default parameters for
- Return value:
- The created
optimizer_groups
- The created
- Explain: Create parameter groups based on the loaded tuners, currently only effective for the
-
SwiftModel.add_weighted_adapter(self, ...)
- Explain: Merge existing LoRA tuners into one
- Parameters:
- This interface is a transparent pass-through of PeftModel.add_weighted_adapter, parameters can refer to: add_weighted_adapter documentation
-
SwiftModel.save_pretrained(self, save_directory, safe_serialization, adapter_name)
- Explain: Store tuner weights
- Parameters:
- save_directory: Storage directory
- safe_serialization: Whether to use safe_tensors, default is False
- adapter_name: The adapter tuner to store, if not passed, all tuners will be stored by default
-
SwiftModel.set_active_adapters(self, adapter_names, offload=None)
- Explain: Set the currently active adapters, adapters not in the list will be deactivated
- In
inference
, the environment variableUSE_UNIQUE_THREAD=0/1
is supported, default value is1
. If0
, set_active_adapters only takes effect for the current thread. In this case, the tuners activated by this thread are used by default, and tuners in different threads do not interfere with each other.
- In
- Parameters:
- adapter_names: Activated tuners
- offload: How to handle deactivated adapters, default is
None
which means leave them in GPU memory. Bothcpu
andmeta
are supported, indicating offloading to cpu and meta devices to reduce GPU memory consumption. WhenUSE_UNIQUE_THREAD=0
, do not pass a value to offload to avoid affecting other threads.
- Return value: None
- Explain: Set the currently active adapters, adapters not in the list will be deactivated
-
SwiftModel.activate_adapter(self, adapter_name)
- Explain: Activate a tuner
- In
inference
, the environment variableUSE_UNIQUE_THREAD=0/1
is supported, default value is1
. If0
, activate_adapter only takes effect for the current thread. In this case, the tuners activated by this thread are used by default, and tuners in different threads do not interfere with each other.
- In
- Parameters:
- adapter_name: The name of the tuner to activate
- Return value: None
- Explain: Activate a tuner
-
SwiftModel.deactivate_adapter(self, adapter_name, offload)
- Explain: Deactivate a tuner
- When the environment variable
USE_UNIQUE_THREAD=0
, do not call this interface
- When the environment variable
- Parameters:
- adapter_name: The name of the tuner to deactivate
- offload: How to handle deactivated adapters, default is
None
which means leave them in GPU memory. Bothcpu
andmeta
are supported, indicating offloading to cpu and meta devices to reduce GPU memory consumption
- Return value: None
- Explain: Deactivate a tuner
-
SwiftModel.get_trainable_parameters(self)
-
Explain: Return training parameter information
-
Parameters: None
-
Return value: Training parameter information, format is as follows:
trainable params: 100M || all params: 1000M || trainable%: 10.00% || cuda memory: 10GiB.
-