Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: Check for prefix tuning + gradient checkpointing fails #2191

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,9 @@ def _setup_prompt_encoder(self, adapter_name: str):
elif config.peft_type == PeftType.P_TUNING:
prompt_encoder = PromptEncoder(config)
elif config.peft_type == PeftType.PREFIX_TUNING:
# prefix tuning now uses Cache but that won't work with gradient checkpointing
if any(getattr(module, "gradient_checkpointing", False) for module in self.get_base_model().modules()):
raise ValueError("Prefix tuning does not work with gradient checkpointing.")
prompt_encoder = PrefixEncoder(config)
else:
raise ValueError("Not supported")
Expand Down
63 changes: 63 additions & 0 deletions tests/test_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
LoraConfig,
OFTConfig,
PrefixTuningConfig,
PromptLearningConfig,
PromptTuningConfig,
PromptTuningInit,
get_peft_model,
Expand All @@ -59,6 +60,17 @@
"task_type": "CAUSAL_LM",
}

SMALL_GRID = {
"model_ids": [
"hf-internal-testing/tiny-random-gpt2",
"hf-internal-testing/tiny-random-OPTForCausalLM",
"hf-internal-testing/tiny-random-MistralForCausalLM",
"peft-internal-testing/tiny-dummy-qwen2",
"trl-internal-testing/tiny-random-LlamaForCausalLM",
],
"task_type": "CAUSAL_LM",
}


def skip_adalora_and_gpt2(test_list):
return [test for test in test_list if not (("GPT2LMHeadModel" in test[1]) and (test[2] == AdaLoraConfig))]
Expand Down Expand Up @@ -91,6 +103,10 @@ def skip_adalora_or_oft_or_hra_and_gpt2(test_list):
]


def only_prompt_learning_filter(test_list):
return [test for test in test_list if issubclass(test[2], PromptLearningConfig)]


class PeftDecoderModelTester(unittest.TestCase, PeftCommonTester):
r"""
Test if the PeftModel behaves as expected. This includes:
Expand Down Expand Up @@ -505,3 +521,50 @@ def process(samples):
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
trainer.train()

@parameterized.expand(
PeftTestConfigManager.get_grid_parameters(SMALL_GRID, filter_params_func=only_prompt_learning_filter)
)
def test_prompt_learning_with_gradient_checkpointing(self, test_name, model_id, config_cls, config_kwargs):
# See issue 869
# Test prompt learning methods with gradient checkpointing in a semi realistic setting.
# Prefix tuning does not work if the model uses the new caching implementation. In that case, a helpful error
# should be raised.
peft_config = config_cls(
base_model_name_or_path=model_id,
**config_kwargs,
)
base_model = self.transformers_class.from_pretrained(model_id)
base_model.gradient_checkpointing_enable()

try:
model = get_peft_model(base_model, peft_config)
except ValueError as exc:
# Some methods will raise a helpful error. After this, exit the test, as training would fail.
assert config_cls == PrefixTuningConfig
assert "Prefix tuning does not work with gradient checkpointing" in str(exc)
return

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token

def process(samples):
tokenized = tokenizer(samples["quote"], truncation=True, max_length=128)
return tokenized

data = load_dataset("ybelkada/english_quotes_copy")
data = data.map(process, batched=True)

with tempfile.TemporaryDirectory() as tmp_dirname:
trainer = Trainer(
model=model,
train_dataset=data["train"],
args=TrainingArguments(
num_train_epochs=1,
max_steps=3,
per_device_train_batch_size=4,
output_dir=tmp_dirname,
),
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
trainer.train()
6 changes: 4 additions & 2 deletions tests/testing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1109,7 +1109,7 @@ def _test_training_layer_indexing(self, model_id, config_cls, config_kwargs):
assert nb_trainable < nb_trainable_all

def _test_training_gradient_checkpointing(self, model_id, config_cls, config_kwargs):
if issubclass(config_cls, PromptLearningConfig):
if config_cls == PrefixTuningConfig:
return pytest.skip(f"Test not applicable for {config_cls}")

if (config_cls == AdaLoraConfig) and ("roberta" in model_id.lower()):
Expand Down Expand Up @@ -1143,7 +1143,9 @@ def _test_training_gradient_checkpointing(self, model_id, config_cls, config_kwa
loss.backward()

for n, param in model.named_parameters():
if model.prefix in n:
if "prompt_encoder." in n: # prompt tuning methods
assert param.grad is not None
elif hasattr(model, "prefix") and (model.prefix in n): # non-prompt tuning methods
assert param.grad is not None
else:
assert param.grad is None
Expand Down
Loading