From ac6f970023205e8fb360c38e691e7b8d118b5eac Mon Sep 17 00:00:00 2001 From: Saeed Dehqan Date: Sat, 9 Nov 2024 16:59:36 +0000 Subject: [PATCH 01/15] Audio streaming training with masking --- ultravox/model/ultravox_model.py | 35 ++++++++++++++++--- ultravox/training/config_base.py | 2 ++ ultravox/training/configs/meta_config.yaml | 2 ++ .../training/configs/streaming_tinyllama.yaml | 26 ++++++++++++++ ultravox/training/train.py | 1 + 5 files changed, 62 insertions(+), 4 deletions(-) create mode 100644 ultravox/training/configs/streaming_tinyllama.yaml diff --git a/ultravox/model/ultravox_model.py b/ultravox/model/ultravox_model.py index 1a190df..609bce1 100644 --- a/ultravox/model/ultravox_model.py +++ b/ultravox/model/ultravox_model.py @@ -55,6 +55,24 @@ def __init__(self, config: UltravoxConfig): self.audio_tower._no_split_modules or [] ) + audio_latency_bsize = config.audio_latency_bsize + audio_streaming_mask = None + if audio_latency_bsize != -1: + audio_latency_nblocks = (self.audio_tower.config.max_source_positions + * self.audio_tower.conv1.stride[0] + * self.audio_tower.conv2.stride[0] + ) + audio_latency_nblocks //= 2 * audio_latency_bsize + audio_streaming_mask = torch.tril( + torch.ones( + audio_latency_nblocks, audio_latency_nblocks), + diagonal=0, + ).repeat_interleave(audio_latency_bsize, dim=0).repeat_interleave(audio_latency_bsize, dim=1) + audio_streaming_mask[audio_streaming_mask == 0] = float('-inf') + audio_streaming_mask[audio_streaming_mask == 1] = 0 + audio_streaming_mask = audio_streaming_mask[None, None, :, :] + self.register_buffer('audio_streaming_mask', audio_streaming_mask, persistent=False) + self.loss_config = LossConfig() self.post_init() @@ -190,7 +208,9 @@ def forward( # B x A/3200 x D audio_tower_output = self.audio_tower.forward( - audio_values.to(self.audio_tower.dtype), audio_len=audio_len + audio_values.to(self.audio_tower.dtype), + audio_len=audio_len, + attention_mask_bw=self.audio_streaming_mask, ).last_hidden_state audio_tower_output = audio_tower_output.to(inputs_embeds.dtype) @@ -537,6 +557,7 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, + attention_mask_bw=None, ): expected_seq_length = ( self.config.max_source_positions @@ -586,12 +607,10 @@ def forward( attention_mask = None if audio_len != None: audio_feature_len = self._get_feat_extract_output_lengths(audio_len) - batch_size = hidden_states.shape[0] max_seq_len = hidden_states.shape[1] attention_mask = ( torch.arange(max_seq_len, device=hidden_states.device)[None, :] - .expand(batch_size, -1) - .lt(audio_feature_len.view(batch_size, 1)) + .lt(audio_feature_len.view(-1, 1)) ) attention_mask = self.get_extended_attention_mask( attention_mask, @@ -600,6 +619,14 @@ def forward( dtype=hidden_states.dtype, ) + if attention_mask_bw is not None: + seqlen = hidden_states.size(-2) + if attention_mask is not None: + attention_mask = torch.minimum(attention_mask_bw[:,:,:seqlen, :seqlen], attention_mask) # merge + else: + attention_mask = attention_mask_bw[:,:,:seqlen, :seqlen] + attention_mask = attention_mask.to(hidden_states.dtype) + # check if head_mask has a correct number of layers specified if desired if head_mask is not None: assert head_mask.size()[0] == ( diff --git a/ultravox/training/config_base.py b/ultravox/training/config_base.py index 22d84f9..e7b8288 100644 --- a/ultravox/training/config_base.py +++ b/ultravox/training/config_base.py @@ -106,6 +106,8 @@ def get_val_sets(self) -> List[DatasetOptions]: # loss function to use loss_config: Optional[ultravox_config.LossConfig] = None + audio_latency_bsize: int = -1 + def __post_init__(self): assert self.data_type in ["bfloat16", "float16", "float32"] if self.device == "cuda" and not torch.cuda.is_available(): diff --git a/ultravox/training/configs/meta_config.yaml b/ultravox/training/configs/meta_config.yaml index bfe3626..e045117 100644 --- a/ultravox/training/configs/meta_config.yaml +++ b/ultravox/training/configs/meta_config.yaml @@ -31,3 +31,5 @@ batch_size: 4 data_type: "bfloat16" report_logs_to: ["tensorboard", "wandb"] + +audio_latency_bsize: 50 # -1 for non-causal, 50 for 1s, 100 for 2s, and so on. \ No newline at end of file diff --git a/ultravox/training/configs/streaming_tinyllama.yaml b/ultravox/training/configs/streaming_tinyllama.yaml new file mode 100644 index 0000000..284a2ee --- /dev/null +++ b/ultravox/training/configs/streaming_tinyllama.yaml @@ -0,0 +1,26 @@ +# SLM with ultravox & llama3.1, trained wtih knowledge distillation. +exp_name: "ultravox-streaming-experiments-1s" +# Make sure to accept the license agreement on huggingface hub +text_model: "meta-llama/Llama-3.2-1B-Instruct" +audio_model: "openai/whisper-small" +loss_config: + # Choose from ["KL_Divergence", "CrossEntropy"], default is "KL_Divergence" + loss_function: "KL_Divergence" +train_sets: + - name: librispeech-clean-continuation + - name: librispeech-other-continuation + - name: peoplespeech-clean-continuation + weight: 4 + - name: commonvoice-en-continuation + weight: 4 + - name: librispeech-clean-transcription + weight: 4 + - name: librispeech-other-transcription + - name: peoplespeech-clean-transcription + - name: commonvoice-en-transcription +# Temporarily remove heysquad_human from val_sets as it causes the training to fail. +val_sets: + - name: peoplespeech +batch_size: 24 +max_steps: 10000 # x8x24 = 2,764,800 +audio_latency_bsize: 50 # -1 for non-causal, 50 for 1s, 100 for 2s, and so on. \ No newline at end of file diff --git a/ultravox/training/train.py b/ultravox/training/train.py index 5281d80..0cbc8dc 100644 --- a/ultravox/training/train.py +++ b/ultravox/training/train.py @@ -123,6 +123,7 @@ def train(args: config_base.TrainConfig): audio_model_lora_config=args.audio_model_lora_config, torch_dtype=args.data_type, pad_token_id=text_tokenizer.eos_token_id, + audio_latency_bsize=args.audio_latency_bsize, ) logging.info("Instantiating model...") From 87fdd0695c5ead0ad5306c6557436fccf1c2943c Mon Sep 17 00:00:00 2001 From: Saeed Dehqan <31902891+saeeddhqan@users.noreply.github.com> Date: Sat, 9 Nov 2024 12:02:15 -0500 Subject: [PATCH 02/15] non-causal as default config --- ultravox/training/configs/meta_config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ultravox/training/configs/meta_config.yaml b/ultravox/training/configs/meta_config.yaml index e045117..5d4f436 100644 --- a/ultravox/training/configs/meta_config.yaml +++ b/ultravox/training/configs/meta_config.yaml @@ -32,4 +32,4 @@ data_type: "bfloat16" report_logs_to: ["tensorboard", "wandb"] -audio_latency_bsize: 50 # -1 for non-causal, 50 for 1s, 100 for 2s, and so on. \ No newline at end of file +audio_latency_bsize: -1 # -1 for non-causal, 50 for 1s, 100 for 2s, and so on. From b041999340ca491433811f03a10f6740b6f61d2a Mon Sep 17 00:00:00 2001 From: Saeed Dehqan Date: Sat, 16 Nov 2024 07:48:12 +0000 Subject: [PATCH 03/15] Move audio latency mask to the corresponding class --- ultravox/model/ultravox_model.py | 58 +++++++++++-------- ultravox/training/config_base.py | 3 +- ultravox/training/configs/meta_config.yaml | 2 +- .../training/configs/streaming_tinyllama.yaml | 2 +- ultravox/training/train.py | 2 +- 5 files changed, 40 insertions(+), 27 deletions(-) diff --git a/ultravox/model/ultravox_model.py b/ultravox/model/ultravox_model.py index 609bce1..4be610f 100644 --- a/ultravox/model/ultravox_model.py +++ b/ultravox/model/ultravox_model.py @@ -55,24 +55,6 @@ def __init__(self, config: UltravoxConfig): self.audio_tower._no_split_modules or [] ) - audio_latency_bsize = config.audio_latency_bsize - audio_streaming_mask = None - if audio_latency_bsize != -1: - audio_latency_nblocks = (self.audio_tower.config.max_source_positions - * self.audio_tower.conv1.stride[0] - * self.audio_tower.conv2.stride[0] - ) - audio_latency_nblocks //= 2 * audio_latency_bsize - audio_streaming_mask = torch.tril( - torch.ones( - audio_latency_nblocks, audio_latency_nblocks), - diagonal=0, - ).repeat_interleave(audio_latency_bsize, dim=0).repeat_interleave(audio_latency_bsize, dim=1) - audio_streaming_mask[audio_streaming_mask == 0] = float('-inf') - audio_streaming_mask[audio_streaming_mask == 1] = 0 - audio_streaming_mask = audio_streaming_mask[None, None, :, :] - self.register_buffer('audio_streaming_mask', audio_streaming_mask, persistent=False) - self.loss_config = LossConfig() self.post_init() @@ -210,7 +192,6 @@ def forward( audio_tower_output = self.audio_tower.forward( audio_values.to(self.audio_tower.dtype), audio_len=audio_len, - attention_mask_bw=self.audio_streaming_mask, ).last_hidden_state audio_tower_output = audio_tower_output.to(inputs_embeds.dtype) @@ -306,6 +287,8 @@ def _create_audio_tower( audio_tower = ModifiedWhisperEncoder.from_pretrained( config.audio_model_id, torch_dtype=config.torch_dtype ) + audio_tower.init_latency_mask( + config.audio_latency_block_size, dtype=config.torch_dtype) else: audio_tower = transformers.AutoModel.from_pretrained( config.audio_model_id, torch_dtype=config.torch_dtype @@ -313,6 +296,7 @@ def _create_audio_tower( else: if "whisper" in config.audio_config._name_or_path: audio_tower = ModifiedWhisperEncoder(config.audio_config) + audio_tower.init_latency_mask(config.audio_latency_block_size, dtype=config.torch_dtype) else: with transformers.modeling_utils.no_init_weights(): # we only ever use from_config if the weights are retrained, hence initializing is not @@ -549,6 +533,35 @@ class ModifiedWhisperEncoder( base_model_prefix = "model.encoder" _no_split_modules = ["WhisperEncoderLayer"] + def init_latency_mask(self, audio_latency_block_size: int, dtype: torch.float): + if audio_latency_block_size is None: + self.audio_streaming_mask = None + return + audio_streaming_mask = None + # total sequence length + audio_latency_nblocks = (self.config.max_source_positions + * self.conv1.stride[0] + * self.conv2.stride[0] + ) + + assert audio_latency_nblocks > 0, f"total sequence length must be positive, got {audio_latency_nblocks}" + + assert audio_latency_nblocks % (2 * audio_latency_block_size) == 0, ( + f"audio_latency_block_size {audio_latency_block_size} must divide {audio_latency_nblocks} evenly. " + ) + + audio_latency_nblocks //= 2 * audio_latency_block_size + audio_streaming_mask = torch.tril( + torch.ones( + audio_latency_nblocks, audio_latency_nblocks), + diagonal=0, + ).repeat_interleave(audio_latency_block_size, dim=0).repeat_interleave(audio_latency_block_size, dim=1) + audio_streaming_mask = (1.0 - audio_streaming_mask) * torch.finfo(dtype).min + # audio_streaming_mask[audio_streaming_mask == 0] = float('-inf') + # audio_streaming_mask[audio_streaming_mask == 1] = 0 + audio_streaming_mask = audio_streaming_mask[None, None, :, :] + self.register_buffer('audio_streaming_mask', audio_streaming_mask, persistent=False) + def forward( self, input_features, @@ -557,7 +570,6 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, - attention_mask_bw=None, ): expected_seq_length = ( self.config.max_source_positions @@ -619,12 +631,12 @@ def forward( dtype=hidden_states.dtype, ) - if attention_mask_bw is not None: + if self.audio_streaming_mask is not None: seqlen = hidden_states.size(-2) if attention_mask is not None: - attention_mask = torch.minimum(attention_mask_bw[:,:,:seqlen, :seqlen], attention_mask) # merge + attention_mask = torch.minimum(self.audio_streaming_mask[:,:,:seqlen, :seqlen], attention_mask) # merge else: - attention_mask = attention_mask_bw[:,:,:seqlen, :seqlen] + attention_mask = self.audio_streaming_mask[:,:,:seqlen, :seqlen] attention_mask = attention_mask.to(hidden_states.dtype) # check if head_mask has a correct number of layers specified if desired diff --git a/ultravox/training/config_base.py b/ultravox/training/config_base.py index e7b8288..7ee2b26 100644 --- a/ultravox/training/config_base.py +++ b/ultravox/training/config_base.py @@ -106,7 +106,8 @@ def get_val_sets(self) -> List[DatasetOptions]: # loss function to use loss_config: Optional[ultravox_config.LossConfig] = None - audio_latency_bsize: int = -1 + # To simulate audio streaming with masking. None for non-causal, 50 for 1s, 100 for 2s, and so on. + audio_latency_block_size: int = None def __post_init__(self): assert self.data_type in ["bfloat16", "float16", "float32"] diff --git a/ultravox/training/configs/meta_config.yaml b/ultravox/training/configs/meta_config.yaml index 5d4f436..5bc0f5a 100644 --- a/ultravox/training/configs/meta_config.yaml +++ b/ultravox/training/configs/meta_config.yaml @@ -32,4 +32,4 @@ data_type: "bfloat16" report_logs_to: ["tensorboard", "wandb"] -audio_latency_bsize: -1 # -1 for non-causal, 50 for 1s, 100 for 2s, and so on. +audio_latency_block_size: null # None for non-causal, 50 for 1s, 100 for 2s, and so on. diff --git a/ultravox/training/configs/streaming_tinyllama.yaml b/ultravox/training/configs/streaming_tinyllama.yaml index 284a2ee..2ab3a6b 100644 --- a/ultravox/training/configs/streaming_tinyllama.yaml +++ b/ultravox/training/configs/streaming_tinyllama.yaml @@ -23,4 +23,4 @@ val_sets: - name: peoplespeech batch_size: 24 max_steps: 10000 # x8x24 = 2,764,800 -audio_latency_bsize: 50 # -1 for non-causal, 50 for 1s, 100 for 2s, and so on. \ No newline at end of file +audio_latency_block_size: null # None for non-causal, 50 for 1s, 100 for 2s, and so on. \ No newline at end of file diff --git a/ultravox/training/train.py b/ultravox/training/train.py index 0cbc8dc..9bc3339 100644 --- a/ultravox/training/train.py +++ b/ultravox/training/train.py @@ -123,7 +123,7 @@ def train(args: config_base.TrainConfig): audio_model_lora_config=args.audio_model_lora_config, torch_dtype=args.data_type, pad_token_id=text_tokenizer.eos_token_id, - audio_latency_bsize=args.audio_latency_bsize, + audio_latency_block_size=args.audio_latency_block_size, ) logging.info("Instantiating model...") From b9fc8544e22890fb5d88254aacf44ed06a23ccc1 Mon Sep 17 00:00:00 2001 From: Saeed Dehqan <31902891+saeeddhqan@users.noreply.github.com> Date: Sat, 16 Nov 2024 02:51:38 -0500 Subject: [PATCH 04/15] Move audio latency mask to the corresponding class From be1906df6862b159e6bdab8d3aee7d94b3c2bdc1 Mon Sep 17 00:00:00 2001 From: Saeed Dehqan <31902891+saeeddhqan@users.noreply.github.com> Date: Sat, 16 Nov 2024 02:53:36 -0500 Subject: [PATCH 05/15] Add comment for audio latency and change var name From a0ea1d49fa9fa9afc5c9f03d155bd5465d32d979 Mon Sep 17 00:00:00 2001 From: Saeed Dehqan <31902891+saeeddhqan@users.noreply.github.com> Date: Sat, 16 Nov 2024 02:54:04 -0500 Subject: [PATCH 06/15] Add comment and change var name for audio latency From 57548167ac73d86c64302c9e4eb6a6cb7892ce6c Mon Sep 17 00:00:00 2001 From: Saeed Dehqan <31902891+saeeddhqan@users.noreply.github.com> Date: Sat, 16 Nov 2024 02:56:10 -0500 Subject: [PATCH 07/15] Add comment and change var name for audio latency From 201f8520f0b08135f02d8a3ec887cc273818c97c Mon Sep 17 00:00:00 2001 From: Saeed Dehqan Date: Sun, 17 Nov 2024 16:28:33 +0000 Subject: [PATCH 08/15] ultravox_model test --- ultravox/model/ultravox_model_test.py | 52 +++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 ultravox/model/ultravox_model_test.py diff --git a/ultravox/model/ultravox_model_test.py b/ultravox/model/ultravox_model_test.py new file mode 100644 index 0000000..8b675de --- /dev/null +++ b/ultravox/model/ultravox_model_test.py @@ -0,0 +1,52 @@ +import pytest +import torch +from transformers import WhisperConfig +from transformers.models.whisper import modeling_whisper as whisper +from ultravox.model import ultravox_model + +class TestModifiedWhisperEncoder: + @pytest.fixture + def encoder(self): + config = WhisperConfig( + max_source_positions=3000, + d_model=256, + encoder_attention_heads=4, + encoder_layers=4, + ) + return ultravox_model.ModifiedWhisperEncoder(config) + + def test_init_latency_mask_none(self, encoder): + encoder.init_latency_mask(None, torch.float32) + assert encoder.audio_streaming_mask is None + + def test_init_latency_mask_valid(self, encoder): + block_size = 50 + encoder.init_latency_mask(block_size, torch.float32) + assert encoder.audio_streaming_mask is not None + + assert len(encoder.audio_streaming_mask.shape) == 4 + assert encoder.audio_streaming_mask.shape[0] == 1 + assert encoder.audio_streaming_mask.shape[1] == 1 + + mask = encoder.audio_streaming_mask[0, 0] + # 50x60=3000 + source_mask = torch.tril(torch.ones(60, 60), diagonal=0).repeat_interleave(block_size, dim=0).repeat_interleave(block_size, dim=1) + source_mask = (1.0 - source_mask) * torch.finfo(torch.float32).min + assert torch.allclose(mask, source_mask) + + def test_init_latency_mask_invalid_block_size(self, encoder): + invalid_block_size = 13 + + with pytest.raises(AssertionError, match="must divide .* evenly"): + encoder.init_latency_mask(invalid_block_size, torch.float32) + + def test_init_latency_mask_different_dtypes(self, encoder): + block_size = 50 + for dtype in (torch.float32, torch.float16): + encoder.init_latency_mask(block_size, dtype) + assert encoder.audio_streaming_mask.min() == torch.finfo(dtype).min + + def test_init_latency_mask_persistence(self, encoder): + block_size = 50 + encoder.init_latency_mask(block_size, torch.float32) + assert 'audio_streaming_mask' in encoder._buffers From 15cbeffb662d249e58cf8cc6277061796bc7e6e7 Mon Sep 17 00:00:00 2001 From: Saeed Dehqan <31902891+saeeddhqan@users.noreply.github.com> Date: Tue, 19 Nov 2024 14:45:55 -0500 Subject: [PATCH 09/15] Update ultravox/model/ultravox_model.py Co-authored-by: Farzad Abdolhosseini --- ultravox/model/ultravox_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ultravox/model/ultravox_model.py b/ultravox/model/ultravox_model.py index 4be610f..34a76f9 100644 --- a/ultravox/model/ultravox_model.py +++ b/ultravox/model/ultravox_model.py @@ -538,7 +538,7 @@ def init_latency_mask(self, audio_latency_block_size: int, dtype: torch.float): self.audio_streaming_mask = None return audio_streaming_mask = None - # total sequence length + # maximum sequence length audio_latency_nblocks = (self.config.max_source_positions * self.conv1.stride[0] * self.conv2.stride[0] From 5d18d50bd44f60afdbb5d60e81cf771946fe8bb1 Mon Sep 17 00:00:00 2001 From: Saeed Dehqan Date: Tue, 19 Nov 2024 21:43:47 +0000 Subject: [PATCH 10/15] amendments --- ultravox/model/ultravox_model.py | 25 ++--- ultravox/model/ultravox_model_test.py | 92 +++++++++---------- ultravox/training/config_base.py | 2 +- ultravox/training/configs/meta_config.yaml | 2 +- .../training/configs/streaming_tinyllama.yaml | 4 +- 5 files changed, 63 insertions(+), 62 deletions(-) diff --git a/ultravox/model/ultravox_model.py b/ultravox/model/ultravox_model.py index 34a76f9..dad66d1 100644 --- a/ultravox/model/ultravox_model.py +++ b/ultravox/model/ultravox_model.py @@ -290,14 +290,19 @@ def _create_audio_tower( audio_tower.init_latency_mask( config.audio_latency_block_size, dtype=config.torch_dtype) else: + assert config.audio_latency_block_size not in (None, 0), \ + "only whisper audio tower supports audio latency masking, got non-zero value for 'audio_latency_block_size'" audio_tower = transformers.AutoModel.from_pretrained( config.audio_model_id, torch_dtype=config.torch_dtype ) else: if "whisper" in config.audio_config._name_or_path: audio_tower = ModifiedWhisperEncoder(config.audio_config) - audio_tower.init_latency_mask(config.audio_latency_block_size, dtype=config.torch_dtype) + audio_tower.init_latency_mask( + config.audio_latency_block_size, dtype=config.torch_dtype) else: + assert config.audio_latency_block_size not in (None, 0), \ + "only whisper audio tower supports audio latency masking, got non-zero value for 'audio_latency_block_size'" with transformers.modeling_utils.no_init_weights(): # we only ever use from_config if the weights are retrained, hence initializing is not # required. This makes the model quite creation faster since init on CPU is quite slow. @@ -539,28 +544,24 @@ def init_latency_mask(self, audio_latency_block_size: int, dtype: torch.float): return audio_streaming_mask = None # maximum sequence length - audio_latency_nblocks = (self.config.max_source_positions + max_seqlen = (self.config.max_source_positions * self.conv1.stride[0] * self.conv2.stride[0] ) - - assert audio_latency_nblocks > 0, f"total sequence length must be positive, got {audio_latency_nblocks}" - - assert audio_latency_nblocks % (2 * audio_latency_block_size) == 0, ( - f"audio_latency_block_size {audio_latency_block_size} must divide {audio_latency_nblocks} evenly. " + assert max_seqlen > 0, f"maximum sequence length must be positive, got {max_seqlen}" + assert max_seqlen % audio_latency_block_size == 0, ( + f"audio_latency_block_size {audio_latency_block_size} must divide {max_seqlen} evenly." ) - - audio_latency_nblocks //= 2 * audio_latency_block_size + # Given the block size, we calculate number of blocks. + audio_latency_nblocks = max_seqlen // audio_latency_block_size audio_streaming_mask = torch.tril( torch.ones( audio_latency_nblocks, audio_latency_nblocks), diagonal=0, ).repeat_interleave(audio_latency_block_size, dim=0).repeat_interleave(audio_latency_block_size, dim=1) audio_streaming_mask = (1.0 - audio_streaming_mask) * torch.finfo(dtype).min - # audio_streaming_mask[audio_streaming_mask == 0] = float('-inf') - # audio_streaming_mask[audio_streaming_mask == 1] = 0 audio_streaming_mask = audio_streaming_mask[None, None, :, :] - self.register_buffer('audio_streaming_mask', audio_streaming_mask, persistent=False) + self.register_buffer("audio_streaming_mask", audio_streaming_mask, persistent=False) def forward( self, diff --git a/ultravox/model/ultravox_model_test.py b/ultravox/model/ultravox_model_test.py index 8b675de..01b0b22 100644 --- a/ultravox/model/ultravox_model_test.py +++ b/ultravox/model/ultravox_model_test.py @@ -4,49 +4,49 @@ from transformers.models.whisper import modeling_whisper as whisper from ultravox.model import ultravox_model -class TestModifiedWhisperEncoder: - @pytest.fixture - def encoder(self): - config = WhisperConfig( - max_source_positions=3000, - d_model=256, - encoder_attention_heads=4, - encoder_layers=4, - ) - return ultravox_model.ModifiedWhisperEncoder(config) - - def test_init_latency_mask_none(self, encoder): - encoder.init_latency_mask(None, torch.float32) - assert encoder.audio_streaming_mask is None - - def test_init_latency_mask_valid(self, encoder): - block_size = 50 - encoder.init_latency_mask(block_size, torch.float32) - assert encoder.audio_streaming_mask is not None - - assert len(encoder.audio_streaming_mask.shape) == 4 - assert encoder.audio_streaming_mask.shape[0] == 1 - assert encoder.audio_streaming_mask.shape[1] == 1 - - mask = encoder.audio_streaming_mask[0, 0] - # 50x60=3000 - source_mask = torch.tril(torch.ones(60, 60), diagonal=0).repeat_interleave(block_size, dim=0).repeat_interleave(block_size, dim=1) - source_mask = (1.0 - source_mask) * torch.finfo(torch.float32).min - assert torch.allclose(mask, source_mask) - - def test_init_latency_mask_invalid_block_size(self, encoder): - invalid_block_size = 13 - - with pytest.raises(AssertionError, match="must divide .* evenly"): - encoder.init_latency_mask(invalid_block_size, torch.float32) - - def test_init_latency_mask_different_dtypes(self, encoder): - block_size = 50 - for dtype in (torch.float32, torch.float16): - encoder.init_latency_mask(block_size, dtype) - assert encoder.audio_streaming_mask.min() == torch.finfo(dtype).min - - def test_init_latency_mask_persistence(self, encoder): - block_size = 50 - encoder.init_latency_mask(block_size, torch.float32) - assert 'audio_streaming_mask' in encoder._buffers +@pytest.fixture +def encoder(): + config = WhisperConfig( + max_source_positions=1500, + d_model=256, + encoder_attention_heads=4, + encoder_layers=4, + ) + return ultravox_model.ModifiedWhisperEncoder(config) + +def test_init_latency_mask_none(encoder): + encoder.init_latency_mask(None, torch.float32) + assert encoder.audio_streaming_mask is None + +def test_init_latency_mask_valid(encoder): + block_size = 100 + encoder.init_latency_mask(block_size, torch.float32) + assert encoder.audio_streaming_mask is not None + + assert len(encoder.audio_streaming_mask.shape) == 4 + assert encoder.audio_streaming_mask.shape[0] == 1 + assert encoder.audio_streaming_mask.shape[1] == 1 + + mask = encoder.audio_streaming_mask[0, 0] + # 50x60=3000 + source_mask = torch.tril(torch.ones(30, 30), diagonal=0).repeat_interleave(block_size, dim=0).repeat_interleave(block_size, dim=1) + source_mask = (1.0 - source_mask) * torch.finfo(torch.float32).min + print(mask.shape) + assert torch.allclose(mask, source_mask) + +def test_init_latency_mask_invalid_block_size(encoder): + invalid_block_size = 13 + + with pytest.raises(AssertionError, match="must divide .* evenly"): + encoder.init_latency_mask(invalid_block_size, torch.float32) + +def test_init_latency_mask_different_dtypes(encoder): + block_size = 50 + for dtype in (torch.float32, torch.float16): + encoder.init_latency_mask(block_size, dtype) + assert encoder.audio_streaming_mask.min() == torch.finfo(dtype).min + +def test_init_latency_mask_persistence(encoder): + block_size = 50 + encoder.init_latency_mask(block_size, torch.float32) + assert "audio_streaming_mask" in encoder._buffers diff --git a/ultravox/training/config_base.py b/ultravox/training/config_base.py index 7ee2b26..20b3396 100644 --- a/ultravox/training/config_base.py +++ b/ultravox/training/config_base.py @@ -106,7 +106,7 @@ def get_val_sets(self) -> List[DatasetOptions]: # loss function to use loss_config: Optional[ultravox_config.LossConfig] = None - # To simulate audio streaming with masking. None for non-causal, 50 for 1s, 100 for 2s, and so on. + # To simulate audio streaming with masking. None for non-causal, 100 for 1s, 200 for 2s, and so on. audio_latency_block_size: int = None def __post_init__(self): diff --git a/ultravox/training/configs/meta_config.yaml b/ultravox/training/configs/meta_config.yaml index 5bc0f5a..e7807bd 100644 --- a/ultravox/training/configs/meta_config.yaml +++ b/ultravox/training/configs/meta_config.yaml @@ -32,4 +32,4 @@ data_type: "bfloat16" report_logs_to: ["tensorboard", "wandb"] -audio_latency_block_size: null # None for non-causal, 50 for 1s, 100 for 2s, and so on. +audio_latency_block_size: null # null for non-causal, 100 for 1s, 200 for 2s, and so on. diff --git a/ultravox/training/configs/streaming_tinyllama.yaml b/ultravox/training/configs/streaming_tinyllama.yaml index 2ab3a6b..46d36db 100644 --- a/ultravox/training/configs/streaming_tinyllama.yaml +++ b/ultravox/training/configs/streaming_tinyllama.yaml @@ -1,4 +1,4 @@ -# SLM with ultravox & llama3.1, trained wtih knowledge distillation. + exp_name: "ultravox-streaming-experiments-1s" # Make sure to accept the license agreement on huggingface hub text_model: "meta-llama/Llama-3.2-1B-Instruct" @@ -23,4 +23,4 @@ val_sets: - name: peoplespeech batch_size: 24 max_steps: 10000 # x8x24 = 2,764,800 -audio_latency_block_size: null # None for non-causal, 50 for 1s, 100 for 2s, and so on. \ No newline at end of file +audio_latency_block_size: 100 # null for non-causal, 100 for 1s, 200 for 2s, and so on. \ No newline at end of file From 68ae6e07985fd1c1fb02fcefcedcf2575a65c67d Mon Sep 17 00:00:00 2001 From: Saeed Dehqan Date: Tue, 19 Nov 2024 21:47:42 +0000 Subject: [PATCH 11/15] reformat --- ultravox/model/ultravox_model.py | 59 +++++++++++++++++---------- ultravox/model/ultravox_model_test.py | 14 ++++++- 2 files changed, 49 insertions(+), 24 deletions(-) diff --git a/ultravox/model/ultravox_model.py b/ultravox/model/ultravox_model.py index dad66d1..09ed280 100644 --- a/ultravox/model/ultravox_model.py +++ b/ultravox/model/ultravox_model.py @@ -288,10 +288,13 @@ def _create_audio_tower( config.audio_model_id, torch_dtype=config.torch_dtype ) audio_tower.init_latency_mask( - config.audio_latency_block_size, dtype=config.torch_dtype) + config.audio_latency_block_size, dtype=config.torch_dtype + ) else: - assert config.audio_latency_block_size not in (None, 0), \ - "only whisper audio tower supports audio latency masking, got non-zero value for 'audio_latency_block_size'" + assert config.audio_latency_block_size not in ( + None, + 0, + ), "only whisper audio tower supports audio latency masking, got non-zero value for 'audio_latency_block_size'" audio_tower = transformers.AutoModel.from_pretrained( config.audio_model_id, torch_dtype=config.torch_dtype ) @@ -299,10 +302,13 @@ def _create_audio_tower( if "whisper" in config.audio_config._name_or_path: audio_tower = ModifiedWhisperEncoder(config.audio_config) audio_tower.init_latency_mask( - config.audio_latency_block_size, dtype=config.torch_dtype) + config.audio_latency_block_size, dtype=config.torch_dtype + ) else: - assert config.audio_latency_block_size not in (None, 0), \ - "only whisper audio tower supports audio latency masking, got non-zero value for 'audio_latency_block_size'" + assert config.audio_latency_block_size not in ( + None, + 0, + ), "only whisper audio tower supports audio latency masking, got non-zero value for 'audio_latency_block_size'" with transformers.modeling_utils.no_init_weights(): # we only ever use from_config if the weights are retrained, hence initializing is not # required. This makes the model quite creation faster since init on CPU is quite slow. @@ -544,24 +550,32 @@ def init_latency_mask(self, audio_latency_block_size: int, dtype: torch.float): return audio_streaming_mask = None # maximum sequence length - max_seqlen = (self.config.max_source_positions + max_seqlen = ( + self.config.max_source_positions * self.conv1.stride[0] * self.conv2.stride[0] ) - assert max_seqlen > 0, f"maximum sequence length must be positive, got {max_seqlen}" - assert max_seqlen % audio_latency_block_size == 0, ( - f"audio_latency_block_size {audio_latency_block_size} must divide {max_seqlen} evenly." - ) + assert ( + max_seqlen > 0 + ), f"maximum sequence length must be positive, got {max_seqlen}" + assert ( + max_seqlen % audio_latency_block_size == 0 + ), f"audio_latency_block_size {audio_latency_block_size} must divide {max_seqlen} evenly." # Given the block size, we calculate number of blocks. audio_latency_nblocks = max_seqlen // audio_latency_block_size - audio_streaming_mask = torch.tril( - torch.ones( - audio_latency_nblocks, audio_latency_nblocks), + audio_streaming_mask = ( + torch.tril( + torch.ones(audio_latency_nblocks, audio_latency_nblocks), diagonal=0, - ).repeat_interleave(audio_latency_block_size, dim=0).repeat_interleave(audio_latency_block_size, dim=1) + ) + .repeat_interleave(audio_latency_block_size, dim=0) + .repeat_interleave(audio_latency_block_size, dim=1) + ) audio_streaming_mask = (1.0 - audio_streaming_mask) * torch.finfo(dtype).min audio_streaming_mask = audio_streaming_mask[None, None, :, :] - self.register_buffer("audio_streaming_mask", audio_streaming_mask, persistent=False) + self.register_buffer( + "audio_streaming_mask", audio_streaming_mask, persistent=False + ) def forward( self, @@ -621,10 +635,9 @@ def forward( if audio_len != None: audio_feature_len = self._get_feat_extract_output_lengths(audio_len) max_seq_len = hidden_states.shape[1] - attention_mask = ( - torch.arange(max_seq_len, device=hidden_states.device)[None, :] - .lt(audio_feature_len.view(-1, 1)) - ) + attention_mask = torch.arange(max_seq_len, device=hidden_states.device)[ + None, : + ].lt(audio_feature_len.view(-1, 1)) attention_mask = self.get_extended_attention_mask( attention_mask, None, @@ -635,9 +648,11 @@ def forward( if self.audio_streaming_mask is not None: seqlen = hidden_states.size(-2) if attention_mask is not None: - attention_mask = torch.minimum(self.audio_streaming_mask[:,:,:seqlen, :seqlen], attention_mask) # merge + attention_mask = torch.minimum( + self.audio_streaming_mask[:, :, :seqlen, :seqlen], attention_mask + ) # merge else: - attention_mask = self.audio_streaming_mask[:,:,:seqlen, :seqlen] + attention_mask = self.audio_streaming_mask[:, :, :seqlen, :seqlen] attention_mask = attention_mask.to(hidden_states.dtype) # check if head_mask has a correct number of layers specified if desired diff --git a/ultravox/model/ultravox_model_test.py b/ultravox/model/ultravox_model_test.py index 01b0b22..51c467a 100644 --- a/ultravox/model/ultravox_model_test.py +++ b/ultravox/model/ultravox_model_test.py @@ -1,9 +1,10 @@ import pytest import torch from transformers import WhisperConfig -from transformers.models.whisper import modeling_whisper as whisper + from ultravox.model import ultravox_model + @pytest.fixture def encoder(): config = WhisperConfig( @@ -14,10 +15,12 @@ def encoder(): ) return ultravox_model.ModifiedWhisperEncoder(config) + def test_init_latency_mask_none(encoder): encoder.init_latency_mask(None, torch.float32) assert encoder.audio_streaming_mask is None + def test_init_latency_mask_valid(encoder): block_size = 100 encoder.init_latency_mask(block_size, torch.float32) @@ -29,23 +32,30 @@ def test_init_latency_mask_valid(encoder): mask = encoder.audio_streaming_mask[0, 0] # 50x60=3000 - source_mask = torch.tril(torch.ones(30, 30), diagonal=0).repeat_interleave(block_size, dim=0).repeat_interleave(block_size, dim=1) + source_mask = ( + torch.tril(torch.ones(30, 30), diagonal=0) + .repeat_interleave(block_size, dim=0) + .repeat_interleave(block_size, dim=1) + ) source_mask = (1.0 - source_mask) * torch.finfo(torch.float32).min print(mask.shape) assert torch.allclose(mask, source_mask) + def test_init_latency_mask_invalid_block_size(encoder): invalid_block_size = 13 with pytest.raises(AssertionError, match="must divide .* evenly"): encoder.init_latency_mask(invalid_block_size, torch.float32) + def test_init_latency_mask_different_dtypes(encoder): block_size = 50 for dtype in (torch.float32, torch.float16): encoder.init_latency_mask(block_size, dtype) assert encoder.audio_streaming_mask.min() == torch.finfo(dtype).min + def test_init_latency_mask_persistence(encoder): block_size = 50 encoder.init_latency_mask(block_size, torch.float32) From feb2394a048b67f95bd1f356a749774a42bf82e5 Mon Sep 17 00:00:00 2001 From: Saeed Dehqan <31902891+saeeddhqan@users.noreply.github.com> Date: Wed, 20 Nov 2024 01:15:36 -0500 Subject: [PATCH 12/15] Update ultravox/training/config_base.py Co-authored-by: Farzad Abdolhosseini --- ultravox/training/config_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ultravox/training/config_base.py b/ultravox/training/config_base.py index 20b3396..6a66b40 100644 --- a/ultravox/training/config_base.py +++ b/ultravox/training/config_base.py @@ -107,7 +107,7 @@ def get_val_sets(self) -> List[DatasetOptions]: loss_config: Optional[ultravox_config.LossConfig] = None # To simulate audio streaming with masking. None for non-causal, 100 for 1s, 200 for 2s, and so on. - audio_latency_block_size: int = None + audio_latency_block_size: Optional[int] = None def __post_init__(self): assert self.data_type in ["bfloat16", "float16", "float32"] From 93164a8b8a990e96a86de8762acfe2a62da36ff9 Mon Sep 17 00:00:00 2001 From: Saeed Dehqan <31902891+saeeddhqan@users.noreply.github.com> Date: Wed, 20 Nov 2024 01:16:26 -0500 Subject: [PATCH 13/15] Update ultravox/model/ultravox_model.py Co-authored-by: Farzad Abdolhosseini --- ultravox/model/ultravox_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ultravox/model/ultravox_model.py b/ultravox/model/ultravox_model.py index 09ed280..6d008e7 100644 --- a/ultravox/model/ultravox_model.py +++ b/ultravox/model/ultravox_model.py @@ -544,7 +544,7 @@ class ModifiedWhisperEncoder( base_model_prefix = "model.encoder" _no_split_modules = ["WhisperEncoderLayer"] - def init_latency_mask(self, audio_latency_block_size: int, dtype: torch.float): + def init_latency_mask(self, audio_latency_block_size: int, dtype: torch.dtype): if audio_latency_block_size is None: self.audio_streaming_mask = None return From 5e06b2a94fd78d8b4fa4338ea311bd3699497b93 Mon Sep 17 00:00:00 2001 From: Saeed Dehqan <31902891+saeeddhqan@users.noreply.github.com> Date: Wed, 20 Nov 2024 01:29:37 -0500 Subject: [PATCH 14/15] Update ultravox_model_test.py --- ultravox/model/ultravox_model_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ultravox/model/ultravox_model_test.py b/ultravox/model/ultravox_model_test.py index 51c467a..a0b95b2 100644 --- a/ultravox/model/ultravox_model_test.py +++ b/ultravox/model/ultravox_model_test.py @@ -31,7 +31,7 @@ def test_init_latency_mask_valid(encoder): assert encoder.audio_streaming_mask.shape[1] == 1 mask = encoder.audio_streaming_mask[0, 0] - # 50x60=3000 + # 100*30=3000 source_mask = ( torch.tril(torch.ones(30, 30), diagonal=0) .repeat_interleave(block_size, dim=0) From 8296295f45c065db2ba4b6787a788a2c8b689670 Mon Sep 17 00:00:00 2001 From: Saeed Dehqan Date: Wed, 20 Nov 2024 18:22:51 +0000 Subject: [PATCH 15/15] Solve mypy issue --- ultravox/model/ultravox_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ultravox/model/ultravox_model.py b/ultravox/model/ultravox_model.py index 6d008e7..1b0a575 100644 --- a/ultravox/model/ultravox_model.py +++ b/ultravox/model/ultravox_model.py @@ -548,7 +548,7 @@ def init_latency_mask(self, audio_latency_block_size: int, dtype: torch.dtype): if audio_latency_block_size is None: self.audio_streaming_mask = None return - audio_streaming_mask = None + # maximum sequence length max_seqlen = ( self.config.max_source_positions