-
Notifications
You must be signed in to change notification settings - Fork 0
/
vizualize_samples.py
88 lines (63 loc) · 2.64 KB
/
vizualize_samples.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""
"""
import os
import numpy as np
from PIL import Image
from numpy import load
import matplotlib.pyplot as plt
import matplotlib as mpl
import argparse
import fnmatch
def save_images(im,cl,DIR,n):
print('saving %s' %os.path.join(DIR,cl + '_' + str(n+1).zfill(3) + '.tif'))
im.save(os.path.join(DIR,cl + '_' + str(n).zfill(3) + '.tif'))
def main():
parser = argparse.ArgumentParser(description='ShearDetect')
parser.add_argument('--file', default=None, help='path to *.npz-file')
parser.add_argument('--file_dir', default=None, help='path to directory with *.npz-files')
# parser.add_argument('--dir', default=None, help='path for saving images')
args = parser.parse_args()
## get input data
FILE = args.file
FILE_DIR = args.file_dir
##=====================
## Check User Input
##=====================
if FILE is None and FILE_DIR is not None:
FILES = list(sorted(fnmatch.filter(os.listdir(FILE_DIR),"*.npz")))
DIR = FILE_DIR
print(FILES)
elif FILE is not None and FILE_DIR is None and os.path.splitext(FILE)[-1] == '.npz':
FILES = [os.path.basename(FILE)]
DIR = ROOT_DIR = os.path.dirname(os.path.abspath(FILE))
else:
FILES = None
print("you must provide a *.npz file or a folder containing *.npz files for vizualization")
##=====================
## Sampling
##=====================
if FILES is not None:
for sample in FILES:
print('creating images from %s' %(sample))
data = load(os.path.join(DIR,sample))
imgs = [img for img in data['arr_0']]
cls = [os.path.split(sample)[-1].split('_')[1]]*len(imgs)
rows = int(np.rint(np.sqrt(len(imgs))))
cols = int(np.rint(len(imgs)/rows))
fig = plt.figure(figsize=(1.25*cols, 1.375*rows))
for n, (img, cl) in enumerate(zip(imgs,cls)):
if img.shape[2] == 1:
img = img.reshape(img.shape[0],img.shape[1])
ax = fig.add_subplot(rows,cols,n+1)
im = Image.fromarray(img)
ax.imshow(im.convert('RGB'),cmap='gray')
ax.set_title(cl)
ax.axis('off')
if not os.path.exists(os.path.join(DIR,cl)):
os.makedirs(os.path.join(DIR,cl))
save_images(im,cl,os.path.join(DIR,cl),n)
# fig.tight_layout(pad=20)
plt.show()
if __name__ == "__main__":
main()