-
Notifications
You must be signed in to change notification settings - Fork 0
/
generator.py
141 lines (115 loc) · 5.55 KB
/
generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import os.path
import json
import scipy.misc
import numpy as np
import matplotlib.pyplot as plt
import skimage
import random
# In this exercise task you will implement an image generator. Generator objects in python are defined as having a next function.
# This next function returns the next generated object. In our case it returns the input of a neural network each time it gets called.
# This input consists of a batch of images and its corresponding labels.
class ImageGenerator:
def __init__(self, file_path, label_path, batch_size, image_size, rotation=False, mirroring=False, shuffle=False):
# Define all members of your generator class object as global members here.
# These need to include:
# the batch size
# the image size
# flags for different augmentations and whether the data should be shuffled for each epoch
# Also depending on the size of your data-set you can consider loading all images into memory here already.
# The labels are stored in json format and can be directly loaded as dictionary.
# Note that the file names correspond to the dicts of the label dictionary.
self.class_dict = {0: 'airplane', 1: 'automobile', 2: 'bird', 3: 'cat', 4: 'deer', 5: 'dog', 6: 'frog',
7: 'horse', 8: 'ship', 9: 'truck'}
self.file_path = os.path.join('./data/', file_path)
self.image_filenames = [filename for filename in os.listdir(self.file_path)]
self.label_path = os.path.join('./data/', label_path)
with open(self.label_path, 'r') as file:
self.labels = json.load(file)
self.batch_size = batch_size
self.image_size = image_size
self.rotation = rotation
self.mirroring = mirroring
self.shuffle = shuffle
# if self.shuffle:
# random.shuffle(self.image_filenames)
self.current_index = 0
self.epoch_index = 0
# self.end_epoch = True
self.Num_batches = len(self.image_filenames) // self.batch_size
def next(self):
# This function creates a batch of images and corresponding labels and returns them.
# In this context a "batch" of images just means a bunch, say 10 images that are forwarded at once.
# Note that your amount of total data might not be divisible without remainder with the batch_size.
# Think about how to handle such cases
# if self.end_epoch:
# self.current_index = 0
# self.end_epoch = False
# Shuffle the filenames at the beginning of each epoch if shuffle is enabled
if self.current_index == 0 and self.shuffle:
random.shuffle(self.image_filenames)
images = [] # Batch of images
labels = [] # Array with corresponding labels
# Create a batch
for _ in range(self.batch_size):
# Check if we have reached the end of the dataset
if self.current_index >= len(self.image_filenames):
# Reset to the beginning and increment epoch index
self.current_index = 0
self.epoch_index += 1
if self.shuffle:
random.shuffle(self.image_filenames) # Shuffle again for the new epoch
# Load the image and label if there are still images left
file_name = self.image_filenames[self.current_index]
src = np.load(f"{self.file_path}/{file_name}")
images.append(self.augment(src))
labels.append(self.labels[file_name.replace('.npy', '')])
self.current_index += 1
# Return the batch as numpy arrays
return np.array(images), np.array(labels)
def augment(self,img):
# this function takes a single image as an input and performs a random transformation
# (mirroring and/or rotation) on it and outputs the transformed image
if img.shape != self.image_size:
img = skimage.transform.resize(img, self.image_size)
if self.mirroring:
mirroring_tyoe = random.choice(("lr", "ud"))
if mirroring_tyoe == "lr":
img = np.fliplr(img)
else:
img = np.flipud(img)
if self.rotation:
rotation_type = np.random.choice((90, 180, 270))
if rotation_type == 90:
img = np.rot90(img, 1)
elif rotation_type == 180:
img = np.rot90(img, 2)
elif rotation_type == 270:
img = np.rot90(img, 3)
return img
def current_epoch(self):
# return the current epoch number
return self.epoch_index
def class_name(self, x):
# This function returns the class name for a specific input
return self.labels[str(x)]
def show(self):
# In order to verify that the generator creates batches as required, this functions calls next to get a
# batch of images and labels and visualizes it.
images, labels = self.next()
cols = 3
rows = self.batch_size//3 + 1
fig, ax = plt.subplots(rows, cols)
axes = ax.flatten() #2D array to one dimensional
for i in range(self.batch_size):
img = images[i]
lab = self.class_dict[labels[i]]
ax = axes[i]
ax.imshow(img)
ax.set_xticks([])
ax.set_yticks([])
ax.set_title(lab)
# unused subplots
for j in range(self.batch_size, len(axes)):
axes[j].axis('off')
plt.tight_layout()
plt.show()