Skip to content

Commit

Permalink
FIX Check for prefix tuning + grad checkpointing
Browse files Browse the repository at this point in the history
See huggingface#869

Since transformers is moving to the new cache implementation, we had to
change prefix tuning to use this too. However, caching does not work
with gradient checkpointing. Therefore, this currently runs into an
error about size mismatches.

Now, PEFT checks for gradient checkpointing and raises a helpful error.
  • Loading branch information
BenjaminBossan committed Oct 31, 2024
1 parent ff6dd9e commit 7fcc139
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 2 deletions.
3 changes: 3 additions & 0 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,9 @@ def _setup_prompt_encoder(self, adapter_name: str):
elif config.peft_type == PeftType.P_TUNING:
prompt_encoder = PromptEncoder(config)
elif config.peft_type == PeftType.PREFIX_TUNING:
# prefix tuning now uses Cache but that won't work with gradient checkpointing
if any(getattr(module, "gradient_checkpointing", False) for module in self.get_base_model().modules()):
raise ValueError("Prefix tuning does not work with gradient checkpointing.")
prompt_encoder = PrefixEncoder(config)
else:
raise ValueError("Not supported")
Expand Down
63 changes: 63 additions & 0 deletions tests/test_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
LoraConfig,
OFTConfig,
PrefixTuningConfig,
PromptLearningConfig,
PromptTuningConfig,
PromptTuningInit,
get_peft_model,
Expand All @@ -59,6 +60,17 @@
"task_type": "CAUSAL_LM",
}

SMALL_GRID = {
"model_ids": [
"hf-internal-testing/tiny-random-gpt2",
"hf-internal-testing/tiny-random-OPTForCausalLM",
"hf-internal-testing/tiny-random-MistralForCausalLM",
"peft-internal-testing/tiny-dummy-qwen2",
"trl-internal-testing/tiny-random-LlamaForCausalLM",
],
"task_type": "CAUSAL_LM",
}


def skip_adalora_and_gpt2(test_list):
return [test for test in test_list if not (("GPT2LMHeadModel" in test[1]) and (test[2] == AdaLoraConfig))]
Expand Down Expand Up @@ -91,6 +103,10 @@ def skip_adalora_or_oft_or_hra_and_gpt2(test_list):
]


def only_prompt_learning_filter(test_list):
return [test for test in test_list if issubclass(test[2], PromptLearningConfig)]


class PeftDecoderModelTester(unittest.TestCase, PeftCommonTester):
r"""
Test if the PeftModel behaves as expected. This includes:
Expand Down Expand Up @@ -505,3 +521,50 @@ def process(samples):
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
trainer.train()

@parameterized.expand(
PeftTestConfigManager.get_grid_parameters(SMALL_GRID, filter_params_func=only_prompt_learning_filter)
)
def test_prompt_learning_with_gradient_checkpointing(self, test_name, model_id, config_cls, config_kwargs):
# See issue 869
# Test prompt learning methods with gradient checkpointing in a semi realistic setting.
# Prefix tuning does not work if the model uses the new caching implementation. In that case, a helpful error
# should be raised.
peft_config = config_cls(
base_model_name_or_path=model_id,
**config_kwargs,
)
base_model = self.transformers_class.from_pretrained(model_id)
base_model.gradient_checkpointing_enable()

try:
model = get_peft_model(base_model, peft_config)
except ValueError as exc:
# Some methods will raise a helpful error. After this, exit the test, as training would fail.
assert config_cls == PrefixTuningConfig
assert "Prefix tuning does not work with gradient checkpointing" in str(exc)
return

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token

def process(samples):
tokenized = tokenizer(samples["quote"], truncation=True, max_length=128)
return tokenized

data = load_dataset("ybelkada/english_quotes_copy")
data = data.map(process, batched=True)

with tempfile.TemporaryDirectory() as tmp_dirname:
trainer = Trainer(
model=model,
train_dataset=data["train"],
args=TrainingArguments(
num_train_epochs=1,
max_steps=3,
per_device_train_batch_size=4,
output_dir=tmp_dirname,
),
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
trainer.train()
6 changes: 4 additions & 2 deletions tests/testing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1109,7 +1109,7 @@ def _test_training_layer_indexing(self, model_id, config_cls, config_kwargs):
assert nb_trainable < nb_trainable_all

def _test_training_gradient_checkpointing(self, model_id, config_cls, config_kwargs):
if issubclass(config_cls, PromptLearningConfig):
if config_cls == PrefixTuningConfig:
return pytest.skip(f"Test not applicable for {config_cls}")

if (config_cls == AdaLoraConfig) and ("roberta" in model_id.lower()):
Expand Down Expand Up @@ -1143,7 +1143,9 @@ def _test_training_gradient_checkpointing(self, model_id, config_cls, config_kwa
loss.backward()

for n, param in model.named_parameters():
if model.prefix in n:
if "prompt_encoder." in n: # prompt tuning methods
assert param.grad is not None
elif hasattr(model, "prefix") and (model.prefix in n): # non-prompt tuning methods
assert param.grad is not None
else:
assert param.grad is None
Expand Down

0 comments on commit 7fcc139

Please sign in to comment.