Skip to content

[ICCV 2023 Oral] Official Implementation of "Denoising Diffusion Autoencoders are Unified Self-supervised Learners"

Notifications You must be signed in to change notification settings

FutureXiang/ddae

Repository files navigation

Denoising Diffusion Autoencoders (DDAE)

This is a multi-gpu PyTorch implementation of the paper Denoising Diffusion Autoencoders are Unified Self-supervised Learners:

@inproceedings{ddae2023,
  title={Denoising Diffusion Autoencoders are Unified Self-supervised Learners},
  author={Xiang, Weilai and Yang, Hongyu and Huang, Di and Wang, Yunhong},
  booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
  year={2023}
}

⭐ (News) Our paper is cited by Kaiming He's new paper Deconstructing Denoising Diffusion Models for Self-Supervised Learning, check it out! 🔥

Overview

This repo contains:

  • Pre-training, sampling and FID evaluation code for diffusion models, including
    • Frameworks:
      • DDPM & DDIM
      • EDM (w/ or w/o data augmentation)
    • Networks:
      • The basic 35.7M DDPM UNet
      • A larger 56M DDPM++ UNet
    • Datasets:
      • CIFAR-10
      • Tiny-ImageNet
  • Feature quality evaluation code, including
    • Linear probing and grid searching
    • Contrastive metrics, i.e., alignment and uniformity
    • Fine-tuning
  • Noise-conditional classifier training and evaluation, including
    • MLP classifier based on DDPM/EDM features
    • WideResNet with VP/VE perturbation
  • Evaluation code for ImageNet-256 pre-trained DiT-XL/2 checkpoint

Requirements

  • In addition to PyTorch environments, please install:
    conda install pyyaml
    pip install pytorch-fid ema-pytorch
  • We use 4 or 8 3080ti GPUs to conduct all the experiments presented in the paper. With automatic mixed precision enabled and 4 GPUs, training a basic 35.7M UNet on CIFAR-10 takes ~14 hours.
  • The pytorch-fid requires image files to calculate the FID metric. Please refer to extract_cifar10_pngs.ipynb to unpack the CIFAR-10 training dataset into 50000 .png image files.

Main results

We present the generative and discriminative evaluation results that can be obtained by this codebase. The EDM_ddpmpp_aug.yaml training is performed on 8 GPUs, while other models are trained on 4 GPUs.

Please note that this is a over-simplified DDPM / EDM implementation, and some network details, initialization, and hyper-parameters may differ from official ones. Please refer to their respective official codebases to reproduce the exact results reported in the paper.

Config Model Network Best linear probe checkpoint Best FID checkpoint
epoch FID acc epoch FID acc
DDPM_ddpm.yaml DDPM 35.7M UNet 800 4.09 90.05 1999 3.62 88.23
EDM_ddpm.yaml EDM 35.7M UNet 1200 3.97 90.44 1999 3.56 89.71
DDPM_ddpmpp.yaml DDPM 56.5M DDPM++ 1200 3.08 93.97 1999 2.98 93.03
EDM_ddpmpp.yaml EDM 56.5M DDPM++ 1200 2.23 94.50 (same)
EDM_ddpmpp_aug.yaml EDM + data aug 56.5M DDPM++ 2000 2.34 95.49 3200 2.12 95.19

FIDs are calculated using 50000 images generated by the deterministic fast sampler (DDIM 100 steps or EDM 18 steps).

Latent-space DiT

We evaluate pre-trained Transformer-based diffusion networks, DiT, from the perspective of transfer learning. Please refer to the ddae/DiT subfolder.

Usage

Diffusion pre-training

To train a DDAE model and generate 50000 image samples with 4 GPUs, for example, run:

python -m torch.distributed.launch --nproc_per_node=4
  # diffusion pre-training with AMP enabled
  train.py       --config config/DDPM_ddpm.yaml --use_amp
  
  # deterministic fast sampling (i.e. DDIM 100 steps / EDM 18 steps)
  sample.py      --config config/DDPM_ddpm.yaml --use_amp --epoch 400

  # stochastic sampling (i.e. DDPM 1000 steps)
  sample.py      --config config/DDPM_ddpm.yaml --use_amp --epoch 400 --mode DDPM

To calculate the FID metric on the training set, for example, run:

python -m pytorch_fid   data/cifar10-pngs/  output_DDPM_ddpm/EMAgenerated_ep400_ddim_steps100_eta0.0/pngs/

Features produced by DDAE

To evaluate the features produced by pre-trained DDAE, for example, run:

python -m torch.distributed.launch --nproc_per_node=4
  # grid searching for proper layer-noise combination
  linear.py      --config config/DDPM_ddpm.yaml --use_amp --epoch 400 --grid

  # linear probing, using the layer-noise combination specified by config.yaml
  linear.py      --config config/DDPM_ddpm.yaml --use_amp --epoch 400

  # showing the alignment-uniformity metrics with respect to different checkpoints
  contrastive.py --config config/DDPM_ddpm.yaml --use_amp

Noise-conditional classifier

To train WideResNet-based classifiers from scratch:

python -m torch.distributed.launch --nproc_per_node=4
  # VP (DDPM) perturbation
  noisy_classifier_WRN.py --mode DDPM
  # VE (EDM) perturbation
  noisy_classifier_WRN.py --mode EDM

and compare their noise-conditional recognition rates with DDAE-based MLP classifier heads:

python -m torch.distributed.launch --nproc_per_node=4
  # using DDPM DDAE encoder
  noisy_classifier_DDAE.py --config config/DDPM_ddpm.yaml  --use_amp --epoch 1999
  # using EDM DDAE encoder
  noisy_classifier_DDAE.py --config config/EDM_ddpmpp.yaml --use_amp --epoch 1200

Acknowledgments

This repository is built on numerous open-source codebases such as DDPM, DDPM-pytorch, DDIM, EDM, Score-based SDE, DiT, and align_uniform.

About

[ICCV 2023 Oral] Official Implementation of "Denoising Diffusion Autoencoders are Unified Self-supervised Learners"

Topics

Resources

Stars

Watchers

Forks