Skip to content

Commit

Permalink
Support KL loss (#63)
Browse files Browse the repository at this point in the history
- Implement the (response) KL loss described in the BLSP-KD paper: https://arxiv.org/abs/2405.19041.
- Add an option to configure loss function in the config file, as shown below:

loss_config:
  # Choose from ["KL_Divergence", "CrossEntropy"], default is "KL_Divergence"
  loss_function: "KL_Divergence"
  • Loading branch information
zqhuang211 authored Aug 5, 2024
1 parent 7202cfc commit ecd58c4
Show file tree
Hide file tree
Showing 8 changed files with 171 additions and 13 deletions.
4 changes: 2 additions & 2 deletions mcloud.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name: ultravox
image: mosaicml/composer:latest
compute:
gpus: 8
cluster: r7z22p1
cluster: r14z3p1
integrations:
- integration_type: git_repo
git_repo: fixie-ai/ultravox
Expand All @@ -14,4 +14,4 @@ command: >-
env_variables:
MLFLOW_TRACKING_URI: databricks
UV_BRANCH: main
TRAIN_ARGS: --config_path ultravox/training/configs/llama3_whisper.yaml
TRAIN_ARGS: --config_path ultravox/training/configs/llama3_whisper_kd.yaml
22 changes: 19 additions & 3 deletions ultravox/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,27 @@

@dataclasses.dataclass
class DataCollatorForSeq2SeqWithAudio(transformers.DataCollatorForSeq2Seq):
# when enabled, the alt_input_ids, alt_attention_mask, and alt_labels fields are used for computing the KL loss in UltravoxModel
include_alt_fields: bool = False

def __call__(self, features, *args, **kwargs):
audio_values = [f.pop("audio_values", None) for f in features]
if self.include_alt_fields:
# these fields are hard-coded in the transformer data collator, so they need special handling before calling the super method
alt_features = [
{
"input_ids": f.pop("alt_input_ids"),
"attention_mask": f.pop("alt_attention_mask"),
"labels": f.pop("alt_labels"),
}
for f in features
]
batch = super().__call__(features, *args, **kwargs)
if self.include_alt_fields:
alt_batch = super().__call__(alt_features, *args, **kwargs)
batch["alt_input_ids"] = alt_batch["input_ids"]
batch["alt_attention_mask"] = alt_batch["attention_mask"]
batch["alt_labels"] = alt_batch["labels"]

# Pad the last dimension of all audio_values to the same length, with 0s on the right.
if audio_values and audio_values[0] is not None:
Expand Down Expand Up @@ -433,15 +451,13 @@ class AnyInstructDataset(VoiceDataset):

def __init__(self, args: VoiceDatasetArgs) -> None:
# TODO(juberti): convert to MDS
# The last 7 samples are missing audio files, so we exclude them.
NUM_SAMPLES = 108193 - 7
super().__init__(args)
dataset = datasets.load_dataset(
"json",
"anyinstruct",
data_files="https://huggingface.co/datasets/fnlp/AnyInstruct/resolve/main/speech_conv/metadata.jsonl",
split="train",
).select(range(NUM_SAMPLES))
)
dataset = dataset.train_test_split(
test_size=0.01, seed=args.shuffle_seed, shuffle=True
)
Expand Down
35 changes: 32 additions & 3 deletions ultravox/model/data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def __init__(
processor: ultravox_processing.UltravoxProcessor,
train_on_inputs: bool = False,
inference_mode: bool = False,
include_alt_fields: bool = False,
) -> None:
"""
Pre-processing for the Ultravox model: applies tokenization and audio processing using the UltravoxProcessor
Expand All @@ -28,13 +29,16 @@ def __init__(
inference_mode: If True, only the input message is included in input_ids and labels, and the assistant
message is removed from the sample. This is used for inference (e.g. testing) since the model should
generate the assistant message. For training and validation, this should be False.
include_alt_fields: If True, the alt_input_ids, alt_attention_mask, and alt_labels are included in the output,
computed with <|audio|> replaced by the audio transcript.
"""
super().__init__(dataset)
self.processor = processor
self.train_on_inputs = train_on_inputs
self.inference_mode = inference_mode
if self.inference_mode:
self.train_on_inputs = True
self.include_alt_fields = include_alt_fields

def _process(self, sample: datasets.VoiceSample) -> Dict[str, Any]:
if self.inference_mode:
Expand Down Expand Up @@ -84,15 +88,40 @@ def _process(self, sample: datasets.VoiceSample) -> Dict[str, Any]:

# TODO: this might be slow due to calling audio_processor twice. We can compute modified input_text_len directly too.
# Revisit when using WhisperProcessor.
input_text_len = self.processor(
input_token_len = self.processor(
text=input_text,
audio=audio,
sampling_rate=sample.sample_rate,
)["input_ids"].shape[-1]
labels[:input_text_len] = -100
labels[:input_token_len] = -100

# If include_alt_fields is True, also include alt_input_ids, alt_attention_mask, and alt_labels
if self.include_alt_fields:
# sample.audio_transcript should never be None but currently not gauranteed, need to be investigated.
alt_text = text.replace("<|audio|>", sample.audio_transcript or "")

alt_inputs = self.processor(
text=alt_text,
audio=None,
return_tensors="pt",
)
alt_input_ids = alt_inputs["input_ids"].squeeze_(0)
alt_inputs["attention_mask"].squeeze_(0)

alt_labels = alt_input_ids.clone()
if not self.train_on_inputs:
alt_input_token_len = (
input_token_len + len(alt_input_ids) - len(input_ids)
)
alt_labels[:alt_input_token_len] = -100

inputs["alt_input_ids"] = alt_input_ids
inputs["alt_attention_mask"] = alt_inputs["attention_mask"]
inputs["alt_labels"] = alt_labels

return {
**inputs,
# input_ids, attention_mask, audio_values, audio_token_start_idx, audio_token_len
# if include_alt_fields is True, also include alt_input_ids, alt_attention_mask, alt_labels
**inputs,
"labels": labels,
}
16 changes: 16 additions & 0 deletions ultravox/model/ultravox_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import dataclasses
from enum import Enum
from typing import Any, Dict, List, Optional

import transformers
Expand All @@ -20,6 +21,21 @@ class LoraConfigSimplified:
)


class LossFunction(str, Enum):
CrossEntropy = "ce"
KL_Divergence = "kl"


@dataclasses.dataclass
class LossConfig:
loss_function: LossFunction = LossFunction.KL_Divergence
kl_temperature: float = 2.0

@property
def requires_alt_fields(self):
return self.loss_function == LossFunction.KL_Divergence


class UltravoxConfig(transformers.PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`UltravoxForConditionalGeneration`]. It is used to instantiate an
Expand Down
67 changes: 65 additions & 2 deletions ultravox/model/ultravox_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

# We must use relative import in this directory to allow uploading to HF Hub
# Even "from . import X" pattern doesn't work (undocumented and unclear why)
from .ultravox_config import LossConfig
from .ultravox_config import LossFunction
from .ultravox_config import UltravoxConfig
from .whisper_model_modified import WhisperEncoder as ModifiedWhisperEncoder

Expand Down Expand Up @@ -52,6 +54,7 @@ def __init__(self, config: UltravoxConfig):
self.multi_modal_projector = UltravoxProjector(config)
self.language_model = self._create_language_model(config)

self.loss_config = LossConfig()
self.post_init()

def get_input_embeddings(self):
Expand All @@ -75,6 +78,9 @@ def get_decoder(self):
def tie_weights(self):
return self.language_model.tie_weights()

def set_loss_config(self, loss_config: LossConfig):
self.loss_config = loss_config

def _setup_cache(
self, cache_cls, max_batch_size: int, max_cache_len: Optional[int] = None
):
Expand All @@ -97,6 +103,42 @@ def resize_token_embeddings(
self.vocab_size = model_embeds.num_embeddings
return model_embeds

def _compute_kl_loss(
self,
lm_output: transformers.modeling_outputs.CausalLMOutputWithPast,
labels: Optional[torch.Tensor] = None,
past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
alt_input_ids: Optional[torch.Tensor] = None,
alt_attention_mask: Optional[torch.Tensor] = None,
alt_labels: Optional[torch.Tensor] = None,
**kwargs,
):
# disable gradient computation for the teacher model
with torch.no_grad():
# compute the teacher (text-only) model's distribution
alt_inputs_embeds = self.get_input_embeddings().forward(alt_input_ids)
alt_lm_output = self.language_model.forward(
inputs_embeds=alt_inputs_embeds,
labels=alt_labels,
attention_mask=alt_attention_mask,
past_key_values=past_key_values,
**kwargs,
)
# compute the KL divergence loss between the two models
kl_loss = F.kl_div(
F.log_softmax(
lm_output.logits[labels != -100] / self.loss_config.kl_temperature,
dim=-1,
),
F.softmax(
alt_lm_output.logits[alt_labels != -100]
/ self.loss_config.kl_temperature,
dim=-1,
),
reduction="batchmean",
)
return {"loss": kl_loss}

def forward(
self,
input_ids: torch.Tensor,
Expand All @@ -107,6 +149,10 @@ def forward(
audio_token_start_idx: Optional[torch.Tensor] = None,
audio_token_len: Optional[torch.Tensor] = None,
past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
# the alt_* fields are needed for KL divergence loss
alt_input_ids: Optional[torch.Tensor] = None,
alt_attention_mask: Optional[torch.Tensor] = None,
alt_labels: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[Tuple, transformers.modeling_outputs.CausalLMOutputWithPast]:
"""
Expand Down Expand Up @@ -162,8 +208,25 @@ def forward(
past_key_values=past_key_values,
**kwargs,
)

return lm_output
if self.training:
if self.loss_config.loss_function == LossFunction.CrossEntropy:
return lm_output
elif self.loss_config.loss_function == LossFunction.KL_Divergence:
return self._compute_kl_loss(
lm_output=lm_output,
labels=labels,
past_key_values=past_key_values,
alt_input_ids=alt_input_ids,
alt_attention_mask=alt_attention_mask,
alt_labels=alt_labels,
**kwargs,
)
else:
raise ValueError(
f"Unsupported loss function: {self.loss_config.loss_function}"
)
else:
return lm_output

def prepare_inputs_for_generation(
self,
Expand Down
3 changes: 3 additions & 0 deletions ultravox/training/config_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ class TrainConfig:
# A list of tags for filtering runs. Only used for wandb.
run_tags: List[str] = simple_parsing.list_field()

# loss function to use
loss_config: Optional[ultravox_config.LossConfig] = None

def __post_init__(self):
assert self.data_type in ["bfloat16", "float16", "float32"]
if self.device == "cuda" and not torch.cuda.is_available():
Expand Down
19 changes: 19 additions & 0 deletions ultravox/training/configs/llama3_whisper_kd.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# SLM with ultravox & llama3, trained wtih knowledge distillation.
exp_name: "llama3_whisper_s"

# Make sure to accept the license agreement on huggingface hub
text_model: "meta-llama/Meta-Llama-3-8B-Instruct"
audio_model: "openai/whisper-small"


loss_config:
# Choose from ["KL_Divergence", "CrossEntropy"], default is "KL_Divergence"
loss_function: "KL_Divergence"

# Temporarily remove heysquad_human from val_sets as it causes the training to fail.
val_sets: ["anyinstruct", "soda", "peoplespeech"]

batch_size: 4
max_steps: 1000

data_sets: ["gigaspeech", "anyinstruct", "soda"]
18 changes: 15 additions & 3 deletions ultravox/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,16 @@ def prepare_dataset(
train_on_inputs: bool,
repeat_data: bool,
num_samples: Optional[int] = None,
include_alt_fields: bool = False, # whether to generate tensors for text-only input (e.g., used for KD training)
) -> data.IterableDataset:

data_sets = [datasets.create_dataset(ds, data_args) for ds in dataset_names]
interleave = datasets.InterleaveDataset(data_sets, repeat=repeat_data)
ds_with_proc = data_processing.UltravoxDataproc(
interleave, processor=processor, train_on_inputs=train_on_inputs
interleave,
processor=processor,
train_on_inputs=train_on_inputs,
include_alt_fields=include_alt_fields,
)
limited_ds = datasets.Range(ds_with_proc, num_samples=num_samples)
return limited_ds
Expand Down Expand Up @@ -121,6 +125,10 @@ def main() -> None:
# https://github.com/huggingface/transformers/issues/17116#issuecomment-1121340890
model.audio_tower.config.layerdrop = 0.0

# loss_config needs to be passed separately just for model training
if args.loss_config is not None:
model.set_loss_config(args.loss_config)

logging.info("Model and processor instantiated.")

# Starting W&B. HF Trainer can also do this, but this way we can include the config.
Expand Down Expand Up @@ -191,6 +199,7 @@ def main() -> None:
use_mds=args.mds,
mds_batch_size=args.batch_size,
),
include_alt_fields=model.loss_config.requires_alt_fields,
)
val_ds_args = datasets.VoiceDatasetArgs(
num_prompts=1,
Expand All @@ -211,6 +220,7 @@ def main() -> None:
processor=processor,
num_samples=args.val_num_samples,
data_args=val_ds_args_text if k.startswith("text_") else val_ds_args,
include_alt_fields=model.loss_config.requires_alt_fields,
)
for k in val_sets
}
Expand All @@ -224,10 +234,12 @@ def main() -> None:
val_datasets = {k: datasets.EmptyDataset() for k in val_sets}

# Set up the data loader
data_collator = datasets.DataCollatorForSeq2SeqWithAudio(tokenizer=text_tokenizer)
data_collator = datasets.DataCollatorForSeq2SeqWithAudio(
tokenizer=text_tokenizer,
include_alt_fields=model.loss_config.requires_alt_fields,
)

logging.info(f"Config Params: {args}")

trainer = transformers.Seq2SeqTrainer(
model,
train_dataset=train_dataset,
Expand Down

0 comments on commit ecd58c4

Please sign in to comment.