forked from hzwer/ECCV2022-RIFE
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
109 lines (101 loc) · 4.24 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import cv2
import ast
import torch
import numpy as np
import random
from torch.utils.data import DataLoader, Dataset
cv2.setNumThreads(1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class VimeoDataset(Dataset):
def __init__(self, dataset_name, batch_size=32):
self.batch_size = batch_size
self.dataset_name = dataset_name
self.h = 256
self.w = 448
self.data_root = 'vimeo_triplet'
self.image_root = os.path.join(self.data_root, 'sequences')
train_fn = os.path.join(self.data_root, 'tri_trainlist.txt')
test_fn = os.path.join(self.data_root, 'tri_testlist.txt')
with open(train_fn, 'r') as f:
self.trainlist = f.read().splitlines()
with open(test_fn, 'r') as f:
self.testlist = f.read().splitlines()
self.load_data()
def __len__(self):
return len(self.meta_data)
def load_data(self):
cnt = int(len(self.trainlist) * 0.95)
if self.dataset_name == 'train':
self.meta_data = self.trainlist[:cnt]
elif self.dataset_name == 'test':
self.meta_data = self.testlist
else:
self.meta_data = self.trainlist[cnt:]
def crop(self, img0, gt, img1, h, w):
ih, iw, _ = img0.shape
x = np.random.randint(0, ih - h + 1)
y = np.random.randint(0, iw - w + 1)
img0 = img0[x:x+h, y:y+w, :]
img1 = img1[x:x+h, y:y+w, :]
gt = gt[x:x+h, y:y+w, :]
return img0, gt, img1
def getimg(self, index):
imgpath = os.path.join(self.image_root, self.meta_data[index])
imgpaths = [imgpath + '/im1.png', imgpath + '/im2.png', imgpath + '/im3.png']
# Load images
img0 = cv2.imread(imgpaths[0])
gt = cv2.imread(imgpaths[1])
img1 = cv2.imread(imgpaths[2])
timestep = 0.5
return img0, gt, img1, timestep
# RIFEm with Vimeo-Septuplet
# imgpaths = [imgpath + '/im1.png', imgpath + '/im2.png', imgpath + '/im3.png', imgpath + '/im4.png', imgpath + '/im5.png', imgpath + '/im6.png', imgpath + '/im7.png']
# ind = [0, 1, 2, 3, 4, 5, 6]
# random.shuffle(ind)
# ind = ind[:3]
# ind.sort()
# img0 = cv2.imread(imgpaths[ind[0]])
# gt = cv2.imread(imgpaths[ind[1]])
# img1 = cv2.imread(imgpaths[ind[2]])
# timestep = (ind[1] - ind[0]) * 1.0 / (ind[2] - ind[0] + 1e-6)
def __getitem__(self, index):
img0, gt, img1, timestep = self.getimg(index)
if self.dataset_name == 'train':
img0, gt, img1 = self.crop(img0, gt, img1, 224, 224)
if random.uniform(0, 1) < 0.5:
img0 = img0[:, :, ::-1]
img1 = img1[:, :, ::-1]
gt = gt[:, :, ::-1]
if random.uniform(0, 1) < 0.5:
img0 = img0[::-1]
img1 = img1[::-1]
gt = gt[::-1]
if random.uniform(0, 1) < 0.5:
img0 = img0[:, ::-1]
img1 = img1[:, ::-1]
gt = gt[:, ::-1]
if random.uniform(0, 1) < 0.5:
tmp = img1
img1 = img0
img0 = tmp
timestep = 1 - timestep
# random rotation
p = random.uniform(0, 1)
if p < 0.25:
img0 = cv2.rotate(img0, cv2.ROTATE_90_CLOCKWISE)
gt = cv2.rotate(gt, cv2.ROTATE_90_CLOCKWISE)
img1 = cv2.rotate(img1, cv2.ROTATE_90_CLOCKWISE)
elif p < 0.5:
img0 = cv2.rotate(img0, cv2.ROTATE_180)
gt = cv2.rotate(gt, cv2.ROTATE_180)
img1 = cv2.rotate(img1, cv2.ROTATE_180)
elif p < 0.75:
img0 = cv2.rotate(img0, cv2.ROTATE_90_COUNTERCLOCKWISE)
gt = cv2.rotate(gt, cv2.ROTATE_90_COUNTERCLOCKWISE)
img1 = cv2.rotate(img1, cv2.ROTATE_90_COUNTERCLOCKWISE)
img0 = torch.from_numpy(img0.copy()).permute(2, 0, 1)
img1 = torch.from_numpy(img1.copy()).permute(2, 0, 1)
gt = torch.from_numpy(gt.copy()).permute(2, 0, 1)
timestep = torch.tensor(timestep).reshape(1, 1, 1)
return torch.cat((img0, img1, gt), 0), timestep