[CK Tile] Generic attention masking support for FMHA fwd and bwd #1340
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
The goal of these changes is to support generic attention masks for the FMHA operator in CK tile. The motivation is to support a variety of masking strategies from existing research (beyond just the existing causal masking). The two that I aim to add support for are:
Big Bird
Longformer
With the existing masking from
SimplifiedGenericAttentionMask
, it was only possible to create masks as a diagonal window, which can only let us do windowed attention or causal attention. Additionally, the FMHA fwd/bwd pipelines make the assumption that the masked tiles in a tile row (column for backwards) are contiguous. This can't support the Big Bird and Longformer masks.These changes instead let the main mask interface,
GenericAttentionMask
, accept a mask definition object, which is where the mask-specific details are contained. A different mask definition can be passed in for different kinds of masks. For example,DiagonalMask
mimics the previous method of windowed masking. The required signature of a mask definition is laid out in theMaskDefABC
struct.Masks can also be defined at a tile granularity instead of a per-element granularity, signified with
IsTileMask
. This is helpful since Big Bird uses block sparsity. Tile sizes need to be members of the struct somehow, and I found it easier to make them template parameters (x_tile
,y_tile
).The pipelines have been modified to use an
IndexIterator
to skip to the next non-zero tile, since they can now be non-contiguous. The index iterator loops through the tile mask indices, checking through incrementing indices until a non-zero tile is found.From what I can tell, the tradeoffs are such:
I am opening this as a draft PR to initiate any discussion. I currently am still working on adding in mask definitions for Big Bird and Longformer as well as some performance results to show (verifying that there isn't any perf regression for the existing masking).