-
Notifications
You must be signed in to change notification settings - Fork 82
/
demo.py
87 lines (67 loc) · 2.47 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import argparse
import pathlib
import numpy as np
import cv2
import time
import torch
import torch.nn as nn
from torch.autograd import Variable
from torchvision import transforms
import torch.backends.cudnn as cudnn
import torchvision
from PIL import Image
from PIL import Image, ImageOps
from face_detection import RetinaFace
from l2cs import select_device, draw_gaze, getArch, Pipeline, render
CWD = pathlib.Path.cwd()
def parse_args():
"""Parse input arguments."""
parser = argparse.ArgumentParser(
description='Gaze evalution using model pretrained with L2CS-Net on Gaze360.')
parser.add_argument(
'--device',dest='device', help='Device to run model: cpu or gpu:0',
default="cpu", type=str)
parser.add_argument(
'--snapshot',dest='snapshot', help='Path of model snapshot.',
default='output/snapshots/L2CS-gaze360-_loader-180-4/_epoch_55.pkl', type=str)
parser.add_argument(
'--cam',dest='cam_id', help='Camera device id to use [0]',
default=0, type=int)
parser.add_argument(
'--arch',dest='arch',help='Network architecture, can be: ResNet18, ResNet34, ResNet50, ResNet101, ResNet152',
default='ResNet50', type=str)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
cudnn.enabled = True
arch=args.arch
cam = args.cam_id
# snapshot_path = args.snapshot
gaze_pipeline = Pipeline(
weights=CWD / 'models' / 'L2CSNet_gaze360.pkl',
arch='ResNet50',
device = select_device(args.device, batch_size=1)
)
cap = cv2.VideoCapture(cam)
# Check if the webcam is opened correctly
if not cap.isOpened():
raise IOError("Cannot open webcam")
with torch.no_grad():
while True:
# Get frame
success, frame = cap.read()
start_fps = time.time()
if not success:
print("Failed to obtain frame")
time.sleep(0.1)
# Process frame
results = gaze_pipeline.step(frame)
# Visualize output
frame = render(frame, results)
myFPS = 1.0 / (time.time() - start_fps)
cv2.putText(frame, 'FPS: {:.1f}'.format(myFPS), (10, 20),cv2.FONT_HERSHEY_COMPLEX_SMALL, 1, (0, 255, 0), 1, cv2.LINE_AA)
cv2.imshow("Demo",frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
success,frame = cap.read()