Skip to content

Commit

Permalink
monkey patch for HF hub error
Browse files Browse the repository at this point in the history
  • Loading branch information
farzadab committed Sep 24, 2024
1 parent fa88073 commit 35a8dd4
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 0 deletions.
39 changes: 39 additions & 0 deletions ultravox/training/helpers/hf_hub_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import huggingface_hub
from huggingface_hub import file_download
from huggingface_hub import hf_file_system


def _fetch_range(self, start: int, end: int) -> bytes:
"""
This is a copy of the original _fetch_range method from HfFileSystemFile.
The only modification is the addition of the 500 status code to the retry_on_status_codes tuple.
Original source code:
https://github.com/huggingface/huggingface_hub/blob/c0fd4e0f7519a4e3659c836081cc7e38c0d14b35/src/huggingface_hub/hf_file_system.py#L717
"""
headers = {
"range": f"bytes={start}-{end - 1}",
**self.fs._api._build_hf_headers(),
}
url = file_download.hf_hub_url(
repo_id=self.resolved_path.repo_id,
revision=self.resolved_path.revision,
filename=self.resolved_path.path_in_repo,
repo_type=self.resolved_path.repo_type,
endpoint=self.fs.endpoint,
)
r = hf_file_system.http_backoff(
"GET",
url,
headers=headers,
retry_on_status_codes=(500, 502, 503, 504), # add 500 to retry on server errors
timeout=huggingface_hub.constants.HF_HUB_DOWNLOAD_TIMEOUT,
)
hf_file_system.hf_raise_for_status(r)
return r.content


def monkey_patch_fetch_range():
import huggingface_hub

huggingface_hub.HfFileSystemFile._fetch_range = _fetch_range
4 changes: 4 additions & 0 deletions ultravox/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,12 @@
from ultravox.model import wandb_utils
from ultravox.training import config_base
from ultravox.training import ddp_utils
from ultravox.training.helpers import hf_hub_patch
from ultravox.training.helpers import prefetch_weights

# Patching HF Hub to avoid throwing an error on 500 dataset errors
hf_hub_patch.monkey_patch_fetch_range()

INPUT_EXAMPLE = {"text": "Transcribe\n<|audio|>", "audio": b"\x00\x00" * 16000}
OUTPUT_EXAMPLE = {"text": "Hello, world!"}

Expand Down

0 comments on commit 35a8dd4

Please sign in to comment.