Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix M-RoPE position calculation when chunked prefill is enabled #10388

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

imkero
Copy link
Contributor

@imkero imkero commented Nov 16, 2024

Fix MRotaryEmbedding's get_input_positions when chunked prefill is enabled.

It only slice at the left-hand side of generated llm_positions currently (forgetting the right-hand side). This PR add right-hand slice position in it to support chunked prefill.

llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
llm_positions = llm_positions[:, context_len:]
mrope_position_delta = (llm_positions.max() + 1 -
len(input_tokens)).item()
return llm_positions.tolist(), mrope_position_delta

Explanation

To make it more clear, here is an example with following configuration:

  • assume a len=40 prompt

  • enable_chunked_prefill=True, and max_num_batched_tokens=32

  • add some log in model_runner.py::ModelInputForGPUBuilder::build near

    return self.model_input_cls(
    input_tokens=input_tokens_tensor,
    input_positions=input_positions_tensor,
    attn_metadata=attn_metadata,
    seq_lens=seq_lens,
    query_lens=query_lens,

Result:

step before this fix after this fix
1st prefill chunk
context_lens: [0]
query_lens: [32]
seq_lens: [32]
input_tokens_lens: [40]
mrope_input_positions: torch.Size([3, 40])
context_lens: [0]
query_lens: [32]
seq_lens: [32]
input_tokens_lens: [40]
mrope_input_positions: torch.Size([3, 32])
2nd prefill chunk broken in prev step
context_lens: [32]
query_lens: [8]
seq_lens: [40]
input_tokens_lens: [40]
mrope_input_positions: torch.Size([3, 8])
1st decode broken in prev step
context_lens: [40]
query_lens: [1]
seq_lens: [41]
input_tokens_lens: [40]
mrope_input_positions: torch.Size([3, 1])

Related error log:

RuntimeError: shape '[40, -1, 128]' is invalid for input of size 49152

the error occurs near:

num_tokens = positions.shape[-1]
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
if positions.ndim == 2:
assert self.mrope_section
cos = torch.cat([
m[i]
for i, m in enumerate(cos.split(self.mrope_section, dim=-1))
],
dim=-1)
sin = torch.cat([
m[i]
for i, m in enumerate(sin.split(self.mrope_section, dim=-1))
],
dim=-1)
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)

About the test I added

  1. Qwen2-VL's M-RoPE works only when there are some multi-modal inputs,
    so an image is included in the inputs

  2. however, Qwen2-VL currently won't work properly when chunked prefill is enabled and there are some multi-modal inputs (it assumes the input is never chunked)

    def _merge_multimodal_embeddings(
    self,
    input_ids: torch.Tensor,
    inputs_embeds: torch.Tensor,
    multimodal_embeddings: torch.Tensor,
    placeholder_token_id: int,
    ) -> torch.Tensor:
    mask = (input_ids == placeholder_token_id)
    inputs_embeds[mask, :] = multimodal_embeddings
    return inputs_embeds

    here use a hacky way: provide a zero-length image to make it happy

  3. and finally we achieved these requirements to allow our test continue

    • chunked prefill enabled
    • M-RoPE works

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

Signed-off-by: imkero <[email protected]>
@imkero imkero changed the title [Bugfix] M-RoPE position calculation when chunked prefill is enabled [Bugfix] Fix M-RoPE position calculation when chunked prefill is enabled Nov 16, 2024
@DarkLight1337
Copy link
Member

DarkLight1337 commented Nov 16, 2024

@ywang96 I thought chunked prefill isn't supported for VLMs yet? Or is this just not tested properly?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants