-
Notifications
You must be signed in to change notification settings - Fork 5
/
helper_functions.py
74 lines (60 loc) · 2.58 KB
/
helper_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch as T
from PIL import Image, ImageOps
import os
from joblib import Parallel, delayed
images_folder = "data/Original/Images"
def get_cuda(tensor):
if T.cuda.is_available():
tensor = tensor.cuda()
return tensor
def get_image(opt, info, istrain):
from_w, from_h, ind = info[0], info[1], info[2]
if istrain:
filename = opt.train_files[ind][0][0]
else:
filename = opt.test_files[ind][0][0]
img = Image.open(os.path.join(images_folder, filename)).convert('RGB')
width, height = img.size
size = min(width, height)
img = img.crop((from_w, from_h, from_w + size, from_h + size))
return img, size
def retina(l, info, opt, istrain):
l = l.cpu().numpy()
def extract_patches_batch(l, size): #Extract square patches for given batch with given location as center and given size as length
batch_size = len(l)
def get_patch(i): #Get patch for each datapoint in a batch
img, imgsize = get_image(opt, info[i], istrain) #Get context image
patch_size = imgsize // 4
patch_size *= size #original size of patch before compressing it to 96x96
#location resized from [-1,1] to [image_size, image_size]
l_denorm = (0.5 * imgsize * (1 + l[i])).astype(int)
from_x, from_y = l_denorm[0] - (patch_size // 2), l_denorm[1] - (patch_size // 2)
to_x, to_y = from_x + patch_size, from_y + patch_size
#pad context image if corners of the patch exceeds its borders
if (from_x < 0 or from_y < 0 or to_x > imgsize or to_y > imgsize):
temp = patch_size // 2 + 1
img = ImageOps.expand(img, border=temp, fill='black')
from_x += temp
from_y += temp
to_x += temp
to_y += temp
img = img.crop((from_x, from_y, to_x, to_y))
img = opt.my_transform(img).unsqueeze(0)
return img
patches = Parallel(n_jobs=opt.n_jobs, backend="threading")(
delayed(get_patch)(i) for i in range(batch_size) #Parallelize get_patch function as its execution for each datapoint in batch is independent of others
)
patches = get_cuda(T.cat(patches, dim=0))
return patches
phi = []
size = opt.start_size
for i in range(opt.k):
phi.append(extract_patches_batch(l, size))
size *= 2
return phi
def get_color(i):
if i == 0:
return "red"
elif i == 1:
return "blue"
return "green"