Skip to content

Commit

Permalink
add inference, amp and channel_last supports for began
Browse files Browse the repository at this point in the history
  • Loading branch information
CaoE committed Mar 9, 2022
1 parent 36d3c77 commit f307843
Show file tree
Hide file tree
Showing 2 changed files with 310 additions and 186 deletions.
267 changes: 163 additions & 104 deletions implementations/began/began.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import numpy as np
import math
import time

import torchvision.transforms as transforms
from torchvision.utils import save_image
Expand All @@ -27,12 +28,17 @@
parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="number of image channels")
parser.add_argument('--inference', action='store_true', default=False)
parser.add_argument('--precision', default='float32', help='Precision, "float32" or "bfloat16"')
parser.add_argument('--channels_last', type=int, default=1, help='use channels last format')
parser.add_argument('--num-iterations', default=100, type=int)
opt = parser.parse_args()
print(opt)

img_shape = (opt.channels, opt.img_size, opt.img_size)

cuda = True if torch.cuda.is_available() else False
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor


def weights_init_normal(m):
Expand Down Expand Up @@ -68,6 +74,8 @@ def __init__(self):
def forward(self, noise):
out = self.l1(noise)
out = out.view(out.shape[0], 128, self.init_size, self.init_size)
if opt.channels_last:
out = out.to(memory_format=torch.channels_last)
img = self.conv_blocks(out)
return img

Expand All @@ -94,116 +102,167 @@ def __init__(self):

def forward(self, img):
out = self.down(img)
if opt.channels_last:
out = out.contiguous()
out = self.fc(out.view(out.size(0), -1))
out = self.up(out.view(out.size(0), 64, self.down_size, self.down_size))
out = out.view(out.size(0), 64, self.down_size, self.down_size)
if opt.channels_last:
out = out.to(memory_format=torch.channels_last)
out = self.up(out)
return out

def main():
# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
generator.cuda()
discriminator.cuda()
else:
generator.cpu()
discriminator.cpu()

# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)
device = torch.device('cuda') if cuda else torch.device('cpu')
if opt.inference:
print("----------------Generation---------------")
if opt.precision == "bfloat16":
cm = torch.cuda.amp.autocast if cuda else torch.cpu.amp.autocast
with cm():
generate(generator, device=device)
else:
generate(generator, device=device)
else:
print("-------------------Train-----------------")
train(generator, discriminator)


def generate(netG, device):
fixed_noise = Variable(Tensor(np.random.normal(0, 1, (10 ** 2, opt.latent_dim))))
if opt.channels_last:
netG_oob = netG
try:
netG_oob = netG_oob.to(memory_format=torch.channels_last)
print("[INFO] Use NHWC model")
except:
print("[WARN] Input NHWC failed! Use normal model")
netG = netG_oob
else:
fixed_noise = fixed_noise.to(device=device)
netG.eval()

total_iters = opt.num_iterations
with torch.no_grad():
tic = time.time()
for i in range(total_iters):
fake = netG(fixed_noise)
toc = time.time() - tic
print("Throughput: %.2f image/sec, batchsize: %d, latency = %.2f ms"%((opt.num_iterations*opt.batch_size)/toc, opt.batch_size, 1000*toc/opt.num_iterations))

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
generator.cuda()
discriminator.cuda()

# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
datasets.MNIST(
"../../data/mnist",
train=True,
download=True,
transform=transforms.Compose(
[transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
),
),
batch_size=opt.batch_size,
shuffle=True,
)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

# ----------
# Training
# ----------

# BEGAN hyper parameters
gamma = 0.75
lambda_k = 0.001
k = 0.0

for epoch in range(opt.n_epochs):
for i, (imgs, _) in enumerate(dataloader):

# Configure input
real_imgs = Variable(imgs.type(Tensor))

# -----------------
# Train Generator
# -----------------

optimizer_G.zero_grad()

# Sample noise as generator input
z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))

# Generate a batch of images
gen_imgs = generator(z)

# Loss measures generator's ability to fool the discriminator
g_loss = torch.mean(torch.abs(discriminator(gen_imgs) - gen_imgs))

g_loss.backward()
optimizer_G.step()

# ---------------------
# Train Discriminator
# ---------------------

optimizer_D.zero_grad()

# Measure discriminator's ability to classify real from generated samples
d_real = discriminator(real_imgs)
d_fake = discriminator(gen_imgs.detach())

d_loss_real = torch.mean(torch.abs(d_real - real_imgs))
d_loss_fake = torch.mean(torch.abs(d_fake - gen_imgs.detach()))
d_loss = d_loss_real - k * d_loss_fake

d_loss.backward()
optimizer_D.step()

# ----------------
# Update weights
# ----------------

diff = torch.mean(gamma * d_loss_real - d_loss_fake)

# Update weight term for fake samples
k = k + lambda_k * diff.item()
k = min(max(k, 0), 1) # Constraint to interval [0, 1]

# Update convergence metric
M = (d_loss_real + torch.abs(diff)).data[0]

# --------------
# Log Progress
# --------------

print(
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] -- M: %f, k: %f"
% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item(), M, k)
)

batches_done = epoch * len(dataloader) + i
if batches_done % opt.sample_interval == 0:
save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)
def train(netG, netD):
# BEGAN hyper parameters
gamma = 0.75
lambda_k = 0.001
k = 0.0

# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
datasets.MNIST(
"../../data/mnist",
train=True,
download=True,
transform=transforms.Compose(
[transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
),
),
batch_size=opt.batch_size,
shuffle=True,
)
# Optimizers
optimizer_G = torch.optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

for epoch in range(opt.n_epochs):
for i, (imgs, _) in enumerate(dataloader):
if opt.channels_last:
imgs_oob = imgs
try:
imgs_oob = imgs_oob.to(memory_format=torch.channels_last)
print("[INFO] Use NHWC input")
except:
print("[WARN] Input NHWC failed! Use normal input")
imgs = imgs_oob
# Configure input
real_imgs = Variable(imgs.type(Tensor))

# -----------------
# Train Generator
# -----------------

optimizer_G.zero_grad()

# Sample noise as generator input
z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))

# Generate a batch of images
gen_imgs = netG(z)

# Loss measures generator's ability to fool the discriminator
g_loss = torch.mean(torch.abs(netD(gen_imgs) - gen_imgs))

g_loss.backward()
optimizer_G.step()

# ---------------------
# Train Discriminator
# ---------------------

optimizer_D.zero_grad()

# Measure discriminator's ability to classify real from generated samples
d_real = netD(real_imgs)
d_fake = netD(gen_imgs.detach())

d_loss_real = torch.mean(torch.abs(d_real - real_imgs))
d_loss_fake = torch.mean(torch.abs(d_fake - gen_imgs.detach()))
d_loss = d_loss_real - k * d_loss_fake

d_loss.backward()
optimizer_D.step()

# ----------------
# Update weights
# ----------------

diff = torch.mean(gamma * d_loss_real - d_loss_fake)

# Update weight term for fake samples
k = k + lambda_k * diff.item()
k = min(max(k, 0), 1) # Constraint to interval [0, 1]

# Update convergence metric
M = (d_loss_real + torch.abs(diff)).data.item()

# --------------
# Log Progress
# --------------

print(
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] -- M: %f, k: %f"
% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item(), M, k)
)

batches_done = epoch * len(dataloader) + i
if batches_done % opt.sample_interval == 0:
save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)
if __name__ == '__main__':
main()
Loading

0 comments on commit f307843

Please sign in to comment.