Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Audio streaming training with masking #148

Merged
merged 15 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 62 additions & 7 deletions ultravox/model/ultravox_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,8 @@ def forward(

# B x A/3200 x D
audio_tower_output = self.audio_tower.forward(
audio_values.to(self.audio_tower.dtype), audio_len=audio_len
audio_values.to(self.audio_tower.dtype),
audio_len=audio_len,
).last_hidden_state
audio_tower_output = audio_tower_output.to(inputs_embeds.dtype)

Expand Down Expand Up @@ -286,14 +287,28 @@ def _create_audio_tower(
audio_tower = ModifiedWhisperEncoder.from_pretrained(
config.audio_model_id, torch_dtype=config.torch_dtype
)
audio_tower.init_latency_mask(
config.audio_latency_block_size, dtype=config.torch_dtype
)
else:
assert config.audio_latency_block_size not in (
None,
0,
), "only whisper audio tower supports audio latency masking, got non-zero value for 'audio_latency_block_size'"
audio_tower = transformers.AutoModel.from_pretrained(
config.audio_model_id, torch_dtype=config.torch_dtype
)
else:
if "whisper" in config.audio_config._name_or_path:
audio_tower = ModifiedWhisperEncoder(config.audio_config)
audio_tower.init_latency_mask(
config.audio_latency_block_size, dtype=config.torch_dtype
)
else:
assert config.audio_latency_block_size not in (
None,
0,
), "only whisper audio tower supports audio latency masking, got non-zero value for 'audio_latency_block_size'"
with transformers.modeling_utils.no_init_weights():
# we only ever use from_config if the weights are retrained, hence initializing is not
# required. This makes the model quite creation faster since init on CPU is quite slow.
Expand Down Expand Up @@ -529,6 +544,39 @@ class ModifiedWhisperEncoder(
base_model_prefix = "model.encoder"
_no_split_modules = ["WhisperEncoderLayer"]

def init_latency_mask(self, audio_latency_block_size: int, dtype: torch.dtype):
if audio_latency_block_size is None:
self.audio_streaming_mask = None
return

# maximum sequence length
max_seqlen = (
self.config.max_source_positions
* self.conv1.stride[0]
* self.conv2.stride[0]
)
assert (
max_seqlen > 0
), f"maximum sequence length must be positive, got {max_seqlen}"
assert (
max_seqlen % audio_latency_block_size == 0
), f"audio_latency_block_size {audio_latency_block_size} must divide {max_seqlen} evenly."
# Given the block size, we calculate number of blocks.
audio_latency_nblocks = max_seqlen // audio_latency_block_size
audio_streaming_mask = (
torch.tril(
torch.ones(audio_latency_nblocks, audio_latency_nblocks),
diagonal=0,
)
.repeat_interleave(audio_latency_block_size, dim=0)
.repeat_interleave(audio_latency_block_size, dim=1)
)
audio_streaming_mask = (1.0 - audio_streaming_mask) * torch.finfo(dtype).min
audio_streaming_mask = audio_streaming_mask[None, None, :, :]
self.register_buffer(
"audio_streaming_mask", audio_streaming_mask, persistent=False
)

def forward(
self,
input_features,
Expand Down Expand Up @@ -586,20 +634,27 @@ def forward(
attention_mask = None
if audio_len != None:
audio_feature_len = self._get_feat_extract_output_lengths(audio_len)
batch_size = hidden_states.shape[0]
max_seq_len = hidden_states.shape[1]
attention_mask = (
torch.arange(max_seq_len, device=hidden_states.device)[None, :]
.expand(batch_size, -1)
.lt(audio_feature_len.view(batch_size, 1))
)
attention_mask = torch.arange(max_seq_len, device=hidden_states.device)[
None, :
].lt(audio_feature_len.view(-1, 1))
attention_mask = self.get_extended_attention_mask(
attention_mask,
None,
device=hidden_states.device,
dtype=hidden_states.dtype,
)

if self.audio_streaming_mask is not None:
seqlen = hidden_states.size(-2)
if attention_mask is not None:
attention_mask = torch.minimum(
self.audio_streaming_mask[:, :, :seqlen, :seqlen], attention_mask
) # merge
else:
attention_mask = self.audio_streaming_mask[:, :, :seqlen, :seqlen]
attention_mask = attention_mask.to(hidden_states.dtype)

# check if head_mask has a correct number of layers specified if desired
if head_mask is not None:
assert head_mask.size()[0] == (
Expand Down
62 changes: 62 additions & 0 deletions ultravox/model/ultravox_model_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import pytest
import torch
from transformers import WhisperConfig

from ultravox.model import ultravox_model


@pytest.fixture
def encoder():
config = WhisperConfig(
max_source_positions=1500,
d_model=256,
encoder_attention_heads=4,
encoder_layers=4,
)
return ultravox_model.ModifiedWhisperEncoder(config)


def test_init_latency_mask_none(encoder):
encoder.init_latency_mask(None, torch.float32)
assert encoder.audio_streaming_mask is None


def test_init_latency_mask_valid(encoder):
block_size = 100
encoder.init_latency_mask(block_size, torch.float32)
assert encoder.audio_streaming_mask is not None

assert len(encoder.audio_streaming_mask.shape) == 4
assert encoder.audio_streaming_mask.shape[0] == 1
assert encoder.audio_streaming_mask.shape[1] == 1

mask = encoder.audio_streaming_mask[0, 0]
# 100*30=3000
source_mask = (
torch.tril(torch.ones(30, 30), diagonal=0)
.repeat_interleave(block_size, dim=0)
.repeat_interleave(block_size, dim=1)
)
source_mask = (1.0 - source_mask) * torch.finfo(torch.float32).min
print(mask.shape)
assert torch.allclose(mask, source_mask)


def test_init_latency_mask_invalid_block_size(encoder):
invalid_block_size = 13

with pytest.raises(AssertionError, match="must divide .* evenly"):
encoder.init_latency_mask(invalid_block_size, torch.float32)


def test_init_latency_mask_different_dtypes(encoder):
block_size = 50
for dtype in (torch.float32, torch.float16):
encoder.init_latency_mask(block_size, dtype)
assert encoder.audio_streaming_mask.min() == torch.finfo(dtype).min


def test_init_latency_mask_persistence(encoder):
block_size = 50
encoder.init_latency_mask(block_size, torch.float32)
assert "audio_streaming_mask" in encoder._buffers
3 changes: 3 additions & 0 deletions ultravox/training/config_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ def get_val_sets(self) -> List[DatasetOptions]:
# loss function to use
loss_config: Optional[ultravox_config.LossConfig] = None

# To simulate audio streaming with masking. None for non-causal, 100 for 1s, 200 for 2s, and so on.
audio_latency_block_size: Optional[int] = None

def __post_init__(self):
assert self.data_type in ["bfloat16", "float16", "float32"]
if self.device == "cuda" and not torch.cuda.is_available():
Expand Down
2 changes: 2 additions & 0 deletions ultravox/training/configs/meta_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,5 @@ batch_size: 4
data_type: "bfloat16"

report_logs_to: ["tensorboard", "wandb"]

audio_latency_block_size: null # null for non-causal, 100 for 1s, 200 for 2s, and so on.
26 changes: 26 additions & 0 deletions ultravox/training/configs/streaming_tinyllama.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@

exp_name: "ultravox-streaming-experiments-1s"
# Make sure to accept the license agreement on huggingface hub
text_model: "meta-llama/Llama-3.2-1B-Instruct"
audio_model: "openai/whisper-small"
loss_config:
# Choose from ["KL_Divergence", "CrossEntropy"], default is "KL_Divergence"
loss_function: "KL_Divergence"
train_sets:
- name: librispeech-clean-continuation
- name: librispeech-other-continuation
- name: peoplespeech-clean-continuation
weight: 4
- name: commonvoice-en-continuation
weight: 4
- name: librispeech-clean-transcription
weight: 4
- name: librispeech-other-transcription
- name: peoplespeech-clean-transcription
- name: commonvoice-en-transcription
# Temporarily remove heysquad_human from val_sets as it causes the training to fail.
val_sets:
- name: peoplespeech
batch_size: 24
max_steps: 10000 # x8x24 = 2,764,800
audio_latency_block_size: 100 # null for non-causal, 100 for 1s, 200 for 2s, and so on.
1 change: 1 addition & 0 deletions ultravox/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def train(args: config_base.TrainConfig):
audio_model_lora_config=args.audio_model_lora_config,
torch_dtype=args.data_type,
pad_token_id=text_tokenizer.eos_token_id,
audio_latency_block_size=args.audio_latency_block_size,
)

logging.info("Instantiating model...")
Expand Down
Loading