-
Notifications
You must be signed in to change notification settings - Fork 448
/
video_demo.py
108 lines (88 loc) · 3.64 KB
/
video_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# Copyright (c) Tencent Inc. All rights reserved.
# This file is modifef from mmyolo/demo/video_demo.py
import argparse
import cv2
import mmcv
import torch
from mmengine.dataset import Compose
from mmdet.apis import init_detector
from mmengine.utils import track_iter_progress
from mmyolo.registry import VISUALIZERS
def parse_args():
parser = argparse.ArgumentParser(description='YOLO-World video demo')
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument('video', help='video file path')
parser.add_argument(
'text',
help=
'text prompts, including categories separated by a comma or a txt file with each line as a prompt.'
)
parser.add_argument('--device',
default='cuda:0',
help='device used for inference')
parser.add_argument('--score-thr',
default=0.1,
type=float,
help='confidence score threshold for predictions.')
parser.add_argument('--out', type=str, help='output video file')
args = parser.parse_args()
return args
def inference_detector(model, image, texts, test_pipeline, score_thr=0.3):
data_info = dict(img_id=0, img=image, texts=texts)
data_info = test_pipeline(data_info)
data_batch = dict(inputs=data_info['inputs'].unsqueeze(0),
data_samples=[data_info['data_samples']])
with torch.no_grad():
output = model.test_step(data_batch)[0]
pred_instances = output.pred_instances
pred_instances = pred_instances[pred_instances.scores.float() >
score_thr]
output.pred_instances = pred_instances
return output
def main():
args = parse_args()
model = init_detector(args.config, args.checkpoint, device=args.device)
# build test pipeline
model.cfg.test_dataloader.dataset.pipeline[
0].type = 'mmdet.LoadImageFromNDArray'
test_pipeline = Compose(model.cfg.test_dataloader.dataset.pipeline)
if args.text.endswith('.txt'):
with open(args.text) as f:
lines = f.readlines()
texts = [[t.rstrip('\r\n')] for t in lines] + [[' ']]
else:
texts = [[t.strip()] for t in args.text.split(',')] + [[' ']]
# reparameterize texts
model.reparameterize(texts)
# init visualizer
visualizer = VISUALIZERS.build(model.cfg.visualizer)
# the dataset_meta is loaded from the checkpoint and
# then pass to the model in init_detector
visualizer.dataset_meta = model.dataset_meta
video_reader = mmcv.VideoReader(args.video)
video_writer = None
if args.out:
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
video_writer = cv2.VideoWriter(
args.out, fourcc, video_reader.fps,
(video_reader.width, video_reader.height))
for frame in track_iter_progress(video_reader):
result = inference_detector(model,
frame,
texts,
test_pipeline,
score_thr=args.score_thr)
visualizer.add_datasample(name='video',
image=frame,
data_sample=result,
draw_gt=False,
show=False,
pred_score_thr=args.score_thr)
frame = visualizer.get_image()
if args.out:
video_writer.write(frame)
if video_writer:
video_writer.release()
if __name__ == '__main__':
main()