-
Notifications
You must be signed in to change notification settings - Fork 2
/
dataset.py
68 lines (54 loc) · 2.34 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import tensorflow as tf
import numpy as np
# Reference : https://stackoverflow.com/questions/54567986/python-numpy-remove-empty-zeroes-border-of-3d-array
def bounds_per_dimension(ndarray):
return map(
lambda e: range(e.min(), e.max() + 1),
np.where(ndarray != 0)
)
def zero_trim_ndarray(ndarray):
return ndarray[np.ix_(*bounds_per_dimension(ndarray))]
# process ground-truth data for YOLO format
def process_each_ground_truth(original_image,
bbox,
class_labels,
input_width,
input_height
):
"""
Reference:
https://github.com/tensorflow/datasets/blob/master/tensorflow_datasets/object_detection/voc.py#L115
bbox return : (ymin / height, xmin / width, ymax / height, xmax / width)
Args:
original_image : (original_height, orignal_width, channel) image tensor
bbox : (max_object_num_in_batch, 4) = (ymin / height, xmin / width, ymax / height, xmax / width)
class_labels : (max_object_num_in_batch) = class labels without one-hot-encoding
input_width : yolo input width
input_height : yolo input height
Returns:
image: (resized_height, resized_width, channel) image ndarray
labels: 2-D list [object_num, 5] (xcenter (Absolute Coordinate), ycenter (Absolute Coordinate), w (Absolute Coordinate), h (Absolute Coordinate), class_num)
object_num: total object number in image
"""
image = original_image.numpy()
image = zero_trim_ndarray(image)
# set original width height
original_h = image.shape[0]
original_w = image.shape[1]
width_rate = input_width * 1.0 / original_w
height_rate = input_height * 1.0 / original_h
image = tf.image.resize(image, [input_height, input_width])
object_num = np.count_nonzero(bbox, axis=0)[0]
labels = [[0, 0, 0, 0, 0]] * object_num
for i in range(object_num):
xmin = bbox[i][1] * original_w
ymin = bbox[i][0] * original_h
xmax = bbox[i][3] * original_w
ymax = bbox[i][2] * original_h
class_num = class_labels[i]
xcenter = (xmin + xmax) * 1.0 / 2 * width_rate
ycenter = (ymin + ymax) * 1.0 / 2 * height_rate
box_w = (xmax - xmin) * width_rate
box_h = (ymax - ymin) * height_rate
labels[i] = [xcenter, ycenter, box_w, box_h, class_num]
return [image.numpy(), labels, object_num]