Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove unused parameter "bn_mode" #79

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 19 additions & 19 deletions pix2pix/src/model/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def lambda_output(input_shape):
return input_shape[:2]


# def conv_block_unet(x, f, name, bn_mode, bn_axis, bn=True, dropout=False, strides=(2,2)):
# def conv_block_unet(x, f, name, bn_axis, bn=True, dropout=False, strides=(2,2)):

# x = Conv2D(f, (3, 3), strides=strides, name=name, padding="same")(x)
# if bn:
Expand All @@ -33,7 +33,7 @@ def lambda_output(input_shape):
# return x


# def up_conv_block_unet(x1, x2, f, name, bn_mode, bn_axis, bn=True, dropout=False):
# def up_conv_block_unet(x1, x2, f, name, bn_axis, bn=True, dropout=False):

# x1 = UpSampling2D(size=(2, 2))(x1)
# x = merge([x1, x2], mode="concat", concat_axis=bn_axis)
Expand All @@ -47,7 +47,7 @@ def lambda_output(input_shape):

# return x

def conv_block_unet(x, f, name, bn_mode, bn_axis, bn=True, strides=(2,2)):
def conv_block_unet(x, f, name, bn_axis, bn=True, strides=(2,2)):

x = LeakyReLU(0.2)(x)
x = Conv2D(f, (3, 3), strides=strides, name=name, padding="same")(x)
Expand All @@ -57,7 +57,7 @@ def conv_block_unet(x, f, name, bn_mode, bn_axis, bn=True, strides=(2,2)):
return x


def up_conv_block_unet(x, x2, f, name, bn_mode, bn_axis, bn=True, dropout=False):
def up_conv_block_unet(x, x2, f, name, bn_axis, bn=True, dropout=False):

x = Activation("relu")(x)
x = UpSampling2D(size=(2, 2))(x)
Expand All @@ -71,7 +71,7 @@ def up_conv_block_unet(x, x2, f, name, bn_mode, bn_axis, bn=True, dropout=False)
return x


def deconv_block_unet(x, x2, f, h, w, batch_size, name, bn_mode, bn_axis, bn=True, dropout=False):
def deconv_block_unet(x, x2, f, h, w, batch_size, name, bn_axis, bn=True, dropout=False):

o_shape = (batch_size, h * 2, w * 2, f)
x = Activation("relu")(x)
Expand All @@ -85,7 +85,7 @@ def deconv_block_unet(x, x2, f, h, w, batch_size, name, bn_mode, bn_axis, bn=Tru
return x


def generator_unet_upsampling(img_dim, bn_mode, model_name="generator_unet_upsampling"):
def generator_unet_upsampling(img_dim, model_name="generator_unet_upsampling"):

nb_filters = 64

Expand All @@ -109,7 +109,7 @@ def generator_unet_upsampling(img_dim, bn_mode, model_name="generator_unet_upsam
strides=(2, 2), name="unet_conv2D_1", padding="same")(unet_input)]
for i, f in enumerate(list_nb_filters[1:]):
name = "unet_conv2D_%s" % (i + 2)
conv = conv_block_unet(list_encoder[-1], f, name, bn_mode, bn_axis)
conv = conv_block_unet(list_encoder[-1], f, name, bn_axis)
list_encoder.append(conv)

# Prepare decoder filters
Expand All @@ -119,15 +119,15 @@ def generator_unet_upsampling(img_dim, bn_mode, model_name="generator_unet_upsam

# Decoder
list_decoder = [up_conv_block_unet(list_encoder[-1], list_encoder[-2],
list_nb_filters[0], "unet_upconv2D_1", bn_mode, bn_axis, dropout=True)]
list_nb_filters[0], "unet_upconv2D_1", bn_axis, dropout=True)]
for i, f in enumerate(list_nb_filters[1:]):
name = "unet_upconv2D_%s" % (i + 2)
# Dropout only on first few layers
if i < 2:
d = True
else:
d = False
conv = up_conv_block_unet(list_decoder[-1], list_encoder[-(i + 3)], f, name, bn_mode, bn_axis, dropout=d)
conv = up_conv_block_unet(list_decoder[-1], list_encoder[-(i + 3)], f, name, bn_axis, dropout=d)
list_decoder.append(conv)

x = Activation("relu")(list_decoder[-1])
Expand All @@ -140,7 +140,7 @@ def generator_unet_upsampling(img_dim, bn_mode, model_name="generator_unet_upsam
return generator_unet


def generator_unet_deconv(img_dim, bn_mode, batch_size, model_name="generator_unet_deconv"):
def generator_unet_deconv(img_dim, batch_size, model_name="generator_unet_deconv"):

assert K.backend() == "tensorflow", "Not implemented with theano backend"

Expand All @@ -162,7 +162,7 @@ def generator_unet_deconv(img_dim, bn_mode, batch_size, model_name="generator_un
h, w = h / 2, w / 2
for i, f in enumerate(list_nb_filters[1:]):
name = "unet_conv2D_%s" % (i + 2)
conv = conv_block_unet(list_encoder[-1], f, name, bn_mode, bn_axis)
conv = conv_block_unet(list_encoder[-1], f, name, bn_axis)
list_encoder.append(conv)
h, w = h / 2, w / 2

Expand All @@ -174,7 +174,7 @@ def generator_unet_deconv(img_dim, bn_mode, batch_size, model_name="generator_un
# Decoder
list_decoder = [deconv_block_unet(list_encoder[-1], list_encoder[-2],
list_nb_filters[0], h, w, batch_size,
"unet_upconv2D_1", bn_mode, bn_axis, dropout=True)]
"unet_upconv2D_1", bn_axis, dropout=True)]
h, w = h * 2, w * 2
for i, f in enumerate(list_nb_filters[1:]):
name = "unet_upconv2D_%s" % (i + 2)
Expand All @@ -184,7 +184,7 @@ def generator_unet_deconv(img_dim, bn_mode, batch_size, model_name="generator_un
else:
d = False
conv = deconv_block_unet(list_decoder[-1], list_encoder[-(i + 3)], f, h,
w, batch_size, name, bn_mode, bn_axis, dropout=d)
w, batch_size, name, bn_axis, dropout=d)
list_decoder.append(conv)
h, w = h * 2, w * 2

Expand All @@ -198,7 +198,7 @@ def generator_unet_deconv(img_dim, bn_mode, batch_size, model_name="generator_un
return generator_unet


def DCGAN_discriminator(img_dim, nb_patch, bn_mode, model_name="DCGAN_discriminator", use_mbd=True):
def DCGAN_discriminator(img_dim, nb_patch, model_name="DCGAN_discriminator", use_mbd=True):
"""
Discriminator model of the DCGAN

Expand Down Expand Up @@ -304,24 +304,24 @@ def DCGAN(generator, discriminator_model, img_dim, patch_size, image_dim_orderin
return DCGAN


def load(model_name, img_dim, nb_patch, bn_mode, use_mbd, batch_size):
def load(model_name, img_dim, nb_patch, use_mbd, batch_size):

if model_name == "generator_unet_upsampling":
model = generator_unet_upsampling(img_dim, bn_mode, model_name=model_name)
model = generator_unet_upsampling(img_dim, model_name=model_name)
model.summary()
from keras.utils import plot_model
plot_model(model, to_file="../../figures/%s.png" % model_name, show_shapes=True, show_layer_names=True)
return model

if model_name == "generator_unet_deconv":
model = generator_unet_deconv(img_dim, bn_mode, batch_size, model_name=model_name)
model = generator_unet_deconv(img_dim, batch_size, model_name=model_name)
model.summary()
from keras.utils import plot_model
plot_model(model, to_file="../../figures/%s.png" % model_name, show_shapes=True, show_layer_names=True)
return model

if model_name == "DCGAN_discriminator":
model = DCGAN_discriminator(img_dim, nb_patch, bn_mode, model_name=model_name, use_mbd=use_mbd)
model = DCGAN_discriminator(img_dim, nb_patch, model_name=model_name, use_mbd=use_mbd)
model.summary()
from keras.utils import plot_model
plot_model(model, to_file="../../figures/%s.png" % model_name, show_shapes=True, show_layer_names=True)
Expand All @@ -331,4 +331,4 @@ def load(model_name, img_dim, nb_patch, bn_mode, use_mbd, batch_size):
if __name__ == "__main__":

# load("generator_unet_deconv", (256, 256, 3), 16, 2, False, 32)
load("generator_unet_upsampling", (256, 256, 3), 16, 2, False, 32)
load("generator_unet_upsampling", (256, 256, 3), 16, False, 32)