Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CK Tile] Generic attention masking support for FMHA fwd and bwd #1340

Draft
wants to merge 14 commits into
base: develop
Choose a base branch
from

Conversation

cameronshinn
Copy link

The goal of these changes is to support generic attention masks for the FMHA operator in CK tile. The motivation is to support a variety of masking strategies from existing research (beyond just the existing causal masking). The two that I aim to add support for are:

Big Bird

Longformer

With the existing masking from SimplifiedGenericAttentionMask, it was only possible to create masks as a diagonal window, which can only let us do windowed attention or causal attention. Additionally, the FMHA fwd/bwd pipelines make the assumption that the masked tiles in a tile row (column for backwards) are contiguous. This can't support the Big Bird and Longformer masks.

These changes instead let the main mask interface, GenericAttentionMask, accept a mask definition object, which is where the mask-specific details are contained. A different mask definition can be passed in for different kinds of masks. For example, DiagonalMask mimics the previous method of windowed masking. The required signature of a mask definition is laid out in the MaskDefABC struct.

Masks can also be defined at a tile granularity instead of a per-element granularity, signified with IsTileMask. This is helpful since Big Bird uses block sparsity. Tile sizes need to be members of the struct somehow, and I found it easier to make them template parameters (x_tile, y_tile).

The pipelines have been modified to use an IndexIterator to skip to the next non-zero tile, since they can now be non-contiguous. The index iterator loops through the tile mask indices, checking through incrementing indices until a non-zero tile is found.

From what I can tell, the tradeoffs are such:

  1. 😊 Construct a variety of mask types easily
  2. 😊 Scalable to arbitrary sequence lengths
  3. 😊 Mask is defined in instructions rather than sparse data structure arrays
  4. ☹️ Mask size is unknown without evaluating the predicate across the entire attention matrix index space
  5. ☹️ Next non-zero in a row can't be determined without checking every index in-between

I am opening this as a draft PR to initiate any discussion. I currently am still working on adding in mask definitions for Big Bird and Longformer as well as some performance results to show (verifying that there isn't any perf regression for the existing masking).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant