Skip to content

Commit

Permalink
Make sure splits have the same columns in audioextensiontask
Browse files Browse the repository at this point in the history
  • Loading branch information
liPatrick committed Sep 13, 2024
1 parent 633b3ca commit 9cdc08b
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions ultravox/tools/ds_tool/ds_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,12 +231,15 @@ def _map_sample_repeat(self, sample):
repeated_audio = np.tile(audio_data, self.multiplier)
repeated_sentence = " ".join([sentence] * self.multiplier)
repeated_translation = " ".join([translation] * self.multiplier)
sample[self.audio_column_name]["array"] = repeated_audio
sample[self.audio_column_name].pop("path")
sample[self.asr_column_name] = repeated_sentence
sample[self.translation_column_name] = repeated_translation

return sample
new_sample = {}
new_sample[self.audio_column_name]["array"] = repeated_audio
new_sample[self.audio_column_name].pop("path")
new_sample[self.asr_column_name] = repeated_sentence
new_sample[self.translation_column_name] = repeated_translation
new_sample[self.id_column_name] = sample[self.id_column_name]

return new_sample

def _map_batch_combine(self, batch):
audios = batch[self.audio_column_name]
Expand Down Expand Up @@ -446,7 +449,7 @@ def _upload(self, ds_chunk_processed: datasets.Dataset, data_dir: str, split_nam
"split": split_name,
}
assert isinstance(self.args.upload_name, str)
try:
try:
ds_split_chunked.push_to_hub(self.args.upload_name, **hub_args)
except Exception as e:
print(f"Failed to upload chunk to hub: {e}")
Expand Down

0 comments on commit 9cdc08b

Please sign in to comment.