[whisper] compile compatibility with long-form decoding #31772
+104
−6
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
What does this PR do?
PR #31166 introduced static k/v cache for Whisper short-form decoding. It was noted in this PR that the current generation logic is not compatible with sequential long-form generation, since the batch size is reduced dynamically in two places:
For
torch.compile
compatibility with our current cache design, we require the batch size to be fixed. Otherwise, for every batch size we create a new cache object in.generate
, which changes the data pr of the k/v cache tensors, causing a re-compile.As things currently stand, we get re-compiles due both the outer and inner loop dynamically changing the batch size. This PR introduces a simple fix: pad the inputs to the max batch size before calling the model, and remove any padded outputs before post-processing.
The alternative would be to change the
batch_idx_map
logic, such that we always keep the full sequence of input features, but only update the sequence generations for the elements of interest. Having tried this quickly as a PoC, the changes are more involved than those proposed in this PR and quickly clutters the dynamic generation logic, which we're retaining for faster eager mode.