Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Evaluation and saving the best checkpoint #209

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
# datasets

*.json
*.h5
*.hdf5
.DS_Store
data/celeba/
data/mnist/

# output folders

implementations/*/temp/
implementations/*/results*/
implementations/*/images/


#pycache
implementations/*/__pycache__/
implementations/gan/__pycache__/
implementations/wgan/__pycache__/
implementations/wgan_div/__pycache__/
implementations/wgan_gp/__pycache__/
implementations/__pycache__/

# checkpoints

data/*/
implementations/*/data
implementations/*/images
implementations/*/saved_models

__pycache__
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,24 @@ Collection of PyTorch implementations of Generative Adversarial Network varietie

<b>See also:</b> [Keras-GAN](https://github.com/eriklindernoren/Keras-GAN)

## Quick Commands for any type of GAN

```
cd implementations
export PYTHONPATH=/home/yca/PyTorch-GAN/implementations:$PYTHONPATH
```
Train

```
cd gan_type
python gan_type.py --dataset="celeba" --img_size=128 --channels=3 #if mnist, just use the default parameters and run python gan_type.py
```
Inference
```
python inference.py --img_size=128 --channels=3 --model=wgan --model_checkpoint="wgan/generator_best_fid.pth"
```
The generated images will be saved under gan_type/inference folder

## Table of Contents
* [Installation](#installation)
* [Implementations](#implementations)
Expand Down
227 changes: 118 additions & 109 deletions implementations/cgan/cgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,43 +14,31 @@
import torch.nn.functional as F
import torch

os.makedirs("images", exist_ok=True)

parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--n_classes", type=int, default=10, help="number of classes for dataset")
parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval between image sampling")
opt = parser.parse_args()
print(opt)

img_shape = (opt.channels, opt.img_size, opt.img_size)

cuda = True if torch.cuda.is_available() else False
from utils import evaluate_model, save_best_model, save_images_with_labels
from datasets import getDataloader


class Generator(nn.Module):
def __init__(self):
def __init__(self, latent_dim, img_shape, n_classes):
super(Generator, self).__init__()

self.label_emb = nn.Embedding(opt.n_classes, opt.n_classes)
self.label_emb = nn.Embedding(n_classes, n_classes)

def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers

self.img_shape = img_shape
self.best_is = 0
self.best_fid = float('inf')
self.best_kid = float('inf')


self.model = nn.Sequential(
*block(opt.latent_dim + opt.n_classes, 128, normalize=False),
*block(latent_dim + n_classes, 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
Expand All @@ -67,10 +55,10 @@ def forward(self, noise, labels):


class Discriminator(nn.Module):
def __init__(self):
def __init__(self, n_classes):
super(Discriminator, self).__init__()

self.label_embedding = nn.Embedding(opt.n_classes, opt.n_classes)
self.label_embedding = nn.Embedding(n_classes, n_classes)

self.model = nn.Sequential(
nn.Linear(opt.n_classes + int(np.prod(img_shape)), 512),
Expand All @@ -90,115 +78,136 @@ def forward(self, img, labels):
validity = self.model(d_in)
return validity

if __name__ == "__main__":
os.makedirs("images", exist_ok=True)

# Loss functions
adversarial_loss = torch.nn.MSELoss()
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--n_classes", type=int, default=10, help="number of classes for dataset")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--eval_interval", type=int, default=400, help="interval between image sampling")
parser.add_argument("--dataset", type=str, default="mnist", help="dataset type: mnist or celeba for now")

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()
opt = parser.parse_args()
print(opt)

if cuda:
generator.cuda()
discriminator.cuda()
adversarial_loss.cuda()
img_shape = (opt.channels, opt.img_size, opt.img_size)

# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
datasets.MNIST(
"../../data/mnist",
train=True,
download=True,
transform=transforms.Compose(
[transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
),
),
batch_size=opt.batch_size,
shuffle=True,
)
cuda = True if torch.cuda.is_available() else False

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor
# Loss functions
adversarial_loss = torch.nn.MSELoss()

# Initialize generator and discriminator
generator = Generator(opt.latent_dim, img_shape, opt.n_classes)
discriminator = Discriminator(opt.n_classes)

def sample_image(n_row, batches_done):
"""Saves a grid of generated digits ranging from 0 to n_classes"""
# Sample noise
z = Variable(FloatTensor(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim))))
# Get labels ranging from 0 to n_classes for n rows
labels = np.array([num for _ in range(n_row) for num in range(n_row)])
labels = Variable(LongTensor(labels))
gen_imgs = generator(z, labels)
save_image(gen_imgs.data, "images/%d.png" % batches_done, nrow=n_row, normalize=True)
if cuda:
generator.cuda()
discriminator.cuda()
adversarial_loss.cuda()

# Configure data loader

# ----------
# Training
# ----------
dataloader = getDataloader(opt.dataset, opt.batch_size)


for epoch in range(opt.n_epochs):
for i, (imgs, labels) in enumerate(dataloader):
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

batch_size = imgs.shape[0]
FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor

# Adversarial ground truths
valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False)
fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False)

# Configure input
real_imgs = Variable(imgs.type(FloatTensor))
labels = Variable(labels.type(LongTensor))
def sample_image(n_row, batches_done):
"""Saves a grid of generated digits ranging from 0 to n_classes"""
# Sample noise
z = Variable(FloatTensor(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim))))
# Get labels ranging from 0 to n_classes for n rows
labels = np.array([num for _ in range(n_row) for num in range(n_row)])
labels = Variable(LongTensor(labels))
gen_imgs = generator(z, labels)
save_image(gen_imgs.data, "images/%d.png" % batches_done, nrow=n_row, normalize=True)

# -----------------
# Train Generator
# -----------------

optimizer_G.zero_grad()
# ----------
# Training
# ----------

# Sample noise and labels as generator input
z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))
gen_labels = Variable(LongTensor(np.random.randint(0, opt.n_classes, batch_size)))
for epoch in range(opt.n_epochs):
for i, (imgs, labels) in enumerate(dataloader):

# Generate a batch of images
gen_imgs = generator(z, gen_labels)
batch_size = imgs.shape[0]

# Loss measures generator's ability to fool the discriminator
validity = discriminator(gen_imgs, gen_labels)
g_loss = adversarial_loss(validity, valid)
# Adversarial ground truths
valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False)
fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False)

g_loss.backward()
optimizer_G.step()
# Configure input
real_imgs = Variable(imgs.type(FloatTensor))
labels = Variable(labels.type(LongTensor))

# ---------------------
# Train Discriminator
# ---------------------
# -----------------
# Train Generator
# -----------------

optimizer_D.zero_grad()
optimizer_G.zero_grad()

# Loss for real images
validity_real = discriminator(real_imgs, labels)
d_real_loss = adversarial_loss(validity_real, valid)
# Sample noise and labels as generator input
z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))
gen_labels = Variable(LongTensor(np.random.randint(0, opt.n_classes, batch_size)))

# Loss for fake images
validity_fake = discriminator(gen_imgs.detach(), gen_labels)
d_fake_loss = adversarial_loss(validity_fake, fake)
# Generate a batch of images
gen_imgs = generator(z, gen_labels)

# Total discriminator loss
d_loss = (d_real_loss + d_fake_loss) / 2
# Loss measures generator's ability to fool the discriminator
validity = discriminator(gen_imgs, gen_labels)
g_loss = adversarial_loss(validity, valid)

d_loss.backward()
optimizer_D.step()
g_loss.backward()
optimizer_G.step()

# ---------------------
# Train Discriminator
# ---------------------

optimizer_D.zero_grad()

# Loss for real images
validity_real = discriminator(real_imgs, labels)
d_real_loss = adversarial_loss(validity_real, valid)

# Loss for fake images
validity_fake = discriminator(gen_imgs.detach(), gen_labels)
d_fake_loss = adversarial_loss(validity_fake, fake)

# Total discriminator loss
d_loss = (d_real_loss + d_fake_loss) / 2

d_loss.backward()
optimizer_D.step()

print(
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
)

save_images_with_labels(gen_imgs.data[:25], gen_labels[:25], epoch)

#save_image(gen_imgs.data[:25], "images/epoch_%d.png" % epoch, nrow=5, normalize=True)

print(
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
)

batches_done = epoch * len(dataloader) + i
if batches_done % opt.sample_interval == 0:
sample_image(n_row=10, batches_done=batches_done)
if epoch % opt.eval_interval == 0:
is_score, fid_score, kid_score = evaluate_model(generator, dataloader, opt.latent_dim, num_samples=1000, conditional=True, num_classes=10)
print(f"Epoch {epoch}: IS = {is_score:.2f}, FID = {fid_score:.2f}, KID = {kid_score:.4f}")
save_best_model(generator, is_score, fid_score, kid_score, epoch)

Loading