forked from cvat-ai/cvat
-
Notifications
You must be signed in to change notification settings - Fork 2
/
model_handler.py
78 lines (64 loc) · 3.34 KB
/
model_handler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
import jsonpickle
import numpy as np
import torch
from pysot_toolkit.bbox import get_axis_aligned_bbox
from pysot_toolkit.trackers.net_wrappers import NetWithBackbone
from pysot_toolkit.trackers.tracker import Tracker
class ModelHandler:
def __init__(self):
use_gpu = torch.cuda.is_available()
net_path = '/transt.pth' # Absolute path of the model
net = NetWithBackbone(net_path=net_path, use_gpu=use_gpu)
self.tracker = Tracker(name='transt', net=net, window_penalty=0.49, exemplar_size=128, instance_size=256)
def decode_state(self, state):
self.tracker.net.net.zf = jsonpickle.decode(state['model.net.net.zf'])
self.tracker.net.net.pos_template = jsonpickle.decode(state['model.net.net.pos_template'])
self.tracker.window = jsonpickle.decode(state['model.window'])
self.tracker.center_pos = jsonpickle.decode(state['model.center_pos'])
self.tracker.size = jsonpickle.decode(state['model.size'])
self.tracker.channel_average = jsonpickle.decode(state['model.channel_average'])
self.tracker.mean = jsonpickle.decode(state['model.mean'])
self.tracker.std = jsonpickle.decode(state['model.std'])
self.tracker.inplace = jsonpickle.decode(state['model.inplace'])
self.tracker.features_initialized = False
if 'model.features_initialized' in state:
self.tracker.features_initialized = jsonpickle.decode(state['model.features_initialized'])
def encode_state(self):
state = {}
state['model.net.net.zf'] = jsonpickle.encode(self.tracker.net.net.zf)
state['model.net.net.pos_template'] = jsonpickle.encode(self.tracker.net.net.pos_template)
state['model.window'] = jsonpickle.encode(self.tracker.window)
state['model.center_pos'] = jsonpickle.encode(self.tracker.center_pos)
state['model.size'] = jsonpickle.encode(self.tracker.size)
state['model.channel_average'] = jsonpickle.encode(self.tracker.channel_average)
state['model.mean'] = jsonpickle.encode(self.tracker.mean)
state['model.std'] = jsonpickle.encode(self.tracker.std)
state['model.inplace'] = jsonpickle.encode(self.tracker.inplace)
state['model.features_initialized'] = jsonpickle.encode(getattr(self.tracker, 'features_initialized', False))
return state
def init_tracker(self, img, bbox):
cx, cy, w, h = get_axis_aligned_bbox(np.array(bbox))
gt_bbox_ = [cx - w / 2, cy - h / 2, w, h]
init_info = {'init_bbox': gt_bbox_}
self.tracker.initialize(img, init_info)
def track(self, img):
outputs = self.tracker.track(img)
prediction_bbox = outputs['target_bbox']
left = prediction_bbox[0]
top = prediction_bbox[1]
right = prediction_bbox[0] + prediction_bbox[2]
bottom = prediction_bbox[1] + prediction_bbox[3]
return (left, top, right, bottom)
def infer(self, image, shape, state):
if state is None:
init_shape = (shape[0], shape[1], shape[2] - shape[0], shape[3] - shape[1])
self.init_tracker(image, init_shape)
state = self.encode_state()
else:
self.decode_state(state)
shape = self.track(image)
state = self.encode_state()
return shape, state