-
Notifications
You must be signed in to change notification settings - Fork 0
/
package_model.py
77 lines (62 loc) · 2.37 KB
/
package_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import copy
import torch
import time
import os
import models
from utils import load_config
class Compile:
def __init__(self, cfg, device, half_precision):
self.cfg = cfg
self.device = device
self.half_precision = half_precision
if device == "cpu" and half_precision:
raise ValueError(f"Half precision does not work with device {device}")
# load model
checkpoint_dir = vars(cfg.model)
checkpoint_dir = [f"{key}[{checkpoint_dir[key]}]" for key in checkpoint_dir]
checkpoint_dir = "-".join(checkpoint_dir)
checkpoint_dir = os.path.join("checkpoints", checkpoint_dir)
checkpoint_path = os.path.join(checkpoint_dir, "ckpt.pth")
model = models.load(cfg)
model.load_checkpoint(checkpoint_path, device=cfg.training.device, train=False)
model.to(self.device)
model.eval()
self.model = model.half() if self.half_precision else model
def run(self):
dummy = torch.rand(1, 3, cfg.data.input_size, cfg.data.input_size).to(self.device)
dummy = dummy.half() if self.half_precision else dummy
self.module = torch.jit.trace(self.model.forward, dummy)
def save(self):
torch.jit.save(self.module, './packaged.pth')
def load(self):
self.loaded_module = torch.jit.load(f='./packaged.pth', map_location=self.device)
self.loaded_module.eval()
def test(self):
dummy = torch.rand(32, 3, cfg.data.input_size, cfg.data.input_size).to(self.device)
dummy = dummy.half() if self.half_precision else dummy
t0 = time.time()
with torch.no_grad():
out_jl = self.loaded_module(copy.deepcopy(dummy)).detach()
t1 = time.time()
with torch.no_grad():
out_nj = self.model.forward(dummy).detach()
t2 = time.time()
print(
f"Out size: {out_nj.size()}\n"
"Out diff:", (out_nj - out_jl).abs().mean().item(), "\n",
"\n",
f" Torch JIT: {t1 - t0}\n",
f"Torch Model: {t2 - t1}\n",
"--------------------------"
)
if __name__ == "__main__":
# import config
cfg = load_config("./config.yml")
instance = Compile(cfg, "cuda", half_precision=True)
instance.run()
instance.save()
instance.load()
instance.test()
instance.test()
instance.test()
instance.test()