-
Notifications
You must be signed in to change notification settings - Fork 1
/
encoder_initial.py
42 lines (36 loc) · 1.35 KB
/
encoder_initial.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
def model_preparing(model_path, save_path):
"""only save the AR encoder params to initialize the NAR drafter's encoder"""
key_l = []
raw_model = torch.load(model_path)
for key in raw_model['model']:
if key.startswith('decoder'):
key_l.append(key)
for key in key_l:
del raw_model['model'][key]
print('*' * 100)
for key in raw_model['model']:
print(key)
torch.save(raw_model, save_path)
def param_checking(model1, model2):
"""check the parameters of the AR verifier and the NAR drafter"""
key_l1 = []
key_l2 = []
raw_model1 = torch.load(model1)
for key in raw_model1['model']:
key_l1.append(key)
raw_model2 = torch.load(model2)
for key in raw_model2['model']:
key_l2.append(key)
print(key_l1)
print(key_l2)
# print(raw_model1['model']['encoder.embed_positions.weight'].size())
# print(raw_model2['model']['encoder.embed_positions.weight'].size())
for k1, k2 in zip(key_l1, key_l2):
if k1 != k2:
print(k1)
print(k2)
if __name__ == "__main__":
AR_path = './checkpoints/wmt14-en-de-base-at-verifier.pt' # the dir that contains AR verifier checkpoint
save_path = './checkpoints/initial_checkpoint.pt' # the save dir of your fairseq NAR drafter checkpoints
model_preparing(AR_path, save_path)