This is a multi-gpu PyTorch implementation of the paper Denoising Diffusion Autoencoders are Unified Self-supervised Learners:
@inproceedings{ddae2023,
title={Denoising Diffusion Autoencoders are Unified Self-supervised Learners},
author={Xiang, Weilai and Yang, Hongyu and Huang, Di and Wang, Yunhong},
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
year={2023}
}
⭐ (News) Our paper is cited by Kaiming He's new paper Deconstructing Denoising Diffusion Models for Self-Supervised Learning, check it out! 🔥
This repo contains:
- Pre-training, sampling and FID evaluation code for diffusion models, including
- Frameworks:
- DDPM & DDIM
- EDM (w/ or w/o data augmentation)
- Networks:
- The basic 35.7M DDPM UNet
- A larger 56M DDPM++ UNet
- Datasets:
- CIFAR-10
- Tiny-ImageNet
- Frameworks:
- Feature quality evaluation code, including
- Linear probing and grid searching
- Contrastive metrics, i.e., alignment and uniformity
- Fine-tuning
- Noise-conditional classifier training and evaluation, including
- MLP classifier based on DDPM/EDM features
- WideResNet with VP/VE perturbation
- Evaluation code for ImageNet-256 pre-trained DiT-XL/2 checkpoint
- In addition to PyTorch environments, please install:
conda install pyyaml pip install pytorch-fid ema-pytorch
- We use 4 or 8 3080ti GPUs to conduct all the experiments presented in the paper. With automatic mixed precision enabled and 4 GPUs, training a basic 35.7M UNet on CIFAR-10 takes ~14 hours.
- The
pytorch-fid
requires image files to calculate the FID metric. Please refer toextract_cifar10_pngs.ipynb
to unpack the CIFAR-10 training dataset into 50000.png
image files.
We present the generative and discriminative evaluation results that can be obtained by this codebase. The EDM_ddpmpp_aug.yaml
training is performed on 8 GPUs, while other models are trained on 4 GPUs.
Please note that this is a over-simplified DDPM / EDM implementation, and some network details, initialization, and hyper-parameters may differ from official ones. Please refer to their respective official codebases to reproduce the exact results reported in the paper.
Config | Model | Network | Best linear probe checkpoint | Best FID checkpoint | ||||
---|---|---|---|---|---|---|---|---|
epoch | FID | acc | epoch | FID | acc | |||
DDPM_ddpm.yaml | DDPM | 35.7M UNet | 800 | 4.09 | 90.05 | 1999 | 3.62 | 88.23 |
EDM_ddpm.yaml | EDM | 35.7M UNet | 1200 | 3.97 | 90.44 | 1999 | 3.56 | 89.71 |
DDPM_ddpmpp.yaml | DDPM | 56.5M DDPM++ | 1200 | 3.08 | 93.97 | 1999 | 2.98 | 93.03 |
EDM_ddpmpp.yaml | EDM | 56.5M DDPM++ | 1200 | 2.23 | 94.50 | (same) | ||
EDM_ddpmpp_aug.yaml | EDM + data aug | 56.5M DDPM++ | 2000 | 2.34 | 95.49 | 3200 | 2.12 | 95.19 |
FIDs are calculated using 50000 images generated by the deterministic fast sampler (DDIM 100 steps or EDM 18 steps).
We evaluate pre-trained Transformer-based diffusion networks, DiT, from the perspective of transfer learning. Please refer to the ddae/DiT subfolder.
To train a DDAE model and generate 50000 image samples with 4 GPUs, for example, run:
python -m torch.distributed.launch --nproc_per_node=4
# diffusion pre-training with AMP enabled
train.py --config config/DDPM_ddpm.yaml --use_amp
# deterministic fast sampling (i.e. DDIM 100 steps / EDM 18 steps)
sample.py --config config/DDPM_ddpm.yaml --use_amp --epoch 400
# stochastic sampling (i.e. DDPM 1000 steps)
sample.py --config config/DDPM_ddpm.yaml --use_amp --epoch 400 --mode DDPM
To calculate the FID metric on the training set, for example, run:
python -m pytorch_fid data/cifar10-pngs/ output_DDPM_ddpm/EMAgenerated_ep400_ddim_steps100_eta0.0/pngs/
To evaluate the features produced by pre-trained DDAE, for example, run:
python -m torch.distributed.launch --nproc_per_node=4
# grid searching for proper layer-noise combination
linear.py --config config/DDPM_ddpm.yaml --use_amp --epoch 400 --grid
# linear probing, using the layer-noise combination specified by config.yaml
linear.py --config config/DDPM_ddpm.yaml --use_amp --epoch 400
# showing the alignment-uniformity metrics with respect to different checkpoints
contrastive.py --config config/DDPM_ddpm.yaml --use_amp
To train WideResNet-based classifiers from scratch:
python -m torch.distributed.launch --nproc_per_node=4
# VP (DDPM) perturbation
noisy_classifier_WRN.py --mode DDPM
# VE (EDM) perturbation
noisy_classifier_WRN.py --mode EDM
and compare their noise-conditional recognition rates with DDAE-based MLP classifier heads:
python -m torch.distributed.launch --nproc_per_node=4
# using DDPM DDAE encoder
noisy_classifier_DDAE.py --config config/DDPM_ddpm.yaml --use_amp --epoch 1999
# using EDM DDAE encoder
noisy_classifier_DDAE.py --config config/EDM_ddpmpp.yaml --use_amp --epoch 1200
This repository is built on numerous open-source codebases such as DDPM, DDPM-pytorch, DDIM, EDM, Score-based SDE, DiT, and align_uniform.