Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to correctly use Prefixing Tuning? #869

Closed
1 of 4 tasks
Vincent-Li-9701 opened this issue Aug 27, 2023 · 44 comments
Closed
1 of 4 tasks

How to correctly use Prefixing Tuning? #869

Vincent-Li-9701 opened this issue Aug 27, 2023 · 44 comments

Comments

@Vincent-Li-9701
Copy link

Vincent-Li-9701 commented Aug 27, 2023

System Info

peft 0.5.0
transformers 4.32.0

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

model = AutoModelForSeq2SeqLM.from_pretrained('bigscience/T0pp', load_in_8bit=True)
model = prepare_model_for_int8_training(model)
config = PrefixTuningConfig(
    task_type=TaskType.SEQ_2_SEQ_LM,
    num_virtual_tokens=100,
    token_dim=model.config.hidden_size,
    num_transformer_submodules=1,
    num_attention_heads=model.config.num_heads,
    num_layers=model.config.num_layers,
    encoder_hidden_size=1792,
)
model = get_peft_model(model, config)

Expected behavior

I'm assuming num_layers, num_attention_heads, and token_dim need to match the base model. In the sample num_transformer_submodules is 1. But encoder-decoder has two transformers right? Should this be 2?

When I run the code above I got

File "/python3.10/site-packages/transformers/models/t5/modeling_t5.py", line 551, in forward
position_bias = position_bias + mask  # (batch_size, n_heads, seq_length, key_length)
RuntimeError: The size of tensor a (3) must match the size of tensor b (103) at non-singleton dimension 3

When I print out the shape of position_bias and mask. mask has 100 more tokens than position_bias seems like on the decoder side. It's also taking in the prefix embeddings

@Vincent-Li-9701
Copy link
Author

@pacman100 Would you mind taking a look?

@LesterGong
Copy link

have you been solved this problem?

@Vincent-Li-9701
Copy link
Author

Vincent-Li-9701 commented Sep 5, 2023

@WhoopeeHg no unfortunately I decided to skip the prefix tuning part since I found that to be less effective than P-Tuning or LoRA on my dataset.

@Vincent-Li-9701
Copy link
Author

Based on my discussion with others, the problem seems to surface when we load the model with 8 bit quantization.

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@vikram71198
Copy link

@pacman100 @BenjaminBossan this issue occurs even without 8 bit quantization and basically renders HF's integration of Prefix Tuning useless, unless this bug is fixed.

@BenjaminBossan
Copy link
Member

@vikram71198 Could you please give more details? Ideally some code to reproduce the error and the PEFT version you're using.

@vikram71198
Copy link

My environment:

transformers = 4.38.1
peft = 0.8.2

This is the code snippet I'm using which has been adapted from here :

import os
import random
import os
import numpy as np
import pandas as pd
import torch
import datasets
from torch.utils.data import Dataset, DataLoader
import transformers
import peft
import trl
import json
from pprint import pprint
import flash_attn
import accelerate
from transformers import BitsAndBytesConfig
from tqdm import tqdm as tqdm
import mlflow

max_length = 500
lr = 2e-4
num_epochs = 5
batch_size = 1
num_virtual_tokens = 30
random_seed = 42

model_name = "teknium/OpenHermes-2.5-Mistral-7B"
text_column = "Transcript"
label_column = "RFC"

def get_prompt(transcript: str) -> str:
    prompt = """Transcript:

{transcript}

---

I want you to act as a transcript analysis expert. I have provided you with a transcript between agent & customer above and your goal is to summarize the reason why the customer calls up the agent. If there is no discernible reason, output "No reason identified".

Answer:"""

    return prompt.format(transcript = transcript)

def get_mistral_prompt(transcript: str, system_message : str = "") -> str:
    messages = [
        {"role": "system", "content": system_message},
        {"role": "user", "content": get_prompt(transcript)}
    ]
    return tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained(model_name, device_map = "auto", torch_dtype = torch.bfloat16)

tokenizer = AutoTokenizer.from_pretrained(model_name)

if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

tokenizer.use_default_system_prompt = False

def preprocess_function(examples):
    batch_size = len(examples[text_column])
    # inputs = [f"{text_column} : {x} Label : " for x in examples[text_column]]
    inputs = [get_mistral_prompt(x) for x in examples[text_column]]
    targets = [str(x) for x in examples[label_column]]
    model_inputs = tokenizer(inputs)
    labels = tokenizer(targets)
    for i in range(batch_size):
        sample_input_ids = model_inputs["input_ids"][i]
        label_input_ids = labels["input_ids"][i] + [tokenizer.pad_token_id]
        # print(i, sample_input_ids, label_input_ids)
        model_inputs["input_ids"][i] = sample_input_ids + label_input_ids
        labels["input_ids"][i] = [-100] * len(sample_input_ids) + label_input_ids
        model_inputs["attention_mask"][i] = [1] * len(model_inputs["input_ids"][i])
    # print(model_inputs)
    for i in range(batch_size):
        sample_input_ids = model_inputs["input_ids"][i]
        label_input_ids = labels["input_ids"][i]

        #padding if length of this example is smaller than max_seq_length
        model_inputs["input_ids"][i] = [tokenizer.pad_token_id] * (
            max_length - len(sample_input_ids)
        ) + sample_input_ids
        model_inputs["attention_mask"][i] = [0] * (max_length - len(sample_input_ids)) + model_inputs[
            "attention_mask"
        ][i]
        labels["input_ids"][i] = [-100] * (max_length - len(sample_input_ids)) + label_input_ids

        #if sequence length of this example exceeds max_seq_length, we're performing truncation here
        model_inputs["input_ids"][i] = torch.tensor(model_inputs["input_ids"][i][:max_length])
        model_inputs["attention_mask"][i] = torch.tensor(model_inputs["attention_mask"][i][:max_length])
        labels["input_ids"][i] = torch.tensor(labels["input_ids"][i][:max_length])
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

transcripts = ["""
Agent: Hello may I know why you're calling today?
Customer: Yeah, I'm calling to cancel my insurance policy
Agent: May I ask you why you chose to do so?
Customer: Yeah, I'm just not interested in the product you have to offer anymore
Agent: That is totally understandable""",
"""
Agent: Hello
Customer: Hello
Agent: Have a nice day
Customer: Thanks"""]

rfcs = ["Customer called to cancel their insurance policy.", "No reason identified."]

import pandas as pd
df = pd.DataFrame({"Transcript": transcripts, "RFC": rfcs})

from datasets import Dataset, DatasetDict
rfc_dataset = DatasetDict()
rfc_dataset["train"] = Dataset.from_pandas(df)

formatted_dataset = rfc_dataset.map(
    preprocess_function,
    batched=True,
    num_proc=1,
    remove_columns=rfc_dataset["train"].column_names,
    load_from_cache_file=False,
    desc="Running tokenizer on dataset",
)

formatted_dataset = formatted_dataset.shuffle(seed = random_seed)

train_dataset = formatted_dataset["train"]

from transformers import default_data_collator

train_dataloader = DataLoader(
    train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True
)

from peft import get_peft_config, get_peft_model, PrefixTuningConfig, TaskType, PeftType
from torch.utils.data import DataLoader
from transformers import get_linear_schedule_with_warmup

peft_config = PrefixTuningConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=num_virtual_tokens, prefix_projection = False)

model = get_peft_model(model, peft_config)

optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

lr_scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=(len(train_dataloader) * num_epochs),
)

run_name = "prefix-tuning-v1"

with mlflow.start_run(run_name = run_name):
    for epoch in tqdm(range(num_epochs), total = num_epochs):

        model.train()
        total_loss = 0

        for step, batch in enumerate(tqdm(train_dataloader)):
            batch = {k: v.to(torch.device("cuda")) for k, v in batch.items()}
            outputs = model(**batch)
            loss = outputs.loss
            total_loss += loss.detach().float()
            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

        train_epoch_loss = total_loss / len(train_dataloader)
        train_ppl = torch.exp(train_epoch_loss)
        print(f"{epoch=}: {train_ppl=} {train_epoch_loss=}")

        mlflow.log_param("epoch", epoch + 1, step = epoch)
        mlflow.log_param("train_loss", train_epoch_loss, step = epoch)
        mlflow.log_param("train perplexity", train_ppl, step = epoch)

    mlflow.end_run()

And this is the exact error message I'm seeing:

RuntimeError: The size of tensor a (530) must match the size of tensor b (500) at non-singleton dimension 3

The difference between 530 & 500 is num_virtual_tokens (30 in this case) and it seems like this always happens.

There was some investigation done by @Vincent-Li-9701 here, but the bug is still unresolved.

@BenjaminBossan please let me know if you need anything else.

@BenjaminBossan
Copy link
Member

Thanks for providing this reproducer. I could condense the code to the following:

import torch
from transformers import MistralConfig, MistralForCausalLM
from peft import PrefixTuningConfig, get_peft_model

# using small mistral for testing, real mistral would also work
model_config = MistralConfig(
    vocab_size=32000,
    hidden_size=512,
    max_position_embeddings=32768,
    num_attention_heads=16,
    num_hidden_layers=8,
    num_key_value_heads=4,
)
model = MistralForCausalLM(model_config)

config = PrefixTuningConfig(task_type="CAUSAL_LM", num_virtual_tokens=30)
model = get_peft_model(model, config)
model.config.use_cache = False

input_ids = torch.tensor([[1, 1, 1], [1, 2, 1]])
attention_mask = torch.tensor([[1, 1, 1], [1, 0, 1]])
outputs = model(input_ids=input_ids, attention_mask=attention_mask)

which gives the same error:

    outputs = model(input_ids=input_ids, attention_mask=attention_mask)
  File "...site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "...site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/vinh/work/forks/peft/src/peft/peft_model.py", line 1126, in forward
    return self.base_model(
  File "...site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "...site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "...site-packages/transformers/models/mistral/modeling_mistral.py", line 1157, in forward
    outputs = self.model(
  File "...site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "...site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "...site-packages/transformers/models/mistral/modeling_mistral.py", line 1004, in forward
    attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
  File "...site-packages/transformers/modeling_attn_mask_utils.py", line 398, in _prepare_4d_causal_attention_mask_for_sdpa
    expanded_4d_mask = attn_mask_converter.to_4d(
  File "...site-packages/transformers/modeling_attn_mask_utils.py", line 137, in to_4d
    expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min)
RuntimeError: The size of tensor a (33) must match the size of tensor b (3) at non-singleton dimension 3

Unfortunately, even after some digging, I couldn't figure out how to fix this issue yet. I asked around, let's see if someone has a solution.

@vikram71198
Copy link

vikram71198 commented Mar 5, 2024

One (important) difference between the implementations of Prefix Tuning & other PeFT techniques is evident in the forward() method of PeftModelForCausalLM here.

if peft_config.peft_type == PeftType.PREFIX_TUNING:
        past_key_values = self.get_prompt(batch_size)
        return self.base_model(
                input_ids=input_ids, inputs_embeds=inputs_embeds, past_key_values=past_key_values, **kwargs
        )
else:
    if inputs_embeds is None:
        inputs_embeds = self.word_embeddings(input_ids)
    # concat prompt labels
    if labels is not None:
        prefix_labels = torch.full((batch_size, peft_config.num_virtual_tokens), -100).to(labels.device)
        kwargs["labels"] = torch.cat((prefix_labels, labels), dim=1)
    prompts = self.get_prompt(batch_size=batch_size, task_ids=task_ids)
    prompts = prompts.to(inputs_embeds.dtype)
    inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1)
    return self.base_model(inputs_embeds=inputs_embeds, **kwargs)

It seems like past_key_values is passed as a separate argument for Prefix Tuning, while its prepended to the input otherwise.

Could this be causing the issue?

I picked up on this from #870.

@BenjaminBossan
Copy link
Member

It seems like past_key_values is passed as a separate argument for Prefix Tuning, while its prepended to the input otherwise.

It could be related, but I don't know enough about these techniques to be sure. Whatever the cause is, it has to be depend on model architecture, as some models work and others don't. I modified the above snippet to check multiple models like so:

import torch
from transformers import MistralConfig, MistralForCausalLM, AutoModelForCausalLM
from peft import PrefixTuningConfig, get_peft_model


def get_model(name):
    if name == "mistral":
        model_config = MistralConfig(
            vocab_size=32000,
            hidden_size=512,
            max_position_embeddings=32768,
            num_attention_heads=16,
            num_hidden_layers=8,
            num_key_value_heads=4,
        )
        return MistralForCausalLM(model_config)

    return AutoModelForCausalLM.from_pretrained(name)


for name in ("gpt2", "facebook/opt-125m", "bigscience/bloomz-560m",  "HuggingFaceH4/tiny-random-LlamaForCausalLM", "mistral"):
    config = PrefixTuningConfig(task_type="CAUSAL_LM", num_virtual_tokens=30)
    model = get_model(name)
    model = get_peft_model(model, config)
    model.config.use_cache = False

    input_ids = torch.tensor([[1, 1, 1], [1, 2, 1]])
    attention_mask = torch.tensor([[1, 1, 1], [1, 0, 1]])
    try:
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        print(f"PASS: model {name} passed")
    except Exception as e:
        print(f"FAIL: model {name} failed with {e}")

and I got:

PASS: model gpt2 passed
PASS: model facebook/opt-125m passed
PASS: model bigscience/bloomz-560m passed
FAIL: model HuggingFaceH4/tiny-random-LlamaForCausalLM failed with 'tuple' object has no attribute 'update'
FAIL: model mistral failed with The size of tensor a (33) must match the size of tensor b (3) at non-singleton dimension 3

with transformers 4.38.1.

Regarding the Llama error, that could be some kv-cache thing, as I get a different error with older transformers versions -- with 4.36 and 4.37, got the same as error as for mistral. Whether the mistral error could also be related to that, I'm not sure.

Pinging @younesbelkada @pacman100 for help.

@vikram71198
Copy link

I slightly modified your script to:

import torch
from transformers import MistralConfig, MistralForCausalLM, AutoModelForCausalLM
from peft import PrefixTuningConfig, get_peft_model

def get_model(name):
    if name == "mistral":
        model_config = MistralConfig(
            vocab_size=32000,
            hidden_size=512,
            max_position_embeddings=32768,
            num_attention_heads=16,
            num_hidden_layers=8,
            num_key_value_heads=4,
        )
        return MistralForCausalLM(model_config)

    return AutoModelForCausalLM.from_pretrained(name)


for name in ("gpt2", "facebook/opt-125m", "bigscience/bloomz-560m",  "HuggingFaceH4/tiny-random-LlamaForCausalLM", "mistral", "teknium/OpenHermes-2.5-Mistral-7B", "meta-llama/Llama-2-7b-chat-hf"):
    config = PrefixTuningConfig(task_type="CAUSAL_LM", num_virtual_tokens=30)
    model = get_model(name)
    model = get_peft_model(model, config)
    model = model.to(torch.device("cuda"))
    model.config.use_cache = False

    input_ids = torch.tensor([[1, 1, 1], [1, 2, 1]]).to(torch.device("cuda"))
    attention_mask = torch.tensor([[1, 1, 1], [1, 0, 1]]).to(torch.device("cuda"))

    try:
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        print(f"PASS: model {name} passed")
    except Exception as e:
        print(f"FAIL: model {name} failed with {e}")
    finally:
        torch.cuda.empty_cache()
        model.to(torch.device("cpu"))
        del model

With the script above, this is what I see with transformers == 4.34.1 :

PASS: model gpt2 passed
PASS: model facebook/opt-125m passed
PASS: model bigscience/bloomz-560m passed
PASS: model HuggingFaceH4/tiny-random-LlamaForCausalLM passed
FAIL: model mistral failed with Sizes of tensors must match except in dimension 2. Expected size 16 but got size 4 for tensor number 1 in the list.
FAIL: model teknium/OpenHermes-2.5-Mistral-7B failed with Sizes of tensors must match except in dimension 2. Expected size 32 but got size 8 for tensor number 1 in the list.
PASS: model meta-llama/Llama-2-7b-chat-hf passed

So, it seems like LlamaForCausalLM is working with 4.34.1.

But, clearly there's something nefarious going on. Hope we can find & fix this, because this locks out the community from using this fine tuning method on many, many models that are actually more relevant now.

@vikram71198
Copy link

vikram71198 commented Mar 6, 2024

So, I downgraded to transformers == 4.34.1 and Prefix Tuning seemed to run bug free with meta-llama/Llama-2-7b-chat-hf.

But, after fine tuning, I mostly saw gibberish outputs from the fine tuned model. I'm about 95% sure that there's nothing wrong with my implementation either & am starting to think that this bug in Prefix Tuning silently creeps up in cases where we don't get the aforementioned RuntimeError too.

So, it seems to me, that currently Prefix Tuning just doesn't work at all.

Hoping someone has a fix for this.

@younesbelkada @pacman100 @BenjaminBossan

@vikram71198
Copy link

Any updates on this?

@BenjaminBossan @younesbelkada @pacman100

@GGG-c
Copy link

GGG-c commented Jul 25, 2024

same

@BenjaminBossan
Copy link
Member

Good and bad news. With the latest transformers (4.43.2) and PEFT version (0.12.0), the tensor size mismatch error is no longer occurring. However, there is a new error:

'tuple' object has no attribute 'get_seq_length'

This is because past_key_values provided by prefix tuning is a tuple of tensors but transformers assumes it's either a Cache instance or None. I have asked internally about clarification of the type and will report back if I get an answer.

@jrrw10
Copy link

jrrw10 commented Aug 1, 2024

This is because past_key_values provided by prefix tuning is a tuple of tensors but transformers assumes it's either a Cache instance or None

I attempted changing past_key_values to a DynamicCache like so in PeftModel.get_prompt():

past_key_values = DynamicCache.from_legacy_cache(past_key_values)

Calling trainer.train() on a Mistral model is resulting in a shape mismatch during attention computation:

RuntimeError                              Traceback (most recent call last)
<ipython-input-9-3435b262f1ae> in <cell line: 1>()
----> 1 trainer.train()

15 frames
/usr/local/lib/python3.10/dist-packages/transformers/models/mistral/modeling_mistral.py in forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, **kwargs)
    471         is_causal = True if causal_mask is None and q_len > 1 else False
    472 
--> 473         attn_output = torch.nn.functional.scaled_dot_product_attention(
    474             query_states,
    475             key_states,

RuntimeError: The expanded size of the tensor (2078) must match the existing size (1054) at non-singleton dimension 3.  Target sizes: [4, 32, 1024, 2078].  Tensor sizes: [4, 1, 1024, 1054]

Any idea why this is? @BenjaminBossan

@BenjaminBossan
Copy link
Member

@jrrw10 What version of transformers and PEFT are you using? Do you still get the same error when upgrading to the latest versions from main for the two packages?

@jrrw10
Copy link

jrrw10 commented Aug 1, 2024

@BenjaminBossan Yes, I get this error with the latest version of main for both PEFT and transformers.

transformers==4.44.0.dev0
peft==0.12.1.dev0

@BenjaminBossan
Copy link
Member

Okay, strange. I tried with transformers commit 51ab25e2932da15511ced35bcbdfa92d25c4794c and PEFT commit 269aba5303216984d03a8be2e6ef270a9130550f. Following your suggestion, I added

            from transformers import DynamicCache
            past_key_values = DynamicCache.from_legacy_cache(past_key_values)

before this line. Then I ran the example shown above and I get:

PASS: model gpt2 passed
PASS: model facebook/opt-125m passed
PASS: model bigscience/bloomz-560m passed
/home/name/anaconda3/envs/peft/lib/python3.11/site-packages/transformers/generation/configuration_utils.py:546: UserWarning: `pad_token_id` should be positive but got -1. This will cause errors when batch generating, if there is padding. Please set `pad_token_id` explicitly by `model.generation_config.pad_token_id=PAD_TOKEN_ID` to avoid errors in generation, and ensure your `input_ids` input does not have negative values.
  warnings.warn(
PASS: model HuggingFaceH4/tiny-random-LlamaForCausalLM passed
PASS: model llamafactory/tiny-random-Llama-3 passed
PASS: model mistral passed
PASS: model peft-internal-testing/tiny-dummy-qwen2 passed

(note that use_cache = False is unnecessary in the snippet)

@jrrw10
Copy link

jrrw10 commented Aug 1, 2024

@BenjaminBossan Thank you for testing this out. I get the same output with the above testing for a forward pass. While the forward pass tests are successful, calling trainer.train() is what is resulting in the shape mismatch for me.

Debugging prints in transformers/models/mistral/modeling_mistral.py in MistralSdpaAttention.forward()reveal the following shapes of query_states, key_states, value_states, and causal_mask:

...
Before attention computation - Query States: torch.Size([4, 32, 1024, 128]), Key States: torch.Size([4, 32, 1054, 128]), Value States: torch.Size([4, 32, 1054, 128]), mask: torch.Size([4, 1, 1024, 1054])
Before attention computation - Query States: torch.Size([4, 32, 1024, 128]), Key States: torch.Size([4, 32, 2078, 128]), Value States: torch.Size([4, 32, 2078, 128]), mask: torch.Size([4, 1, 1024, 1054])

Which then hits the mismatch:

[/usr/local/lib/python3.10/dist-packages/transformers/models/mistral/modeling_mistral.py](https://localhost:8080/#) in forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, **kwargs)
    475         is_causal = True if causal_mask is None and q_len > 1 else False
    476 
--> 477         attn_output = torch.nn.functional.scaled_dot_product_attention(
    478             query_states,
    479             key_states,
    480            dropout_p=self.attention_dropout if self.training else 0.0,
    481            is_causal=is_causal,
    482 )


RuntimeError: The size of tensor a (2078) must match the size of tensor b (1054) at non-singleton dimension 3

It seems the mismatch occurs because the key length (2078) does not match the key length in the mask (1054).

Do you have any insights on why this discrepancy might be occurring during trainer.train() and how to address this? Thanks again for the help!

@jrrw10
Copy link

jrrw10 commented Aug 1, 2024

It seems the mismatch occurs because the key length (2078) does not match the key length in the mask (1054).

I was experimenting in the MistralSdpaAttention class forward() method after discovering the attention mask was the wrong size.

I changed causal_mask to:

causal_mask = attention_mask
if attention_mask is not None:
    # Ensure the mask matches key length
    if causal_mask.shape[-1] != key_states.shape[-2]:
        if causal_mask.shape[-1] < key_states.shape[-2]:
            padding = key_states.shape[-2] - causal_mask.shape[-1]
            causal_mask = torch.nn.functional.pad(causal_mask, (0, padding))
        else:
            causal_mask = causal_mask[:, :, :, :key_states.shape[-2]]

since it was incorrectly sized before. This allowed training to get kicked off, however I am getting strange output:

Step Training Loss Validation Loss
10 9.897600 9.929580
20 9.884000 9.929580
30 9.878300 9.929580
40 9.895400 9.929580
50 9.874400 9.929580

Similarly, if the model is loaded with attn_implementation="flash_attention_2",
i.e

base_model = AutoModelForCausalLM.from_pretrained(
    base_model_id,
    attn_implementation="flash_attention_2",
    device_map='cuda',
    torch_dtype=torch.bfloat16,)

I did not have to edit anything in the respective class, and training was kicked off with almost identical constant loss. (Typical learning rate among other hyperparameters). Any help/update is much appreciated.

@BenjaminBossan
Copy link
Member

Thanks @jrrw10 for this continued investigation. I tried training mistral and did not run into the errors you reported, only the fix with the DynamicCache was required. The loss also improved nicely. Here is the code that I used:

import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling, Trainer, TrainingArguments
from peft import LoraConfig, PrefixTuningConfig, get_peft_model

model_id = "mistralai/Mistral-7B-v0.1"
# "teknium/OpenHermes-2.5-Mistral-7B" works as well
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map=0,
    torch_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token

config = PrefixTuningConfig(task_type="CAUSAL_LM", num_virtual_tokens=10)
model = get_peft_model(model, config)

def process(samples):
    tokenized = tokenizer(samples["quote"], truncation=True, max_length=128)
    return tokenized

data = load_dataset("ybelkada/english_quotes_copy")
data = data.map(process, batched=True)

trainer = Trainer(
    model=model,
    train_dataset=data["train"],
    args=TrainingArguments(
        num_train_epochs=1,
        per_device_train_batch_size=4,
        bf16=True,
        learning_rate=3e-4,
        logging_steps=10,
        output_dir="/tmp/peft/869",
    ),
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
trainer.train()

This is the loss that I get:

step train loss
10 8.9031
20 8.4677
30 8.4124
40 8.231
50 7.9302
60 7.996
70 8.1211
80 7.7546
90 8.0796
100 7.5609
110 7.67
120 7.3477
130 7.5419
140 7.4619
150 7.1993
160 7.4286
170 7.0095
180 7.1371
190 7.2012
200 6.9664
210 6.9731
220 7.0729
230 7.1143
240 7.1078
250 6.9686
260 6.8469
270 6.7001
280 6.7319
290 7.0498
300 6.8004
310 6.755
320 6.8177
330 6.4536
340 6.7542
350 6.6656
360 6.7039
370 6.8224
380 6.7475
390 6.6785
400 6.6081
410 6.6096
420 6.6462
430 6.9221
440 6.7322
450 6.6007
460 6.8302
470 6.6958
480 6.758
490 6.6558
500 6.4734
510 6.3721
520 6.9516
530 6.4559
540 6.3372
550 6.4975
560 6.484
570 6.4277
580 6.5164
590 6.6368
600 6.2812
610 6.4681
620 6.5265

Any idea what the difference could be?

@jrrw10
Copy link

jrrw10 commented Aug 6, 2024

@BenjaminBossan I've found the issue/s I was experiencing.

The difference between our implementations is that I am using gradient checkpointing like so:

model = AutoModelForCausalLM.from_pretrained(
   model_id,
   device_map=0,
   torch_dtype=torch.bfloat16,
)
model.gradient_checkpointing_enable()

This should be reproducible for you in your example above if you add that line. Here's the error:

[/usr/local/lib/python3.10/dist-packages/transformers/models/mistral/modeling_mistral.py](https://localhost:8080/#) in forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, **kwargs)
    475         is_causal = True if causal_mask is None and q_len > 1 else False
    476 
--> 477         attn_output = torch.nn.functional.scaled_dot_product_attention(
    478             query_states,
    479             key_states,

RuntimeError: The size of tensor a (72) must match the size of tensor b (41) at non-singleton dimension 3

Which is the exact error that I was experiencing before. However, when I manually fixed the attention mask size mismatch like I mentioned before, I was also met with the constant loss problem I showed above. I having problems reproducing this - maybe this was fixed in a recent commit I'm not sure.

In conclusion - I got my implementation to work by adding the DynamicCache fix, and gradient checkpointing currently breaks Prefix Tuning (attention mask size mismatch). Thank you for the help sir!

@BenjaminBossan
Copy link
Member

Thanks for digging deeper. Indeed, with gradient checkpointing, there is an issue. IIUC, adjusting the size of the causal mask is not the correct solution though, which could explain the bad losses you see. I printed the layer index, as well as the shapes of key_states and value_states once before and once after they get updated from the cache in this line:

https://github.com/huggingface/transformers/blob/e0d82534cc95b582ab072c1bbc060852ba7f9d51/src/transformers/models/mistral/modeling_mistral.py#L457

0 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
1 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
2 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
3 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
4 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
5 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
6 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
7 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
8 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
9 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
10 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
11 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
12 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
13 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
14 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
15 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
16 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
17 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
18 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
19 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
20 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
21 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
22 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
23 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
24 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
25 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
26 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
27 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
28 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
29 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
30 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
31 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 41, 128])
31 torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 72, 128]) torch.Size([4, 8, 31, 128]) torch.Size([4, 8, 72, 128])

The shapes are all correct except for the last line: We have sequence length of 31 before applying cache and 41 after applying cache, which is expected as we add 10 virtual tokens.

Now let's check the last line. As we can see, it's again layer 31, i.e. it looks like we're in the gradient checkpointing phase. We can see the wrong length of 72 after updating. This is 31+41, i.e. the updated cache is added on top of the states, which IMO is not the correct way of handling this. It should be the same as during the first time we visited this layer, i.e. 31+10. I wonder if gradient checkpointing + cache is generally broken or if we're using it incorrectly.

@BenjaminBossan
Copy link
Member

@jrrw10 FYI: #1962 (comment). So the workaround with creating a cache object to hold the past_key_values is a bit of a dead end.

@jrrw10
Copy link

jrrw10 commented Aug 7, 2024

@BenjaminBossan Thanks for the detailed explanation of this.

Just to confirm, Prefix Tuning does successfully train with the DynamicCache workaround and gradient checkpointing disabled as of right now - this is sufficient for me and hopefully other users until a rewrite is done to past_key_values. Hopefully this discussion helped!

@BenjaminBossan
Copy link
Member

Great that it works for you. I'd just be cautious with this as it's not using transformers as intended. This can be risky because:

  • It could work but not quite correctly, resulting in lower performance.
  • It could work with one architecture but not with others.
  • It could work now but break any time there is a transformers update.

@hammoudhasan
Copy link

hammoudhasan commented Aug 11, 2024

@jrrw10 do you mind sharing the final script here? I'm having one issue where the training loss goes down but the performance is bad when I do test inference - I suspect it's because I don't apply the chat template during training and inference. Do you optimize for the prefix with the chat template or ?

@BenjaminBossan do you know how your provided code could be modified to use chat models where templates needs to be applied?

@BenjaminBossan
Copy link
Member

do you know how your provided code could be modified to use chat models where templates needs to be applied?

Did you try calling apply_chat_template in the process function? Also check out the docs in chat templates.

@hammoudhasan
Copy link

@BenjaminBossan I got it sorted out (: thank you for the reply. It seems what was pushing me back was a very weird error I got with llama models (e.g TinyLlama/TinyLlama-1.1B-Chat-v1.0) where model(**batch) where batch contains the input_ids, labels and attention_mask values was leading to a weird dimensions error.

huggingface llama RuntimeError: Sizes of tensors must match except in dimension 2. Expected size 32 but got size 4 for tensor number 1 in the list.

Several people have reported having a similar error with no solution yet :-(

@BenjaminBossan
Copy link
Member

It seems what was pushing me back was a very weird error I got with llama models (e.g TinyLlama/TinyLlama-1.1B-Chat-v1.0) where model(**batch) where batch contains the input_ids, labels and attention_mask values was leading to a weird dimensions error.

You mean independent of PEFT?

@hammoudhasan
Copy link

@BenjaminBossan yep! Seems something related to the modeling files.

@BenjaminBossan
Copy link
Member

Just an update to all who encountered problems with this in the past: Could you please upgrade to the latest PEFT (v0.13.2+) and transformers versions (v4.45.2+) and report back if you're still encountering issues?

BenjaminBossan added a commit that referenced this issue Oct 24, 2024
See #869, #1962

Fix several issues caused by changes to cache in transformers. In
particular, past_key_values for prefix tuning is now converted to a
transformers Cache instance.

---------

Co-authored-by: Raushan Turganbay <[email protected]>
@DavdGao
Copy link

DavdGao commented Oct 25, 2024

Just an update to all who encountered problems with this in the past: Could you please upgrade to the latest PEFT (v0.13.2+) and transformers versions (v4.45.2+) and report back if you're still encountering issues?

@BenjaminBossan The error "tuple" object has no attribute 'get_seq_length' still exists.

My env:

  • python 3.11
  • transformers 4.46.0
  • peft 0.13.2
  • accelerate 1.0.1
  • deepspeed 0.15.3
  • tokenizers 0.20.1
  • torch 2.5.0

What I'm doing: train Qwen2.5-72B Instruct with Prefix-tuning

My code:

import peft
from transformers import AutoTokenizer, TrainingArguments, HfArgumentParser, AutoModelForCausalLM, Trainer
from dataclasses import dataclass, field
from peft import get_peft_model, PrefixTuningConfig
from typing import Dict, List

import warnings
import torch
import os
from sft_dataset import make_supervised_data_module, _print_rank

warnings.filterwarnings("ignore")
os.environ["WANDB_SILENT"] = "true"


@dataclass
class CustomizeArguments:
    model_name_or_path: str = field(default=None)
    tokenizer: str = field(default=None)
    data_path: List[str] = field(default=None)
    
    eval_path: str = field(default=None)

    max_sft_length: int = field(default=None)

    # Prefix-tuning
    pft_enable: bool = field(default=False)
    pft_num_virtual_tokens: int = field(default=None)

@dataclass
class TrainingArguments(TrainingArguments):
    optim: str = field(default="adamw_torch")


def train():
    parser = HfArgumentParser((TrainingArguments, CustomizeArguments))
    training_args, customize_args = parser.parse_args_into_dataclasses()

    tokenizer = AutoTokenizer.from_pretrained(customize_args.tokenizer, use_fast=False)
    
    data_modules = make_supervised_data_module(tokenizer, customize_args.data_path, customize_args.eval_path,customize_args.max_sft_length or tokenizer.model_max_length)

    model = AutoModelForCausalLM.from_pretrained(customize_args.model_name_or_path)
    model.gradient_checkpointing_enable()
    
    assert customize_args.lora_enable and not customize_args.pft_enable or not customize_args.lora_enable and customize_args.pft_enable or not (customize_args.lora_enable or customize_args.pft_enable), "Only one of lora and pft can be enabled."

    if customize_args.pft_enable:
        _print_rank("Prefix-tuning is enabled.")
        pft_config = PrefixTuningConfig(
            peft_type=peft.PeftType.PREFIX_TUNING,
            inference_mode=False,
            task_type="CAUSAL_LM",
            num_virtual_tokens=customize_args.pft_num_virtual_tokens,
        )
        model = get_peft_model(model, pft_config)
        model.print_trainable_parameters()
    else:
        _print_rank("Fine-tuning full parameters.")

    trainer = Trainer(
        model=model,
        tokenizer=tokenizer,
        args=training_args,

        **data_modules,
    )

    print("Start training...")
    trainer.train()
    
    model.save_pretrained(training_args.output_dir)
    tokenizer.save_pretrained(training_args.output_dir)


if __name__ == "__main__":
    train()

The error information:

[rank7]: Traceback (most recent call last):
[rank7]:   File "/cpfs/data/gaodawei.gdw/train/sft.py", line 112, in <module>
[rank7]:     train()
[rank7]:   File "/cpfs/data/gaodawei.gdw/train/sft.py", line 101, in train
[rank7]:     trainer.train()
[rank7]:   File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/transformers/trainer.py", line 2122, in train
[rank7]:     return inner_training_loop(
[rank7]:            ^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/transformers/trainer.py", line 2474, in _inner_training_loop
[rank7]:     tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[rank7]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/transformers/trainer.py", line 3572, in training_step
[rank7]:     loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/transformers/trainer.py", line 3625, in compute_loss
[rank7]:     outputs = model(**inputs)
[rank7]:               ^^^^^^^^^^^^^^^
[rank7]:   File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank7]:     return self._call_impl(*args, **kwargs)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank7]:     return forward_call(*args, **kwargs)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
[rank7]:     ret_val = func(*args, **kwargs)
[rank7]:               ^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/deepspeed/runtime/engine.py", line 1899, in forward
[rank7]:     loss = self.module(*inputs, **kwargs)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank7]:     return self._call_impl(*args, **kwargs)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1844, in _call_impl
[rank7]:     return inner()
[rank7]:            ^^^^^^^
[rank7]:   File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in inner
[rank7]:     result = forward_call(*args, **kwargs)
[rank7]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/peft/peft_model.py", line 1680, in forward
[rank7]:     return self.base_model(input_ids=input_ids, inputs_embeds=inputs_embeds, **kwargs)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank7]:     return self._call_impl(*args, **kwargs)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1844, in _call_impl
[rank7]:     return inner()
[rank7]:            ^^^^^^^
[rank7]:   File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in inner
[rank7]:     result = forward_call(*args, **kwargs)
[rank7]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 1164, in forward
[rank7]:     outputs = self.model(
[rank7]:               ^^^^^^^^^^^
[rank7]:   File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank7]:     return self._call_impl(*args, **kwargs)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1844, in _call_impl
[rank7]:     return inner()
[rank7]:            ^^^^^^^
[rank7]:   File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in inner
[rank7]:     result = forward_call(*args, **kwargs)
[rank7]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 857, in forward
[rank7]:     past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
[rank7]:                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]: AttributeError: 'tuple' object has no attribute 'get_seq_length'

@DavdGao
Copy link

DavdGao commented Oct 25, 2024

When I remove the line model.gradient_checkpointing_enable(), and change the base model to Qwen/Qwen2.5-7B-Instruct due to limited GPT memory. The training is running normally.

@BenjaminBossan
Copy link
Member

Thanks for reporting @DavdGao. What model did you use that led to the error? Could you perhaps share how you called the script? Also, we recently merged some fixes to PEFT, so if you could install from source and check again, that would be helpful to know.

@DavdGao
Copy link

DavdGao commented Oct 28, 2024

Thanks for reporting @DavdGao. What model did you use that led to the error? Could you perhaps share how you called the script? Also, we recently merged some fixes to PEFT, so if you could install from source and check again, that would be helpful to know.

@BenjaminBossan

  • Model: The model I'm using is Qwen/Qwen2.5-72B-Instruct downloaded from HuggingFace.
  • I'm using deepspeed with the following scripts, it maybe a little complex. If you need any further details, just let me know.
    • basic.sh:
#!/bin/bash

set -e 

export CUTLASS_PATH=/cpfs/data/gaodawei.gdw/cutlass

# task
path_data=${1}
path_model=${2}
path_tokenizer=${3}

# params
bs=${4}
lr=${5}
wd=${6}
epo=${7}

pft_enable=${8}
pft_num_virtual_tokens=${9}

ds_config=${10}

name_run=${11}

eval_dir=${12}

filename_with_extension=${path_data##*/}
filename_without_extension=${filename_with_extension%.*}

second_last_dir="${path_data%/*}"
second_last_dir="${second_last_dir##*/}"

third_last_dir="${path_data%/*/*}"
third_last_dir="${third_last_dir##*/}"

name_data="${third_last_dir}_${second_last_dir}_${filename_without_extension}"

name_run=${name_run}/${bs}bs_${lr}lr_${wd}wd_${epo}epo

if [ "$pft_enable" = "True" ]; then
    name_run=${name_run}_${pft_num_virtual_tokens}vtoken
fi

echo "Task: $name_run"

path_save=/home/data/shared/checkpoints/prefix-tuning/${name_run}
mkdir -p ${path_save}

# wandb
WANDB_PROJECT=PFT
WANDB_NAME=${name_run}

ROOT=/cpfs/data/gaodawei.gdw/train

path_ds_config=${ds_config}

cd ${ROOT}

gas=$((${bs}/8))

deepspeed --num_gpus 8 --num_nodes 1 --master_port 5900 \
    ${ROOT}/sft.py \
    --model_name_or_path ${path_model} \
    --tokenizer ${path_tokenizer} \
    --do_train \
    --data_path ${path_data} \
    --eval_strategy "no"\
    --bf16 True \
    --output_dir ${path_save} \
    --num_train_epochs ${epo} \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps ${gas} \
    --save_strategy "epoch" \
    --save_total_limit 99999 \
    --learning_rate ${lr} \
    --weight_decay ${wd} \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_steps 10 \
    --log_level info \
    --tf32 True \
    --wandb_project ${WANDB_PROJECT} \
    --wandb_name ${WANDB_NAME} \
    --deepspeed ${path_ds_config} \
    --save_only_model True \
    --pft_enable ${pft_enable} \
    --pft_num_virtual_tokens ${pft_num_virtual_tokens}  2>&1 | tee ${path_save}/training_log.txt
- The final training scripts:
#!/bin/bash

set -e

name_run=Prefix-Tuning_Qwen2.5-72B-Instruct_v1

path_data=/cpfs/data/gaodawei.gdw/data/train/v3

train_dir=${path_data}/split_train
eval_dir=${path_data}/split_val

path_model=/home/data/shared/checkpoints/qwen/qwen2.5/Qwen2.5-7B-Instruct
ds_config=/cpfs/data/gaodawei.gdw/scripts/sft/multi_node/configs/ds_config_stage0.json

path_tokenizer=${path_model}

global_batch_size=32

pft_enable=True

LRS=(1e-5 5e-5 1e-4)
WDS=(0.01)
EPOS=(10)
VTS=(10 50 100)

for epo in ${EPOS[*]};
do
   for wd in ${WDS[*]};
   do
       for lr in ${LRS[*]};
       do
           for pft_num_virtual_tokens in ${VTS[*]};
           do
               bash /cpfs/data/gaodawei.gdw/train/scripts_prefix-tuning/basic.sh \
                   "${train_dir}" \
                   ${path_model} \
                   ${path_tokenizer} \
                   ${global_batch_size} \
                   ${lr} \
                   ${wd} \
                   ${epo} \
                   ${pft_enable} \
                   ${pft_num_virtual_tokens} \
                   ${ds_config} \
                   ${name_run} \
                   ${eval_dir}
               echo "######################################################################"
           done
       done
   done 
done

@DavdGao
Copy link

DavdGao commented Oct 28, 2024

@BenjaminBossan

I have tried to install the latest peft from source code. When I enable model.gradient_checkpointing_enable(), the following error:

  • peft version: 0.13.3.dev0

  • error information

[rank7]: Traceback (most recent call last):
[rank7]:   File "/cpfs/data/gaodawei.gdw/train/sft.py", line 112, in <module>
[rank7]:     train()
[rank7]:   File "/cpfs/data/gaodawei.gdw/train/sft.py", line 101, in train
[rank7]:     trainer.train()
[rank7]:   File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/transformers/trainer.py", line 2122, in train
[rank7]:     return inner_training_loop(
[rank7]:            ^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/transformers/trainer.py", line 2474, in _inner_training_loop
[rank7]:     tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[rank7]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/transformers/trainer.py", line 3606, in training_step
[rank7]:     self.accelerator.backward(loss, **kwargs)
[rank7]:   File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/accelerate/accelerator.py", line 2238, in backward
[rank7]:     self.deepspeed_engine_wrapped.backward(loss, **kwargs)
[rank7]:   File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/accelerate/utils/deepspeed.py", line 186, in backward
[rank7]:     self.engine.backward(loss, **kwargs)
[rank7]:   File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
[rank7]:     ret_val = func(*args, **kwargs)
[rank7]:               ^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/deepspeed/runtime/engine.py", line 2020, in backward
[rank7]:     self.optimizer.backward(loss, retain_graph=retain_graph)
[rank7]:   File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
[rank7]:     ret_val = func(*args, **kwargs)
[rank7]:               ^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/deepspeed/runtime/zero/stage3.py", line 2250, in backward
[rank7]:     self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
[rank7]:   File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/deepspeed/runtime/fp16/loss_scaler.py", line 63, in backward
[rank7]:     scaled_loss.backward(retain_graph=retain_graph)
[rank7]:   File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/torch/_tensor.py", line 521, in backward
[rank7]:     torch.autograd.backward(
[rank7]:   File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/torch/autograd/__init__.py", line 289, in backward
[rank7]:     _engine_run_backward(
[rank7]:   File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/torch/autograd/graph.py", line 769, in _engine_run_backward
[rank7]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/torch/autograd/function.py", line 306, in apply
[rank7]:     return user_fn(self, *args)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/torch/utils/checkpoint.py", line 296, in backward
[rank7]:     outputs = ctx.run_function(*detached_inputs)
[rank7]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank7]:     return self._call_impl(*args, **kwargs)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl
[rank7]:     result = forward_call(*args, **kwargs)
[rank7]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 623, in forward
[rank7]:     hidden_states, self_attn_weights, present_key_value = self.self_attn(
[rank7]:                                                           ^^^^^^^^^^^^^^^
[rank7]:   File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank7]:     return self._call_impl(*args, **kwargs)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl
[rank7]:     result = forward_call(*args, **kwargs)
[rank7]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 544, in forward
[rank7]:     attn_output = torch.nn.functional.scaled_dot_product_attention(
[rank7]:                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]: RuntimeError: The expanded size of the tensor (140) must match the existing size (75) at non-singleton dimension 3.  Target sizes: [1, 64, 65, 140].  Tensor sizes: [1, 1, 65, 75]

Similar error occurs when I change the model to Qwen/Qwen2.5-7B-Instruct

@BenjaminBossan
Copy link
Member

Thanks @DavdGao I can reproduce the error. I'll investigate and get back to you when I find something out.

@DavdGao
Copy link

DavdGao commented Oct 29, 2024

@BenjaminBossan Thank you for your assistance. I greatly appreciate it.

Currently, I removed the line model.gradient_checkpointing_enable() and tried to train a smaller 7B model. The training went well but the inference failed. I have described this error in #2134 , and I think it may be relevant.

BenjaminBossan added a commit to BenjaminBossan/peft that referenced this issue Oct 31, 2024
See huggingface#869

Since transformers is moving to the new cache implementation, we had to
change prefix tuning to use this too. However, caching does not work
with gradient checkpointing. Therefore, this currently runs into an
error about size mismatches.

Now, PEFT checks for gradient checkpointing and raises a helpful error.
@BenjaminBossan
Copy link
Member

@DavdGao After some investigation with colleagues, we came to the conclusion that unfortunately, prefix tuning won't work with gradient checkpointing. The reason is that transformers now made some changes to caching, which is reflected in prefix tuning now using DynamicCache. But gradient checkpointing does not work properly with cache. Therefore, the only solution is to either use a different PEFT method or not using gradient checkpointing. I created a PR, #2191, to raise a proper error when we see this situation.

BenjaminBossan added a commit that referenced this issue Nov 1, 2024
See #869

Since transformers is moving to the new cache implementation, we had to
change prefix tuning to use this too. However, caching does not work
with gradient checkpointing. Therefore, this currently runs into an
error about size mismatches.

Now, PEFT checks for gradient checkpointing and raises a helpful error.
@DavdGao
Copy link

DavdGao commented Nov 5, 2024

@DavdGao After some investigation with colleagues, we came to the conclusion that unfortunately, prefix tuning won't work with gradient checkpointing. The reason is that transformers now made some changes to caching, which is reflected in prefix tuning now using DynamicCache. But gradient checkpointing does not work properly with cache. Therefore, the only solution is to either use a different PEFT method or not using gradient checkpointing. I created a PR, #2191, to raise a proper error when we see this situation.

Thanks a lot.
I wonder if there's anyway to support larger llms for prefix tuning? In my case, even with deepspeed zero optimization (stage3) and A100 GPU (80GB memory), I still met OOM error when training 72B llms with prefix tuning

@BenjaminBossan
Copy link
Member

@DavdGao It is unfortunate that caching precludes the use of gradient checkpointing, thus resulting in higher memory usage. Not sure if you already tried quantization, but that should work with prefix tuning.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

8 participants