Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

modules_to_save Incorrect Overlap in Multiple LoRA Adapters #2206

Open
2 of 4 tasks
saeid93 opened this issue Nov 8, 2024 · 6 comments · May be fixed by #2220
Open
2 of 4 tasks

modules_to_save Incorrect Overlap in Multiple LoRA Adapters #2206

saeid93 opened this issue Nov 8, 2024 · 6 comments · May be fixed by #2220

Comments

@saeid93
Copy link
Contributor

saeid93 commented Nov 8, 2024

System Info

Python 3.11.9
transformers==4.40.2
peft==0.11.2

Who can help?

@BenjaminBossan
A bug occurs in the PEFT library when using multiple LoRA adapters, each with a unique modules_to_save configuration. The issue arises when the modules_to_save from the first LoRA adapter (e.g., adapter_1) is applied to subsequent adapters (e.g., adapter_2), rather than maintaining independent configurations. As a result, modules specified in modules_to_save for adapter_1 also appear in adapter_2, leading to unintended behavior and possibly affecting fine-tuning accuracy. This incorrect handling of modules_to_save causes duplicate entries where only the respective LoRA adapter’s modules should be saved.

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

The following example code demonstrates this issue, displaying the model structure where adapter_2 contains modules meant only for adapter_1.

Example Code

import os
from transformers import AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, PeftModel

# Get the directory of the current Python script
script_dir = os.path.dirname(os.path.abspath(__file__))

# Define relative paths for adapters
adapter_1_path = os.path.join(script_dir, "adapter_1")
adapter_2_path = os.path.join(script_dir, "adapter_2")

# Load base model
base_model = AutoModelForCausalLM.from_pretrained("gpt2")

# Define LoRA configs with different modules_to_save
lora_config_1 = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=["c_attn"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    modules_to_save=["lm_head"]
)

lora_config_2 = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=["c_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    modules_to_save=["wte"]
)

# Apply and save the first adapter
os.makedirs(adapter_1_path, exist_ok=True)
model_with_lora_1 = get_peft_model(base_model, lora_config_1, adapter_name="adapter_1")
model_with_lora_1.save_pretrained(adapter_1_path)

# Apply and save the second adapter
os.makedirs(adapter_2_path, exist_ok=True)
model_with_lora_2 = get_peft_model(base_model, lora_config_2, adapter_name="adapter_2")
model_with_lora_2.save_pretrained(adapter_2_path)

# Load a fresh base model and wrap it in PeftModel by loading the first adapter
base_model = AutoModelForCausalLM.from_pretrained("gpt2")
peft_model = PeftModel.from_pretrained(base_model, os.path.join(adapter_1_path, "adapter_1"), adapter_name="adapter_1")

# Load the second adapter into the PeftModel
peft_model.load_adapter(os.path.join(adapter_2_path, "adapter_2"), adapter_name="adapter_2")

# Display structure and inspect unexpected 'modules_to_save' overlap
print("Expected `modules_to_save` for each adapter:")
print("Adapter 1 `modules_to_save`: ['lm_head']")
print("Adapter 2 `modules_to_save`: ['wte']")
print("\nActual model structure and `modules_to_save` contents:\n")
print(peft_model.transformer.wte)
print(peft_model.lm_head)

The code output will be:

Expected `modules_to_save` for each adapter:
Adapter 1 `modules_to_save`: ['lm_head']
Adapter 2 `modules_to_save`: ['wte']

Actual model structure and `modules_to_save` contents:

ModulesToSaveWrapper(
  (original_module): Embedding(50257, 768)
  (modules_to_save): ModuleDict(
    (adapter_2): Embedding(50257, 768)
  )
)
ModulesToSaveWrapper(
  (original_module): Linear(in_features=768, out_features=50257, bias=False)
  (modules_to_save): ModuleDict(
    (adapter_1): Linear(in_features=768, out_features=50257, bias=False)
    (adapter_2): Linear(in_features=768, out_features=50257, bias=False)
  )
)

Expected behavior

As you see adapter 2 is also built for the "lm_head" module to which it shouldn't, the expected output is shown below:

Expected `modules_to_save` for each adapter:
Adapter 1 `modules_to_save`: ['lm_head']
Adapter 2 `modules_to_save`: ['wte']

Actual model structure and `modules_to_save` contents:

ModulesToSaveWrapper(
  (original_module): Embedding(50257, 768)
  (modules_to_save): ModuleDict(
    (adapter_2): Embedding(50257, 768)
  )
)
ModulesToSaveWrapper(
  (original_module): Linear(in_features=768, out_features=50257, bias=False)
  (modules_to_save): ModuleDict(
    (adapter_1): Linear(in_features=768, out_features=50257, bias=False)
  )
)
@BenjaminBossan
Copy link
Member

Thanks a lot for reporting this. Indeed, the handling of modules_to_save can be messy at times and the outcome you show should be avoided. I don't have the opportunity to test this right now, but my assumption is that this extra module won't disrupt the results for adapter 2 because it is a copy of the original layer and behaves exactly the same, as that right?

@saeid93
Copy link
Contributor Author

saeid93 commented Nov 11, 2024

No worries, glad to be of any help. As far as I have tested it should be fine and using the correct loaded layer, the only problem is redundancy in loaded modules. I also dug a bit deeper and noticed that the problem originates from this function:

def set_additional_trainable_modules(self, peft_config, adapter_name):

For an unknown reason when using load_adapter:
self.modules_to_save = set(peft_config.modules_to_save)

The set is not being updated to only the new layer and it will still hold the old layer in the set too (which shouldn't). For example if I manually hack the above script the problem will be solved:

...

# Apply and save the second adapter
os.makedirs(adapter_2_path, exist_ok=True)
model_with_lora_2 = get_peft_model(base_model, lora_config_2, adapter_name="adapter_2")
model_with_lora_2.save_pretrained(adapter_2_path)

# Load a fresh base model and wrap it in PeftModel by loading the first adapter
base_model = AutoModelForCausalLM.from_pretrained("gpt2")
peft_model = PeftModel.from_pretrained(base_model, os.path.join(adapter_1_path, "adapter_1"), adapter_name="adapter_1")

peft_model.modules_to_save = {"wte"} # <----------- HERE manually changing the modules_to_save
# Load the second adapter into the PeftModel
peft_model.load_adapter(os.path.join(adapter_2_path, "adapter_2"), adapter_name="adapter_2")

...

BenjaminBossan added a commit to BenjaminBossan/peft that referenced this issue Nov 18, 2024
Resolves huggingface#2206

NOT READY TO MERGE YET.

Tentative solution to that issue.

The problem is that we keep a "global" modules_to_save on the model
which contains all possible modules_to_save for each adapter. When the
first adapter targets layer "foo" with modules_to_save and the second
adapter targets "bar", then "foo" will create a copy of the original
module for the second adapter, even though it's not needed.

This does not change the result but is unnecessary and takes up memory.
Thus it should be avoided.

TODO: Tests.
@BenjaminBossan BenjaminBossan linked a pull request Nov 18, 2024 that will close this issue
@BenjaminBossan
Copy link
Member

Okay, I managed to reproduce the error. My tentative fix is in #2220. Right now, there is a CI issue but it should hopefully resolve itself soon. Meanwhile, it would be great if you could check if the fix makes sense to you.

Just a side note, when adding multiple adapters, don't use get_peft_model twice, use peft_model.add_adapter instead. But even with that change, the problem is reproducible.

@BenjaminBossan
Copy link
Member

@saeid93 did you have the opportunity to test this fix on your original use case?

@saeid93
Copy link
Contributor Author

saeid93 commented Dec 2, 2024

@BenjaminBossan sorry, I forgot to check this one. However, I just checked it after your message with a fresh installation of peft from your drafted pull requests and the problem I had earlier do not appear anymore! Thank you for fixing this one.

ModulesToSaveWrapper(
  (original_module): Embedding(50257, 768)
  (modules_to_save): ModuleDict(
    (adapter_2): Embedding(50257, 768)
  )
)
ModulesToSaveWrapper(
  (original_module): Linear(in_features=768, out_features=50257, bias=False)
  (modules_to_save): ModuleDict(
    (adapter_1): Linear(in_features=768, out_features=50257, bias=False)
  )
)

@BenjaminBossan
Copy link
Member

Great, thanks for testing!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants