Skip to content

Commit

Permalink
Fix: fix test.py syntax error, fix global_step not change issue, fix …
Browse files Browse the repository at this point in the history
…infinite loop
  • Loading branch information
Armour committed Jul 3, 2018
1 parent 0e52c2c commit 9046ada
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,35 +18,39 @@


if __name__ == '__main__':
# Init model.
is_training, global_step, _, loss, predict_rgb, color_image_rgb, gray_image, file_paths = init_model(train=False)
# Init model
is_training, _, _, loss, predict_rgb, color_image_rgb, gray_image, file_paths = init_model(train=False)

# Init scaffold, hooks and config.
# Init scaffold, hooks and config
scaffold = tf.train.Scaffold()
summary_hook = tf.train.SummarySaverHook(output_dir=testing_summary, save_steps=display_step, scaffold=scaffold)
checkpoint_hook = tf.train.CheckpointSaverHook(checkpoint_dir=summary_path, save_steps=saving_step, scaffold=scaffold)
num_step_hook = tf.train.StopAtStepHook(num_steps=len(file_paths))
config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=True, gpu_options=(tf.GPUOptions(allow_growth=True))
config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=True, gpu_options=(tf.GPUOptions(allow_growth=True)))
session_creator = tf.train.ChiefSessionCreator(scaffold=scaffold, config=config, checkpoint_dir=summary_path)

# Create a session for running operations in the Graph.
# Create a session for running operations in the Graph
with tf.train.MonitoredSession(session_creator=session_creator, hooks=[checkpoint_hook, summary_hook]) as sess:
print("🤖 Start testing...")
step = 0
avg_loss = 0

while not sess.should_stop():
# Get global_step.
step, l, pred, color, gray = sess.run([global_step, loss, predict_rgb, color_image_rgb, gray_image], feed_dict={is_training: False})
step += 1

l, pred, color, gray = sess.run([loss, predict_rgb, color_image_rgb, gray_image], feed_dict={is_training: False})

if step % display_step == 0:
# Print batch loss.
print("📖 Iter %d, Minibatch Loss = %f" % (step, l))
print("📖 Testing iter %d, Minibatch Loss = %f" % (step, l))
avg_loss += float(l)

# Save testing image.
summary_image = concat_images(gray[0], pred[0])
summary_image = concat_images(summary_image, color[0])
plt.imsave("%s/images/%d.png" % (testing_summary, step), summary_image)

if step >= len(file_paths) / batch_size:
break

print("🎉 Testing finished!")
print("👀 Total average loss: %f" % (avg_loss / len(file_paths)))

0 comments on commit 9046ada

Please sign in to comment.