Skip to content

Commit

Permalink
refine demo
Browse files Browse the repository at this point in the history
  • Loading branch information
CoinCheung committed Jul 3, 2022
1 parent 83ab7a5 commit 97ea008
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 13 deletions.
7 changes: 4 additions & 3 deletions tools/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np
import cv2

import lib.transform_cv2 as T
import lib.data.transform_cv2 as T
from lib.models import model_factory
from configs import set_cfg_from_file

Expand All @@ -34,7 +34,7 @@
palette = np.random.randint(0, 256, (256, 3), dtype=np.uint8)

# define model
net = model_factory[cfg.model_type](cfg.n_cats, aux_mode='pred')
net = model_factory[cfg.model_type](cfg.n_cats, aux_mode='eval')
net.load_state_dict(torch.load(args.weight_path, map_location='cpu'), strict=False)
net.eval()
net.cuda()
Expand All @@ -53,8 +53,9 @@

# inference
im = F.interpolate(im, size=new_size, align_corners=False, mode='bilinear')
out = net(im)
out = net(im)[0]
out = F.interpolate(out, size=org_size, align_corners=False, mode='bilinear')
out = out.argmax(dim=1)

# visualize
out = out.squeeze().detach().cpu().numpy()
Expand Down
23 changes: 13 additions & 10 deletions tools/demo_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.multiprocessing import Process, Queue
import torch.multiprocessing as mp
import time
from PIL import Image
import numpy as np
import cv2

import lib.transform_cv2 as T
import lib.data.transform_cv2 as T
from lib.models import model_factory
from configs import set_cfg_from_file

Expand Down Expand Up @@ -40,7 +40,7 @@ def get_model():


# fetch frames
def get_func(inpth, in_q):
def get_func(inpth, in_q, done):
cap = cv2.VideoCapture(args.input)
width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) # type is float
height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) # type is float
Expand All @@ -59,7 +59,8 @@ def get_func(inpth, in_q):
in_q.put(frame)

in_q.put('quit')
while not in_q.empty(): continue
done.wait()

cap.release()
time.sleep(1)
print('input queue done')
Expand Down Expand Up @@ -105,14 +106,15 @@ def infer_batch(frames):


if __name__ == '__main__':
torch.multiprocessing.set_start_method('spawn')
mp.set_start_method('spawn')

in_q = Queue(1024)
out_q = Queue(1024)
in_q = mp.Queue(1024)
out_q = mp.Queue(1024)
done = mp.Event()

in_worker = Process(target=get_func,
args=(args.input, in_q))
out_worker = Process(target=save_func,
in_worker = mp.Process(target=get_func,
args=(args.input, in_q, done))
out_worker = mp.Process(target=save_func,
args=(args.input, args.output, out_q))

in_worker.start()
Expand All @@ -133,6 +135,7 @@ def infer_batch(frames):
infer_batch(frames)

out_q.put('quit')
done.set()

out_worker.join()
in_worker.join()

0 comments on commit 97ea008

Please sign in to comment.