git clone https://github.com/moreh-dev/ml-workbench.git
cd ml-workbench/DCGAN-PyTorch
Note: some changes have been made compared to original repo to fix the issue and make the training run well in our MAF:
- In train.py, line 70-71, added float() outside of the int value to avoid Float issue.
real_label = float(1)
fake_label = float(0)
- In train.py, bottom line 204, changed writer from imagemagick to pillow:
anim.save('celeba.gif', dpi=80, writer='pillow')
2. Prepare data by manually download as guided or by other method, then update the directory location inside the root
variable in utils.py
(line 6).
- This implementation uses the CelebA dataset. However, any other dataset can also be used. Download from https://drive.google.com/drive/folders/0B7EVK8r0v71pWEZsZE9oNnFzTm8?resourcekey=0-5BR16BdXnb8hVj6CNHKzLg and preprocess if need.
- Other ways: https://gist.github.com/SeitaroShinagawa/05083f971d3f1b88a19df32841e5cb25 or https://github.com/suvojit-0x55aa/celebA-HQ-dataset-download
pip install pillow
To train the model, run train.py
. To set the training parametrs, update the values in the params
dictionary in train.py
.
Checkpoints would be saved by default in model directory every 2 epochs.
python train.py
PyTorch implementation of DCGAN introduced in the paper: Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks, Alec Radford, Luke Metz, Soumith Chintala.
Generative Adversarial Networks (GANs) are one of the most popular (and coolest) Machine Learning algorithms developed in recent times. They belong to a set of algorithms called generative models, which are widely used for unupervised learning tasks which aim to learn the uderlying structure of the given data. As the name suggests GANs allow you to generate new unseen data that mimic the actual given real data. However, GANs pose problems in training and require carefullly tuned hyperparameters.This paper aims to solve this problem.
DCGAN is one of the most popular and succesful network design for GAN. It mainly composes of convolution layers without max pooling or fully connected layers. It uses strided convolutions and transposed convolutions for the downsampling and the upsampling respectively.
Generator architecture of DCGAN
Network Design of DCGAN:
- Replace all pooling layers with strided convolutions.
- Remove all fully connected layers.
- Use transposed convolutions for upsampling.
- Use Batch Normalization after every layer except after the output layer of the generator and the input layer of the discriminator.
- Use ReLU non-linearity for each layer in the generator except for output layer use tanh.
- Use Leaky-ReLU non-linearity for each layer of the disciminator excpet for output layer use sigmoid.
Hyperparameters are chosen as given in the paper.
- mini-batch size: 128
- learning rate: 0.0002
- momentum term beta1: 0.5
- slope of leak of LeakyReLU: 0.2
- For the optimizer Adam (with beta2 = 0.999) has been used instead of SGD as described in the paper.
This implementation uses the CelebA dataset. However, any other dataset can
also be used. Download the data and update the directory location inside the root
variable in utils.py
.
CelebA Dataset
To train the model, run train.py
. To set the training parametrs, update the values in the params
dictionary in train.py
.
Checkpoints would be saved by default in model directory every 2 epochs.
By default, GPU is used for training if available.
Training will take a long time. It took me around 3 hours on a NVIDIA GeForce GTX 1060 GPU. Using a CPU is not recommended.
Loss Curves
D: Discriminator, G: GeneratorTo generate new unseen images, run generate.py
.
python3 generate.py --load_path /path/to/pth/checkpoint --num_output n
Generated Images
After Epoch 1: After Epoch 10:
- Alec Radford, Luke Metz, Soumith Chintala. Unsupervised representation learning with deep convolutional generative adversarial networks.[arxiv]
- Ian Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, Yoshua Bengio. Generative adversarial nets. NIPS 2014 [arxiv]
- Ian Goodfellow. Tutorial: Generative Adversarial Networks. NIPS 2016 [arxiv]
- DCGAN Tutorial. [https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html]
- PyTorch Docs. [https://pytorch.org/docs/stable/index.html]