Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError in Prefix #870

Closed
2 of 4 tasks
liuao743 opened this issue Aug 28, 2023 · 8 comments
Closed
2 of 4 tasks

RuntimeError in Prefix #870

liuao743 opened this issue Aug 28, 2023 · 8 comments

Comments

@liuao743
Copy link

liuao743 commented Aug 28, 2023

System Info

when I use prefix tuning on llama, it occurs :

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward

The peft version is 0.4.0

I have tried all the other tuning methods supported by PEFT, This problem did not occur.

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

my code is as follows:

prefix_config = PrefixTuningConfig(
                task_type=TaskType.CAUSAL_LM,
                inference_mode=False,
                encoder_hidden_size=finetuning_args.encoder_hidden_size,
                num_virtual_tokens=finetuning_args.num_virtual_tokens,
                token_dim=finetuning_args.token_dim,
                num_transformer_submodules=finetuning_args.num_transformer_submodules,
                num_attention_heads=finetuning_args.num_attention_heads,
                num_layers=finetuning_args.num_layers,
                prefix_projection=finetuning_args.prefix_projection
            )
            model = get_peft_model(model, prefix_config)

The Specific values are:

--encoder_hidden_size 4096 \
--num_virtual_tokens 20 \
--token_dim 4096 \
--num_transformer_submodules 1 \
--num_attention_heads 32 \
--num_layers 32 \
--num_layer_trainable 32 \

Expected behavior

sove this problem

@BenjaminBossan
Copy link
Member

Could you please provide more information: What model are you using, what data, how do you train the model, what is the full stacktrace? Otherwise, it's hard to help you.

@Vincent-Li-9701
Copy link

@BenjaminBossan we had two very similar errors here and here here

@liuao743
Copy link
Author

liuao743 commented Aug 29, 2023

@BenjaminBossan Sorry I can't post all the code. Because I used LLaMA-Effcient-Tuning and added the prefix tuning method on it, it would be very troublesome to extract this part of the code, and the amount of code would be huge if it was not extracted.
I tried to find the cause of the problem myself.
I use PeftModelForCausalLM and in forward method from this class:

if peft_config.peft_type == PeftType.PREFIX_TUNING:
            past_key_values = self.get_prompt(batch_size)
            return self.base_model(input_ids=input_ids, past_key_values=past_key_values, **kwargs)
        else:
            if inputs_embeds is None:
                inputs_embeds = self.word_embeddings(input_ids)
            # concat prompt labels
            if labels is not None:
                prefix_labels = torch.full((batch_size, peft_config.num_virtual_tokens), -100).to(self.device)
                kwargs["labels"] = torch.cat((prefix_labels, labels), dim=1)
            prompts = self.get_prompt(batch_size=batch_size)
            prompts = prompts.to(inputs_embeds.dtype)
            inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1)
            return self.base_model(inputs_embeds=inputs_embeds, **kwargs)

This is the main difference between prefix and other peft methods, he passes to the model past_key_values instead of as input, so I checked the forward method of BloomModel in transformer, as follows(in transformers/models/bloom/modeling_bloom.py, from line 761 to 791):

for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            if self.gradient_checkpointing and self.training:

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        # None for past_key_value
                        return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)

                    return custom_forward

                outputs = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(block),
                    hidden_states,
                    alibi,
                    causal_mask,
                    layer_past,
                    head_mask[i],
                )

I think when backward, It only saves the gradient to hidden_states, not layer_past in past_key_values, leading to this question.
Thanks for your reply.

@BenjaminBossan
Copy link
Member

I don't have much experience with prefix tuning, maybe @pacman100 has an idea here.

@guilinzys2016
Copy link

@liuao743 我也是在给LLaMA-Effcient-Tuning添加perfix tuning方法的时候遇到了这个问题,请问现在有解决方法了嘛,比如更新或回退peft、transformers版本?

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@fan-niu
Copy link

fan-niu commented Oct 16, 2023

@liuao743 我也是在给LLaMA-Effcient-Tuning添加perfix tuning方法的时候遇到了这个问题,请问现在有解决方法了嘛,比如更新或回退peft、transformers版本?

this error need to modify transformers/src/transformers/models/llama/modeling_llama.py file,

 914                 def create_custom_forward(module):
 915                     def custom_forward(*inputs):
 916                         # None for past_key_value
 917                         return module(*inputs, output_attentions, padding_mask=padding_mask, use_cache=None)
 918 
 919                     return custom_forward
 920 
 921                 layer_outputs = torch.utils.checkpoint.checkpoint(
 922                     create_custom_forward(decoder_layer), hidden_states, attention_mask, position_ids, past_key_value,
 923                 )

set past_key_value and use_cache=None,it's ok

Copy link

github-actions bot commented Nov 9, 2023

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants