Skip to content

Commit

Permalink
Audio streaming training with masking (#148)
Browse files Browse the repository at this point in the history
* Audio streaming training with masking

* ultravox_model test

---------

Co-authored-by: Farzad Abdolhosseini <[email protected]>
  • Loading branch information
saeeddhqan and farzadab authored Nov 20, 2024
1 parent 3866fd7 commit 9fc2732
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 7 deletions.
69 changes: 62 additions & 7 deletions ultravox/model/ultravox_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,8 @@ def forward(

# B x A/3200 x D
audio_tower_output = self.audio_tower.forward(
audio_values.to(self.audio_tower.dtype), audio_len=audio_len
audio_values.to(self.audio_tower.dtype),
audio_len=audio_len,
).last_hidden_state
audio_tower_output = audio_tower_output.to(inputs_embeds.dtype)

Expand Down Expand Up @@ -286,14 +287,28 @@ def _create_audio_tower(
audio_tower = ModifiedWhisperEncoder.from_pretrained(
config.audio_model_id, torch_dtype=config.torch_dtype
)
audio_tower.init_latency_mask(
config.audio_latency_block_size, dtype=config.torch_dtype
)
else:
assert config.audio_latency_block_size not in (
None,
0,
), "only whisper audio tower supports audio latency masking, got non-zero value for 'audio_latency_block_size'"
audio_tower = transformers.AutoModel.from_pretrained(
config.audio_model_id, torch_dtype=config.torch_dtype
)
else:
if "whisper" in config.audio_config._name_or_path:
audio_tower = ModifiedWhisperEncoder(config.audio_config)
audio_tower.init_latency_mask(
config.audio_latency_block_size, dtype=config.torch_dtype
)
else:
assert config.audio_latency_block_size not in (
None,
0,
), "only whisper audio tower supports audio latency masking, got non-zero value for 'audio_latency_block_size'"
with transformers.modeling_utils.no_init_weights():
# we only ever use from_config if the weights are retrained, hence initializing is not
# required. This makes the model quite creation faster since init on CPU is quite slow.
Expand Down Expand Up @@ -529,6 +544,39 @@ class ModifiedWhisperEncoder(
base_model_prefix = "model.encoder"
_no_split_modules = ["WhisperEncoderLayer"]

def init_latency_mask(self, audio_latency_block_size: int, dtype: torch.dtype):
if audio_latency_block_size is None:
self.audio_streaming_mask = None
return

# maximum sequence length
max_seqlen = (
self.config.max_source_positions
* self.conv1.stride[0]
* self.conv2.stride[0]
)
assert (
max_seqlen > 0
), f"maximum sequence length must be positive, got {max_seqlen}"
assert (
max_seqlen % audio_latency_block_size == 0
), f"audio_latency_block_size {audio_latency_block_size} must divide {max_seqlen} evenly."
# Given the block size, we calculate number of blocks.
audio_latency_nblocks = max_seqlen // audio_latency_block_size
audio_streaming_mask = (
torch.tril(
torch.ones(audio_latency_nblocks, audio_latency_nblocks),
diagonal=0,
)
.repeat_interleave(audio_latency_block_size, dim=0)
.repeat_interleave(audio_latency_block_size, dim=1)
)
audio_streaming_mask = (1.0 - audio_streaming_mask) * torch.finfo(dtype).min
audio_streaming_mask = audio_streaming_mask[None, None, :, :]
self.register_buffer(
"audio_streaming_mask", audio_streaming_mask, persistent=False
)

def forward(
self,
input_features,
Expand Down Expand Up @@ -586,20 +634,27 @@ def forward(
attention_mask = None
if audio_len != None:
audio_feature_len = self._get_feat_extract_output_lengths(audio_len)
batch_size = hidden_states.shape[0]
max_seq_len = hidden_states.shape[1]
attention_mask = (
torch.arange(max_seq_len, device=hidden_states.device)[None, :]
.expand(batch_size, -1)
.lt(audio_feature_len.view(batch_size, 1))
)
attention_mask = torch.arange(max_seq_len, device=hidden_states.device)[
None, :
].lt(audio_feature_len.view(-1, 1))
attention_mask = self.get_extended_attention_mask(
attention_mask,
None,
device=hidden_states.device,
dtype=hidden_states.dtype,
)

if self.audio_streaming_mask is not None:
seqlen = hidden_states.size(-2)
if attention_mask is not None:
attention_mask = torch.minimum(
self.audio_streaming_mask[:, :, :seqlen, :seqlen], attention_mask
) # merge
else:
attention_mask = self.audio_streaming_mask[:, :, :seqlen, :seqlen]
attention_mask = attention_mask.to(hidden_states.dtype)

# check if head_mask has a correct number of layers specified if desired
if head_mask is not None:
assert head_mask.size()[0] == (
Expand Down
62 changes: 62 additions & 0 deletions ultravox/model/ultravox_model_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import pytest
import torch
from transformers import WhisperConfig

from ultravox.model import ultravox_model


@pytest.fixture
def encoder():
config = WhisperConfig(
max_source_positions=1500,
d_model=256,
encoder_attention_heads=4,
encoder_layers=4,
)
return ultravox_model.ModifiedWhisperEncoder(config)


def test_init_latency_mask_none(encoder):
encoder.init_latency_mask(None, torch.float32)
assert encoder.audio_streaming_mask is None


def test_init_latency_mask_valid(encoder):
block_size = 100
encoder.init_latency_mask(block_size, torch.float32)
assert encoder.audio_streaming_mask is not None

assert len(encoder.audio_streaming_mask.shape) == 4
assert encoder.audio_streaming_mask.shape[0] == 1
assert encoder.audio_streaming_mask.shape[1] == 1

mask = encoder.audio_streaming_mask[0, 0]
# 100*30=3000
source_mask = (
torch.tril(torch.ones(30, 30), diagonal=0)
.repeat_interleave(block_size, dim=0)
.repeat_interleave(block_size, dim=1)
)
source_mask = (1.0 - source_mask) * torch.finfo(torch.float32).min
print(mask.shape)
assert torch.allclose(mask, source_mask)


def test_init_latency_mask_invalid_block_size(encoder):
invalid_block_size = 13

with pytest.raises(AssertionError, match="must divide .* evenly"):
encoder.init_latency_mask(invalid_block_size, torch.float32)


def test_init_latency_mask_different_dtypes(encoder):
block_size = 50
for dtype in (torch.float32, torch.float16):
encoder.init_latency_mask(block_size, dtype)
assert encoder.audio_streaming_mask.min() == torch.finfo(dtype).min


def test_init_latency_mask_persistence(encoder):
block_size = 50
encoder.init_latency_mask(block_size, torch.float32)
assert "audio_streaming_mask" in encoder._buffers
3 changes: 3 additions & 0 deletions ultravox/training/config_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ def get_val_sets(self) -> List[DatasetOptions]:
# loss function to use
loss_config: Optional[ultravox_config.LossConfig] = None

# To simulate audio streaming with masking. None for non-causal, 100 for 1s, 200 for 2s, and so on.
audio_latency_block_size: Optional[int] = None

def __post_init__(self):
assert self.data_type in ["bfloat16", "float16", "float32"]
if self.device == "cuda" and not torch.cuda.is_available():
Expand Down
2 changes: 2 additions & 0 deletions ultravox/training/configs/meta_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,5 @@ batch_size: 4
data_type: "bfloat16"

report_logs_to: ["tensorboard", "wandb"]

audio_latency_block_size: null # null for non-causal, 100 for 1s, 200 for 2s, and so on.
26 changes: 26 additions & 0 deletions ultravox/training/configs/streaming_tinyllama.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@

exp_name: "ultravox-streaming-experiments-1s"
# Make sure to accept the license agreement on huggingface hub
text_model: "meta-llama/Llama-3.2-1B-Instruct"
audio_model: "openai/whisper-small"
loss_config:
# Choose from ["KL_Divergence", "CrossEntropy"], default is "KL_Divergence"
loss_function: "KL_Divergence"
train_sets:
- name: librispeech-clean-continuation
- name: librispeech-other-continuation
- name: peoplespeech-clean-continuation
weight: 4
- name: commonvoice-en-continuation
weight: 4
- name: librispeech-clean-transcription
weight: 4
- name: librispeech-other-transcription
- name: peoplespeech-clean-transcription
- name: commonvoice-en-transcription
# Temporarily remove heysquad_human from val_sets as it causes the training to fail.
val_sets:
- name: peoplespeech
batch_size: 24
max_steps: 10000 # x8x24 = 2,764,800
audio_latency_block_size: 100 # null for non-causal, 100 for 1s, 200 for 2s, and so on.
1 change: 1 addition & 0 deletions ultravox/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def train(args: config_base.TrainConfig):
audio_model_lora_config=args.audio_model_lora_config,
torch_dtype=args.data_type,
pad_token_id=text_tokenizer.eos_token_id,
audio_latency_block_size=args.audio_latency_block_size,
)

logging.info("Instantiating model...")
Expand Down

0 comments on commit 9fc2732

Please sign in to comment.