forked from onnx/models
-
Notifications
You must be signed in to change notification settings - Fork 0
/
imagenet_preprocess.py
61 lines (52 loc) · 2.03 KB
/
imagenet_preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# SPDX-License-Identifier: Apache-2.0
import numpy as np
from PIL import Image
import mxnet
from mxnet.gluon.data.vision import transforms
def preprocess(image):
# resize so that the shorter side is 256, maintaining aspect ratio
def image_resize(image, min_len):
image = Image.fromarray(image)
ratio = float(min_len) / min(image.size[0], image.size[1])
if image.size[0] > image.size[1]:
new_size = (int(round(ratio * image.size[0])), min_len)
else:
new_size = (min_len, int(round(ratio * image.size[1])))
image = image.resize(new_size, Image.BILINEAR)
return np.array(image)
image = image_resize(image, 256)
# Crop centered window 224x224
def crop_center(image, crop_w, crop_h):
h, w, c = image.shape
start_x = w//2 - crop_w//2
start_y = h//2 - crop_h//2
return image[start_y:start_y+crop_h, start_x:start_x+crop_w, :]
image = crop_center(image, 224, 224)
# transpose
image = image.transpose(2, 0, 1)
# convert the input data into the float32 input
img_data = image.astype('float32')
# normalize
mean_vec = np.array([0.485, 0.456, 0.406])
stddev_vec = np.array([0.229, 0.224, 0.225])
norm_img_data = np.zeros(img_data.shape).astype('float32')
for i in range(img_data.shape[0]):
norm_img_data[i,:,:] = (img_data[i,:,:]/255 - mean_vec[i]) / stddev_vec[i]
# add batch channel
norm_img_data = norm_img_data.reshape(1, 3, 224, 224).astype('float32')
return norm_img_data
# Pre-processing function for ImageNet models
def preprocess_mxnet(img):
'''
Preprocessing required on the images for inference with mxnet gluon
The function takes path to an image and returns processed tensor
'''
transform_fn = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
img = transform_fn(img)
img = img.expand_dims(axis=0) # batchify
return img