Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[whisper] compile compatibility with long-form decoding #31772

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

sanchit-gandhi
Copy link
Contributor

What does this PR do?

PR #31166 introduced static k/v cache for Whisper short-form decoding. It was noted in this PR that the current generation logic is not compatible with sequential long-form generation, since the batch size is reduced dynamically in two places:

  1. In the outer loop over time position, we remove audio samples that have already finished generation
  2. In the inner loop over temperatures, we remove inputs that don't need fallback at the next temperature increment

For torch.compile compatibility with our current cache design, we require the batch size to be fixed. Otherwise, for every batch size we create a new cache object in .generate, which changes the data pr of the k/v cache tensors, causing a re-compile.

As things currently stand, we get re-compiles due both the outer and inner loop dynamically changing the batch size. This PR introduces a simple fix: pad the inputs to the max batch size before calling the model, and remove any padded outputs before post-processing.

The alternative would be to change the batch_idx_map logic, such that we always keep the full sequence of input features, but only update the sequence generations for the elements of interest. Having tried this quickly as a PoC, the changes are more involved than those proposed in this PR and quickly clutters the dynamic generation logic, which we're retaining for faster eager mode.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@kamilakesbi kamilakesbi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @sanchit-gandhi,

Those changes look good to me!

Padding the inputs to the maximum batch size before calling the model and then removing the padded outputs before post-processing is a nice way to fix the problem.

This solution should also be compatible with PR #30984 where we want to unify short and long form generation in Whisper.

@@ -807,6 +834,10 @@ def generate_with_fallback(
generation_config=generation_config,
)

if cur_bsz < batch_size:
seek_sequences = seek_sequences[:cur_bsz]
seek_outputs = seek_outputs[:cur_bsz]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice trick! :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants