-
Notifications
You must be signed in to change notification settings - Fork 0
/
sprite_model.py
95 lines (77 loc) · 2.49 KB
/
sprite_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# -*- coding: utf-8 -*-
"""sprite_preprocessing_protomodel.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1mOBKUwc1RGFROZN4DNyUdWqCBnPfSzS9
"""
import PIL
import tensorflow as tf
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import pathlib
print(tf.__version__)
# image file path saved, quantity of images stored and printed
data_dir = pathlib.Path('images')
image_count = len(list(data_dir.glob('*/*.jpg')))
print(image_count)
#sprite file path saved, sprite image loaded
sprite = list(data_dir.glob('sprite/*'))
PIL.Image.open(str(sprite[1]))
#no-event file path saved, no event image loaded
noevent = list(data_dir.glob('noevent/*'))
PIL.Image.open(str(noevent[50]))
#creates a "training" tf.data.Dataset
#validation_split is an optional float between 0-1 reserved for validation
#seed is optional random seed for shuffling/transformations
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="training",
seed=123)
#creates a "training" tf.data.Dataset
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="validation",
seed=123)
class_names = train_ds.class_names
print(class_names)
num_classes = len(class_names)
# plt.figure(figsize=(10, 10))
# for images, labels in train_ds.take(1):
# for i in range(9):
# ax = plt.subplot(3, 3, i + 1)
# plt.imshow(images[i].numpy().astype("uint8"))
# plt.title(class_names[labels[i]])
# plt.axis("off")
# plt.show()
# for image_batch, labels_batch in train_ds:
# print(image_batch.shape)
# print(labels_batch.shape)
# break
AUTOTUNE = tf.data.experimental.AUTOTUNE
train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
#Preformance increase https://www.tensorflow.org/tutorials/load_data/images
model = tf.keras.Sequential([
layers.experimental.preprocessing.Rescaling(1./255),
layers.Conv2D(32, 3, activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(32, 3, activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(32, 3, activation='relu'),
layers.MaxPooling2D(),
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.Dense(num_classes)
])
model.compile(
optimizer='adam',
loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model.fit(
train_ds,
validation_data=val_ds,
epochs=4
)
prediction = model.predict(val_ds)