From 35a8dd40c10bbbfa6a9d68c06a5f9deae4adb0d1 Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Tue, 24 Sep 2024 14:45:33 -0700 Subject: [PATCH] monkey patch for HF hub error --- ultravox/training/helpers/hf_hub_patch.py | 39 +++++++++++++++++++++++ ultravox/training/train.py | 4 +++ 2 files changed, 43 insertions(+) create mode 100644 ultravox/training/helpers/hf_hub_patch.py diff --git a/ultravox/training/helpers/hf_hub_patch.py b/ultravox/training/helpers/hf_hub_patch.py new file mode 100644 index 00000000..bd072eeb --- /dev/null +++ b/ultravox/training/helpers/hf_hub_patch.py @@ -0,0 +1,39 @@ +import huggingface_hub +from huggingface_hub import file_download +from huggingface_hub import hf_file_system + + +def _fetch_range(self, start: int, end: int) -> bytes: + """ + This is a copy of the original _fetch_range method from HfFileSystemFile. + The only modification is the addition of the 500 status code to the retry_on_status_codes tuple. + + Original source code: + https://github.com/huggingface/huggingface_hub/blob/c0fd4e0f7519a4e3659c836081cc7e38c0d14b35/src/huggingface_hub/hf_file_system.py#L717 + """ + headers = { + "range": f"bytes={start}-{end - 1}", + **self.fs._api._build_hf_headers(), + } + url = file_download.hf_hub_url( + repo_id=self.resolved_path.repo_id, + revision=self.resolved_path.revision, + filename=self.resolved_path.path_in_repo, + repo_type=self.resolved_path.repo_type, + endpoint=self.fs.endpoint, + ) + r = hf_file_system.http_backoff( + "GET", + url, + headers=headers, + retry_on_status_codes=(500, 502, 503, 504), # add 500 to retry on server errors + timeout=huggingface_hub.constants.HF_HUB_DOWNLOAD_TIMEOUT, + ) + hf_file_system.hf_raise_for_status(r) + return r.content + + +def monkey_patch_fetch_range(): + import huggingface_hub + + huggingface_hub.HfFileSystemFile._fetch_range = _fetch_range diff --git a/ultravox/training/train.py b/ultravox/training/train.py index 5dde3f62..414a75c4 100644 --- a/ultravox/training/train.py +++ b/ultravox/training/train.py @@ -29,8 +29,12 @@ from ultravox.model import wandb_utils from ultravox.training import config_base from ultravox.training import ddp_utils +from ultravox.training.helpers import hf_hub_patch from ultravox.training.helpers import prefetch_weights +# Patching HF Hub to avoid throwing an error on 500 dataset errors +hf_hub_patch.monkey_patch_fetch_range() + INPUT_EXAMPLE = {"text": "Transcribe\n<|audio|>", "audio": b"\x00\x00" * 16000} OUTPUT_EXAMPLE = {"text": "Hello, world!"}