-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable flash attention #20448
base: master
Are you sure you want to change the base?
Enable flash attention #20448
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #20448 +/- ##
==========================================
- Coverage 82.01% 76.84% -5.18%
==========================================
Files 514 514
Lines 47194 47261 +67
Branches 7408 7417 +9
==========================================
- Hits 38706 36317 -2389
- Misses 6698 9220 +2522
+ Partials 1790 1724 -66
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
keras/src/backend/config.py
Outdated
use flash attention for faster computations. | ||
""" | ||
global _ENABLE_FLASH_ATTENTION | ||
_ENABLE_FLASH_ATTENTION = value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is needs to be threadlocal. Instead of doing it like this, use set_global_attribute
/get_global_attribute
from keras.src.backend.common.global_state
. See how other global flags are implemented.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
importing global_state
in config created circular import error so I have moved the configs to attention.py
|
||
|
||
class MultiHeadAttentionTest(testing.TestCase): | ||
def test_basics(self): | ||
config.enable_flash_attention(True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a numerical correctness test with and without FA.
This PR
_compute_attention
method would just callops.dot_production_attention
keras.config.enable_flash_attention
andkeras.config.is_flash_attention_enabled