-
Notifications
You must be signed in to change notification settings - Fork 38
/
convert_to_hf.py
40 lines (33 loc) · 1.7 KB
/
convert_to_hf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from argparse import ArgumentParser
from contrastors.models.biencoder import BiEncoder, BiEncoderConfig
from contrastors.models.dual_encoder import DualEncoder, DualEncoderConfig
from contrastors.models.huggingface import NomicBertConfig, NomicBertForPreTraining, NomicVisionModel
def parse_args():
parser = ArgumentParser()
parser.add_argument("--ckpt_path", type=str, required=True)
parser.add_argument("--model_name", type=str, required=True)
parser.add_argument("--private", action="store_true")
parser.add_argument("--biencoder", action="store_true")
parser.add_argument("--vision", action="store_true")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
if args.biencoder:
config = BiEncoderConfig.from_pretrained(args.ckpt_path)
model = BiEncoder.from_pretrained(args.ckpt_path, config=config)
model = model.trunk
elif args.vision:
NomicBertConfig.register_for_auto_class()
NomicVisionModel.register_for_auto_class("AutoModel")
config = DualEncoderConfig.from_pretrained(args.ckpt_path)
model = DualEncoder.from_pretrained(args.ckpt_path, config=config)
vision = model.vision
hf_config = NomicBertConfig(**model.vision.trunk.config.to_dict())
model = NomicVisionModel(hf_config)
state_dict = vision.state_dict()
state_dict = {k.replace("trunk.", ""): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
else:
config = NomicBertConfig.from_pretrained(args.ckpt_path)
model = NomicBertForPreTraining.from_pretrained(args.ckpt_path, config=config)
model.push_to_hub(args.model_name, private=args.private, use_temp_dir=False)