Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable flash attention #20448

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from

Conversation

divyashreepathihalli
Copy link
Collaborator

@divyashreepathihalli divyashreepathihalli commented Nov 4, 2024

This PR

  • refactors the MHA layer so that its _compute_attention method would just call ops.dot_production_attention
  • Adds a global toggle keras.config.enable_flash_attention and keras.config.is_flash_attention_enabled

@codecov-commenter
Copy link

codecov-commenter commented Nov 4, 2024

Codecov Report

Attention: Patch coverage is 72.72727% with 6 lines in your changes missing coverage. Please review.

Project coverage is 76.84%. Comparing base (c052cea) to head (3a47c53).
Report is 2 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/layers/attention/multi_head_attention.py 63.63% 3 Missing and 1 partial ⚠️
keras/api/_tf_keras/keras/config/__init__.py 0.00% 2 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (c052cea) and HEAD (3a47c53). Click for more details.

HEAD has 2 uploads less than BASE
Flag BASE (c052cea) HEAD (3a47c53)
keras 4 3
keras-jax 1 0
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #20448      +/-   ##
==========================================
- Coverage   82.01%   76.84%   -5.18%     
==========================================
  Files         514      514              
  Lines       47194    47261      +67     
  Branches     7408     7417       +9     
==========================================
- Hits        38706    36317    -2389     
- Misses       6698     9220    +2522     
+ Partials     1790     1724      -66     
Flag Coverage Δ
keras 76.72% <72.72%> (-5.15%) ⬇️
keras-jax ?
keras-numpy 59.86% <72.72%> (-0.01%) ⬇️
keras-tensorflow 65.90% <72.72%> (-0.02%) ⬇️
keras-torch 64.85% <72.72%> (-0.02%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

keras/src/backend/config.py Outdated Show resolved Hide resolved
use flash attention for faster computations.
"""
global _ENABLE_FLASH_ATTENTION
_ENABLE_FLASH_ATTENTION = value
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is needs to be threadlocal. Instead of doing it like this, use set_global_attribute/get_global_attribute from keras.src.backend.common.global_state. See how other global flags are implemented.

Copy link
Collaborator Author

@divyashreepathihalli divyashreepathihalli Nov 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

importing global_state in config created circular import error so I have moved the configs to attention.py



class MultiHeadAttentionTest(testing.TestCase):
def test_basics(self):
config.enable_flash_attention(True)
Copy link
Member

@fchollet fchollet Nov 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a numerical correctness test with and without FA.

@divyashreepathihalli divyashreepathihalli marked this pull request as draft November 5, 2024 01:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants