Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-level masking #125

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 92 additions & 15 deletions infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,33 +44,103 @@ def parsed_args():
required=True,
)
parser.add_argument(
"--new_size", type=int, help="Size of generated masks",
"--new_size",
type=int,
help="Size of generated masks",
)
parser.add_argument(
"--output_dir",
default="./outputs/",
type=str,
help="Directory to write images to",
)
parser.add_argument(
"--path_to_masks",
type=str,
help="Path of masks to be used for painting",
required=False,
)
parser.add_argument(
"--apply_mask",
action="store_true",
help="Apply mask to image to save",
)

return parser.parse_args()


def eval_folder(path_to_images, output_dir, paint=False):
images = [path_to_images / Path(i) for i in os.listdir(path_to_images)]
for img_path in images:
def eval_folder(
output_dir,
path_to_images,
path_to_masks=None,
paint=False,
masker=False,
apply_mask=False,
):

image_list = os.listdir(path_to_images)
image_list.sort()
images = [path_to_images / Path(i) for i in image_list]
if not masker:
mask_list = os.listdir(path_to_masks)
mask_list.sort()
masks = [path_to_masks / Path(i) for i in mask_list]

for i, img_path in enumerate(images):
img = tensor_loader(img_path, task="x", domain="val")

# Resize img:
img = F.interpolate(img, (new_size, new_size), mode="nearest")
img = img.squeeze(0)
for tf in transforms:
img = tf(img)

img = img.unsqueeze(0).to(device)
z = model.encode(img)
mask = model.decoders["m"](z)

vutils.save_image(mask, output_dir / ("mask_" + img_path.name), normalize=True)
if not masker:
mask = tensor_loader(masks[i], task="m", domain="val")
mask = F.interpolate(mask, (new_size, new_size), mode="nearest")
mask = mask.squeeze()
mask = mask.unsqueeze(0).to(device)

if masker:
if "m2" in opts.tasks:
z = model.encode(img)
z_aug_1 = torch.cat(
(z, trainer.label_1[0, :, :, :].unsqueeze(0)),
dim=1,
)
z_aug_2 = torch.cat(
(z, trainer.label_2[0, :, :, :].unsqueeze(0)),
dim=1,
)
mask_1 = model.decoders["m"](z_aug_1)
mask_2 = model.decoders["m"](z_aug_2)
vutils.save_image(
mask_1, output_dir / ("mask1_" + img_path.name), normalize=True
)
vutils.save_image(
mask_2, output_dir / ("mask2_" + img_path.name), normalize=True
)

if apply_mask:
vutils.save_image(
img * (1.0 - mask_1) + mask_1,
output_dir / (img_path.stem + "img_masked_1" + ".jpg"),
normalize=True,
)
vutils.save_image(
img * (1.0 - mask_2) + mask_2,
output_dir / (img_path.stem + "img_masked_2" + ".jpg"),
normalize=True,
)

else:
z = model.encode(img)
mask = model.decoders["m"](z)
vutils.save_image(
mask, output_dir / ("mask_" + img_path.name), normalize=True
)

if paint:
z_painter = trainer.sample_z(1)
Expand Down Expand Up @@ -113,10 +183,12 @@ def isimg(path_file):
else:
new_size = args.new_size

if "m" in opts.tasks and "p" in opts.tasks:
paint = False
masker = False
if "p" in opts.tasks:
paint = True
else:
paint = False
if "m" in opts.tasks:
masker = True
# ------------------------
# ----- Define model -----
# ------------------------
Expand All @@ -142,6 +214,7 @@ def isimg(path_file):
# eval_folder(args.path_to_images, output_dir)

rootdir = args.path_to_images
maskdir = args.path_to_masks
writedir = args.output_dir

for root, subdirs, files in tqdm(os.walk(rootdir)):
Expand All @@ -151,10 +224,6 @@ def isimg(path_file):
has_imgs = False
for f in files:
if isimg(f):
# read_path = root / f
# rel_path = read_path.relative_to(rootdir)
# write_path = writedir / rel_path
# write_path.mkdir(parents=True, exist_ok=True)
has_imgs = True
break

Expand All @@ -163,4 +232,12 @@ def isimg(path_file):
rel_path = root.relative_to(rootdir)
write_path = writedir / rel_path
write_path.mkdir(parents=True, exist_ok=True)
eval_folder(root, write_path, paint)
print("root: ", root)
eval_folder(
write_path,
root,
path_to_masks=maskdir,
paint=paint,
masker=masker,
apply_mask=args.apply_mask,
)
2 changes: 1 addition & 1 deletion omnigan/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def tensor_loader(path, task, domain):
arr = np.moveaxis(arr, 2, 0)
elif task == "s":
arr = np.moveaxis(arr, 2, 0)
elif task == "m":
elif task == "m" or task == "m2":
arr[arr != 0] = 1
# Make sure mask is single-channel
if len(arr.shape) >= 3:
Expand Down
125 changes: 125 additions & 0 deletions omnigan/discriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,116 @@ def forward(self, input):
return result


class AuxiliaryClassifier(nn.Module):
def __init__(
self,
input_size=640,
input_nc=1,
ndf=64,
n_layers=3,
norm_layer=nn.BatchNorm2d,
use_sigmoid=False,
):
super(AuxiliaryClassifier, self).__init__()
self.input_nc = input_nc
self.ndf = ndf
self.n_layers = n_layers
self.norm_layer = norm_layer
self.use_sigmoid = use_sigmoid

if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d

kw = 3
padw = 1
sequence = [
# Use spectral normalization
SpectralNorm(
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw)
),
nn.LeakyReLU(0.2, True),
]

nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers):
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
sequence += [
# Use spectral normalization
SpectralNorm( # TODO replace with Conv2dBlock
nn.Conv2d(
ndf * nf_mult_prev,
ndf * nf_mult,
kernel_size=kw,
stride=2,
padding=padw,
bias=use_bias,
)
),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True),
]

nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8)
sequence += [
# Use spectral normalization
SpectralNorm(
nn.Conv2d(
ndf * nf_mult_prev,
ndf * nf_mult,
kernel_size=kw,
stride=1,
padding=padw,
bias=use_bias,
)
),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True),
]

self.shared_layers = nn.Sequential(*sequence)

proj_dim = ndf

self.projection = SpectralNorm(
nn.Conv2d(
ndf * nf_mult + 2048, proj_dim, kernel_size=1, stride=1, padding=0
)
)

latent_size = int(input_size / (2 ** n_layers))
self.linear_size = int(proj_dim * latent_size * latent_size)
self.gan_layer = nn.Linear(self.linear_size, 1)
self.ac_layer = nn.Linear(self.linear_size, 2)

def forward(self, mask, z):
x = self.shared_layers(mask)
x = torch.cat((x, z), dim=1)
x = self.projection(x)
x = x.view(-1, self.linear_size)

return [self.gan_layer(x), self.ac_layer(x)]


def get_AC(
input_size, input_nc, ndf, n_layers=3, norm="batch", use_sigmoid=False,
):
norm_layer = get_norm_layer(norm_type=norm)
net = AuxiliaryClassifier(
input_size=input_size,
input_nc=input_nc,
ndf=ndf,
n_layers=n_layers,
norm_layer=norm_layer,
use_sigmoid=use_sigmoid,
)
return net


class OmniDiscriminator(nn.ModuleDict):
def __init__(self, opts):
super().__init__()
Expand Down Expand Up @@ -281,6 +391,21 @@ def __init__(self, opts):
)
else:
raise Exception("This Discriminator is currently not supported!")
if "m2" in opts.tasks:
# Create a flood-level discriminator / classifier
self["m2"] = nn.ModuleDict(
vict0rsch marked this conversation as resolved.
Show resolved Hide resolved
{
"FloodLevel": get_AC(
input_size=640,
input_nc=1,
ndf=opts.dis.m2.ndf,
n_layers=opts.dis.m2.n_layers,
norm=opts.dis.m2.norm,
use_sigmoid=opts.dis.m2.use_sigmoid,
)
}
)

if "s" in opts.tasks:
if opts.gen.s.use_advent:
self["s"] = nn.ModuleDict(
Expand Down
20 changes: 19 additions & 1 deletion omnigan/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,10 @@ def __init__(self, opts, latent_shape=None, verbose=None):
self.decoders["s"] = SegmentationDecoder(opts)

if "m" in opts.tasks and not opts.gen.m.ignore:
self.decoders["m"] = MaskDecoder(opts)
if "m2" in opts.tasks:
self.decoders["m"] = ConditionalMasker(opts)
else:
self.decoders["m"] = MaskDecoder(opts)

self.decoders = nn.ModuleDict(self.decoders)

Expand Down Expand Up @@ -116,6 +119,21 @@ def __init__(self, opts):
)


class ConditionalMasker(BaseDecoder):
def __init__(self, opts):
super().__init__(
n_upsample=opts.gen.m.n_upsample,
n_res=opts.gen.m.n_res,
input_dim=opts.gen.encoder.res_dim + 1,
proj_dim=opts.gen.m.proj_dim,
output_dim=opts.gen.m.output_dim,
res_norm=opts.gen.m.res_norm,
activ=opts.gen.m.activ,
pad_type=opts.gen.m.pad_type,
output_activ="sigmoid",
)


class DepthDecoder(BaseDecoder):
def __init__(self, opts):
super().__init__(
Expand Down
7 changes: 3 additions & 4 deletions omnigan/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def get_losses(opts, verbose, device=None):

losses = {
"G": {"a": {}, "p": {}, "tasks": {}},
"D": {"default": {}, "advent": {}},
"D": {"default": {}, "advent": {}, "multilevel": {}},
"C": {},
}

Expand Down Expand Up @@ -417,6 +417,7 @@ def get_losses(opts, verbose, device=None):
soft_shift=opts.dis.soft_shift, flip_prob=opts.dis.flip_prob, verbose=verbose
)
losses["D"]["advent"] = ADVENTAdversarialLoss(opts)
losses["D"]["multilevel"] = CrossEntropy()
return losses


Expand All @@ -441,9 +442,7 @@ def __init__(self):
def __call__(self, prediction, target):
return self.loss(
prediction,
torch.FloatTensor(prediction.size())
.fill_(target)
.to(prediction.get_device()),
torch.FloatTensor(prediction.size()).fill_(target).to(prediction.device),
)


Expand Down
Loading