-
Notifications
You must be signed in to change notification settings - Fork 2
/
train.py
169 lines (147 loc) · 5.41 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import io
import json
import tarfile
import time
from collections import namedtuple
from dataclasses import asdict
import httpx
import tqdm
from cog import BaseModel, Input, Path, Secret
from huggingface_hub import (
HfApi,
get_hf_file_metadata,
hf_hub_url,
)
from huggingface_hub._login import _login as hf_login
from huggingface_hub.utils import filter_repo_objects
from predict import PredictorConfig
class TrainingOutput(BaseModel):
weights: Path
def train(
hf_model_id: str = Input(
description="""
Hugging Face model identifier
(e.g. NousResearch/Hermes-2-Theta-Llama-3-8B).
""",
),
hf_model_sha: str = Input(
description="""
The version of the model.
If unspecified, the latest version is used.
""",
default=None,
),
hf_token: Secret = Input(
description="""
Hugging Face API token.
Get your token at https://huggingface.co/settings/tokens
""",
default=None,
),
allow_patterns: str = Input(
description="""
Patterns constituting the allowlist.
If provided, item paths must match at least one pattern from the allowlist.
(e.g. "*.safetensors").
""",
default=None,
),
ignore_patterns: str = Input(
description="""
Patterns constituting the denylist.
If provided, item paths must not match any patterns from the denylist.
(e.g. "*.gguf").
""",
default="*.gguf",
),
prompt_template: str = Input(
description="""
Prompt template. This is a Jinja2 template that overrides the
HuggingFace tokenizer configuration. If this is set to None and nothing
is configured on HuggingFace, no formatting is applied.
To override HuggingFace configuration, set it to the string
`{{messages[0]['content']}}`.""",
default=None,
),
) -> TrainingOutput:
if hf_token is not None and isinstance(hf_token, Secret):
print("Logging in to Hugging Face Hub...")
hf_token = hf_token.get_secret_value().strip()
hf_login(token=hf_token, add_to_git_credential=False)
else:
print("No HuggingFace token provided.")
api = HfApi()
# Fetch the model info
model = api.model_info(
hf_model_id, revision=hf_model_sha, token=hf_token, files_metadata=True
)
print(f"Using model {model.id} with SHA {model.sha}")
# Determine which files to download
files = list(
filter_repo_objects(
items=[f.rfilename for f in model.siblings],
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
)
)
if len(files) == 0:
raise ValueError("No files to download")
Entry = namedtuple("Entry", ["filename", "url", "metadata"])
entries = [
Entry(
filename=x,
url=hf_hub_url(repo_id=hf_model_id, filename=x),
metadata=get_hf_file_metadata(
hf_hub_url(repo_id=hf_model_id, filename=x), token=hf_token
),
)
for x in files
]
config = PredictorConfig(prompt_template=prompt_template)
start = time.time()
print(f"Downloading {len(files)} files...")
# Download the files and write them to a tar file
weights = Path("model.tar")
with tarfile.open(name=str(weights), mode="w:") as tar:
# Add predictor_config.json
predictor_config_data = json.dumps(asdict(config)).encode("utf-8")
tar_info = tarfile.TarInfo("predictor_config.json")
tar_info.mtime = int(time.time())
tar_info.size = len(predictor_config_data)
tar.addfile(tar_info, fileobj=io.BytesIO(predictor_config_data))
with tqdm.tqdm(
total=sum(entry.metadata.size for entry in entries),
unit="B",
unit_divisor=1024,
unit_scale=True,
mininterval=1,
) as pbar:
headers = {"Authorization": f"Bearer {hf_token}"}
with httpx.Client(
headers=headers, follow_redirects=True, timeout=None
) as client:
for n, entry in enumerate(entries, start=1):
pbar.update(0)
pbar.set_postfix(
n=f"{n}/{len(entries)}",
file=entry.filename,
refresh=True,
)
with client.stream("GET", entry.url) as response:
response.raise_for_status()
with io.BytesIO() as buffer:
for chunk in response.iter_bytes(chunk_size=32 * 1024):
buffer.write(chunk)
pbar.update(len(chunk))
pbar.set_postfix(
n=f"{n}/{len(entries)}",
file=entry.filename,
refresh=False,
)
buffer.seek(0)
tar_info = tarfile.TarInfo(entry.filename)
tar_info.mtime = int(time.time())
tar_info.size = entry.metadata.size
tar.addfile(tar_info, fileobj=buffer)
print(f"Downloaded {len(files)} files in {time.time() - start:.2f} seconds")
return TrainingOutput(weights=weights)