From 7fcc139003d56ac746f5cf320777d491d2104fda Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Thu, 31 Oct 2024 12:21:13 +0100 Subject: [PATCH] FIX Check for prefix tuning + grad checkpointing See #869 Since transformers is moving to the new cache implementation, we had to change prefix tuning to use this too. However, caching does not work with gradient checkpointing. Therefore, this currently runs into an error about size mismatches. Now, PEFT checks for gradient checkpointing and raises a helpful error. --- src/peft/peft_model.py | 3 ++ tests/test_decoder_models.py | 63 ++++++++++++++++++++++++++++++++++++ tests/testing_common.py | 6 ++-- 3 files changed, 70 insertions(+), 2 deletions(-) diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index a38e0750f2..df640fc4ed 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -647,6 +647,9 @@ def _setup_prompt_encoder(self, adapter_name: str): elif config.peft_type == PeftType.P_TUNING: prompt_encoder = PromptEncoder(config) elif config.peft_type == PeftType.PREFIX_TUNING: + # prefix tuning now uses Cache but that won't work with gradient checkpointing + if any(getattr(module, "gradient_checkpointing", False) for module in self.get_base_model().modules()): + raise ValueError("Prefix tuning does not work with gradient checkpointing.") prompt_encoder = PrefixEncoder(config) else: raise ValueError("Not supported") diff --git a/tests/test_decoder_models.py b/tests/test_decoder_models.py index 3ad373ac01..f8c9d2c65a 100644 --- a/tests/test_decoder_models.py +++ b/tests/test_decoder_models.py @@ -34,6 +34,7 @@ LoraConfig, OFTConfig, PrefixTuningConfig, + PromptLearningConfig, PromptTuningConfig, PromptTuningInit, get_peft_model, @@ -59,6 +60,17 @@ "task_type": "CAUSAL_LM", } +SMALL_GRID = { + "model_ids": [ + "hf-internal-testing/tiny-random-gpt2", + "hf-internal-testing/tiny-random-OPTForCausalLM", + "hf-internal-testing/tiny-random-MistralForCausalLM", + "peft-internal-testing/tiny-dummy-qwen2", + "trl-internal-testing/tiny-random-LlamaForCausalLM", + ], + "task_type": "CAUSAL_LM", +} + def skip_adalora_and_gpt2(test_list): return [test for test in test_list if not (("GPT2LMHeadModel" in test[1]) and (test[2] == AdaLoraConfig))] @@ -91,6 +103,10 @@ def skip_adalora_or_oft_or_hra_and_gpt2(test_list): ] +def only_prompt_learning_filter(test_list): + return [test for test in test_list if issubclass(test[2], PromptLearningConfig)] + + class PeftDecoderModelTester(unittest.TestCase, PeftCommonTester): r""" Test if the PeftModel behaves as expected. This includes: @@ -505,3 +521,50 @@ def process(samples): data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False), ) trainer.train() + + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters(SMALL_GRID, filter_params_func=only_prompt_learning_filter) + ) + def test_prompt_learning_with_gradient_checkpointing(self, test_name, model_id, config_cls, config_kwargs): + # See issue 869 + # Test prompt learning methods with gradient checkpointing in a semi realistic setting. + # Prefix tuning does not work if the model uses the new caching implementation. In that case, a helpful error + # should be raised. + peft_config = config_cls( + base_model_name_or_path=model_id, + **config_kwargs, + ) + base_model = self.transformers_class.from_pretrained(model_id) + base_model.gradient_checkpointing_enable() + + try: + model = get_peft_model(base_model, peft_config) + except ValueError as exc: + # Some methods will raise a helpful error. After this, exit the test, as training would fail. + assert config_cls == PrefixTuningConfig + assert "Prefix tuning does not work with gradient checkpointing" in str(exc) + return + + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.pad_token = tokenizer.eos_token + + def process(samples): + tokenized = tokenizer(samples["quote"], truncation=True, max_length=128) + return tokenized + + data = load_dataset("ybelkada/english_quotes_copy") + data = data.map(process, batched=True) + + with tempfile.TemporaryDirectory() as tmp_dirname: + trainer = Trainer( + model=model, + train_dataset=data["train"], + args=TrainingArguments( + num_train_epochs=1, + max_steps=3, + per_device_train_batch_size=4, + output_dir=tmp_dirname, + ), + data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False), + ) + trainer.train() diff --git a/tests/testing_common.py b/tests/testing_common.py index 954f79be5f..77dd529862 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -1109,7 +1109,7 @@ def _test_training_layer_indexing(self, model_id, config_cls, config_kwargs): assert nb_trainable < nb_trainable_all def _test_training_gradient_checkpointing(self, model_id, config_cls, config_kwargs): - if issubclass(config_cls, PromptLearningConfig): + if config_cls == PrefixTuningConfig: return pytest.skip(f"Test not applicable for {config_cls}") if (config_cls == AdaLoraConfig) and ("roberta" in model_id.lower()): @@ -1143,7 +1143,9 @@ def _test_training_gradient_checkpointing(self, model_id, config_cls, config_kwa loss.backward() for n, param in model.named_parameters(): - if model.prefix in n: + if "prompt_encoder." in n: # prompt tuning methods + assert param.grad is not None + elif hasattr(model, "prefix") and (model.prefix in n): # non-prompt tuning methods assert param.grad is not None else: assert param.grad is None