-
Notifications
You must be signed in to change notification settings - Fork 2
/
data.py
67 lines (55 loc) · 2.18 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from torchvision.datasets import CIFAR10
from torch.utils.data import Dataset
from torchvision.transforms import RandomCrop, Resize, RandomHorizontalFlip
from torchvision.transforms import ToTensor, Normalize, Compose
from PIL import Image
import numpy as np
import glob
class CIFARSubset(Dataset):
def __init__(self, data_path, label, n_tracks):
dataset = CIFAR10(data_path)
targets = np.array(dataset.targets)
indices = np.where(targets == dataset.class_to_idx[label])
self.data = [dataset.data[indices] for i in indices][0]
self.data = self.data[:n_tracks]
self.transform = get_transform(-1, -1, False, True)
def __getitem__(self, idx):
return self.transform(self.data[idx])
def __len__(self):
return len(self.data)
class ImageNetSubset(Dataset):
def __init__(self, data_path, label, n_tracks):
dataset = sorted(glob.glob(data_path + '/{}/*.jpg'.format(label)))
self.data = []
for idx, file in enumerate(dataset):
im = Image.open(file).convert('RGB')
self.data.append(im.copy())
im.close()
if idx + 1 >= n_tracks:
break
self.transform = get_transform(160, 128, True, True)
def __getitem__(self, idx):
return self.transform(self.data[idx])
def __len__(self):
return len(self.data)
def get_datasets(type, data_path, label_a, label_b, n_tracks=5000):
if type == "cifar":
dataset_a = CIFARSubset(data_path, label_a, n_tracks)
dataset_b = CIFARSubset(data_path, label_b, n_tracks)
elif type == "imagenet":
dataset_a = ImageNetSubset(data_path, label_a, n_tracks)
dataset_b = ImageNetSubset(data_path, label_b, n_tracks)
return dataset_a, dataset_b
def get_transform(resize, cropsize, flip, normalize):
options = []
if resize > 0:
options.append(Resize(resize))
if cropsize > 0:
options.append(RandomCrop(cropsize))
if flip:
options.append(RandomHorizontalFlip(0.5))
options.append(ToTensor())
if normalize:
options.append(Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
transform = Compose(options)
return transform