You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I came across your llama_patch.py when looking to patch Llama for inference myself and unless I'm doing something wrong the implementation fails when use_cache=True and past_key_value is not None.
Specifically during geneartion with use_cache=True in this line query_states will have sequence length 1 while key_states and value_states will have length 1 + past_key_value[0].shape[-2] and thus these tensors won't stack.
I think this is also the other llama patches referenced in the comments don't support flash attention + kv cache at the same time. Not sure if there's a clever workaround?
The text was updated successfully, but these errors were encountered:
I see, no worries! Just came across this and thought I would let you know since the patch seemed to specifically implement the case with past_key_value unlike the other referenced implementations.
I came across your llama_patch.py when looking to patch Llama for inference myself and unless I'm doing something wrong the implementation fails when
use_cache=True
andpast_key_value
is notNone
.Specifically during geneartion with
use_cache=True
in this linequery_states
will have sequence length 1 whilekey_states
andvalue_states
will have length1 + past_key_value[0].shape[-2]
and thus these tensors won't stack.https://github.com/philschmid/deep-learning-pytorch-huggingface/blob/05d83eaa3c2ad6088227fa26dffb097e06439aef/training/utils/llama_patch.py#L76C3-L76C3
I think this is also the other llama patches referenced in the comments don't support flash attention + kv cache at the same time. Not sure if there's a clever workaround?
The text was updated successfully, but these errors were encountered: