diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index a38e0750f2..df640fc4ed 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -647,6 +647,9 @@ def _setup_prompt_encoder(self, adapter_name: str): elif config.peft_type == PeftType.P_TUNING: prompt_encoder = PromptEncoder(config) elif config.peft_type == PeftType.PREFIX_TUNING: + # prefix tuning now uses Cache but that won't work with gradient checkpointing + if any(getattr(module, "gradient_checkpointing", False) for module in self.get_base_model().modules()): + raise ValueError("Prefix tuning does not work with gradient checkpointing.") prompt_encoder = PrefixEncoder(config) else: raise ValueError("Not supported") diff --git a/tests/test_decoder_models.py b/tests/test_decoder_models.py index 3ad373ac01..f8c9d2c65a 100644 --- a/tests/test_decoder_models.py +++ b/tests/test_decoder_models.py @@ -34,6 +34,7 @@ LoraConfig, OFTConfig, PrefixTuningConfig, + PromptLearningConfig, PromptTuningConfig, PromptTuningInit, get_peft_model, @@ -59,6 +60,17 @@ "task_type": "CAUSAL_LM", } +SMALL_GRID = { + "model_ids": [ + "hf-internal-testing/tiny-random-gpt2", + "hf-internal-testing/tiny-random-OPTForCausalLM", + "hf-internal-testing/tiny-random-MistralForCausalLM", + "peft-internal-testing/tiny-dummy-qwen2", + "trl-internal-testing/tiny-random-LlamaForCausalLM", + ], + "task_type": "CAUSAL_LM", +} + def skip_adalora_and_gpt2(test_list): return [test for test in test_list if not (("GPT2LMHeadModel" in test[1]) and (test[2] == AdaLoraConfig))] @@ -91,6 +103,10 @@ def skip_adalora_or_oft_or_hra_and_gpt2(test_list): ] +def only_prompt_learning_filter(test_list): + return [test for test in test_list if issubclass(test[2], PromptLearningConfig)] + + class PeftDecoderModelTester(unittest.TestCase, PeftCommonTester): r""" Test if the PeftModel behaves as expected. This includes: @@ -505,3 +521,50 @@ def process(samples): data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False), ) trainer.train() + + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters(SMALL_GRID, filter_params_func=only_prompt_learning_filter) + ) + def test_prompt_learning_with_gradient_checkpointing(self, test_name, model_id, config_cls, config_kwargs): + # See issue 869 + # Test prompt learning methods with gradient checkpointing in a semi realistic setting. + # Prefix tuning does not work if the model uses the new caching implementation. In that case, a helpful error + # should be raised. + peft_config = config_cls( + base_model_name_or_path=model_id, + **config_kwargs, + ) + base_model = self.transformers_class.from_pretrained(model_id) + base_model.gradient_checkpointing_enable() + + try: + model = get_peft_model(base_model, peft_config) + except ValueError as exc: + # Some methods will raise a helpful error. After this, exit the test, as training would fail. + assert config_cls == PrefixTuningConfig + assert "Prefix tuning does not work with gradient checkpointing" in str(exc) + return + + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.pad_token = tokenizer.eos_token + + def process(samples): + tokenized = tokenizer(samples["quote"], truncation=True, max_length=128) + return tokenized + + data = load_dataset("ybelkada/english_quotes_copy") + data = data.map(process, batched=True) + + with tempfile.TemporaryDirectory() as tmp_dirname: + trainer = Trainer( + model=model, + train_dataset=data["train"], + args=TrainingArguments( + num_train_epochs=1, + max_steps=3, + per_device_train_batch_size=4, + output_dir=tmp_dirname, + ), + data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False), + ) + trainer.train() diff --git a/tests/testing_common.py b/tests/testing_common.py index 954f79be5f..77dd529862 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -1109,7 +1109,7 @@ def _test_training_layer_indexing(self, model_id, config_cls, config_kwargs): assert nb_trainable < nb_trainable_all def _test_training_gradient_checkpointing(self, model_id, config_cls, config_kwargs): - if issubclass(config_cls, PromptLearningConfig): + if config_cls == PrefixTuningConfig: return pytest.skip(f"Test not applicable for {config_cls}") if (config_cls == AdaLoraConfig) and ("roberta" in model_id.lower()): @@ -1143,7 +1143,9 @@ def _test_training_gradient_checkpointing(self, model_id, config_cls, config_kwa loss.backward() for n, param in model.named_parameters(): - if model.prefix in n: + if "prompt_encoder." in n: # prompt tuning methods + assert param.grad is not None + elif hasattr(model, "prefix") and (model.prefix in n): # non-prompt tuning methods assert param.grad is not None else: assert param.grad is None