-
Notifications
You must be signed in to change notification settings - Fork 0
/
batch_image_generator.py
35 lines (26 loc) · 1.18 KB
/
batch_image_generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import random
import numpy as np
from matplotlib import image as mpimg
class BatchImageGenerator:
def __init__(self, image_augmentor):
self.image_augmentor = image_augmentor
def batch_generator(self, image_paths, steering_angles, batch_size, is_training):
"""
The Batch Generator allows us to generate augmented images on the fly, when needed.
"""
while True:
batch_img = []
batch_steering = []
for i in range(batch_size):
random_index = random.randint(0, len(image_paths) - 1)
if is_training:
im, steering, aug_type = \
self.image_augmentor.random_augment(image_paths[random_index],
steering_angles[random_index])
else:
im = mpimg.imread(image_paths[random_index])
steering = steering_angles[random_index]
im = self.image_augmentor.image_preprocess(im)
batch_img.append(im)
batch_steering.append(steering)
yield (np.asarray(batch_img), np.asarray(batch_steering))