Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama patch for FlashAttention support fails with use_cache #26

Open
qmdnls opened this issue Aug 1, 2023 · 2 comments
Open

Llama patch for FlashAttention support fails with use_cache #26

qmdnls opened this issue Aug 1, 2023 · 2 comments

Comments

@qmdnls
Copy link

qmdnls commented Aug 1, 2023

I came across your llama_patch.py when looking to patch Llama for inference myself and unless I'm doing something wrong the implementation fails when use_cache=True and past_key_value is not None.

Specifically during geneartion with use_cache=True in this line query_states will have sequence length 1 while key_states and value_states will have length 1 + past_key_value[0].shape[-2] and thus these tensors won't stack.

https://github.com/philschmid/deep-learning-pytorch-huggingface/blob/05d83eaa3c2ad6088227fa26dffb097e06439aef/training/utils/llama_patch.py#L76C3-L76C3

I think this is also the other llama patches referenced in the comments don't support flash attention + kv cache at the same time. Not sure if there's a clever workaround?

@philschmid
Copy link
Owner

Hey @qmdnls,

It could be very true what you say. I created the patch only for training, where you use gradient checkpointing and no cache.

If you are interested in inference i recommend checking text-generation-infernece

@qmdnls
Copy link
Author

qmdnls commented Aug 1, 2023

I see, no worries! Just came across this and thought I would let you know since the patch seemed to specifically implement the case with past_key_value unlike the other referenced implementations.

Thanks for the pointer, I will have a look!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants