Skip to content

Commit

Permalink
Merge pull request #65 from atong01/cifar_10_FID_5
Browse files Browse the repository at this point in the history
Provide a new Cifar10 example, FID computation and trained model weights
  • Loading branch information
kilianFatras authored Nov 1, 2023
2 parents 1c44a6a + ccb0bef commit 7b0b2bc
Show file tree
Hide file tree
Showing 12 changed files with 811 additions and 521 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,5 @@ slurm*.out
*.jpg

notebooks/figures/

.DS_Store
477 changes: 56 additions & 421 deletions README.md

Large diffs are not rendered by default.

Binary file added assets/169_generated_samples_otcfm.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/169_generated_samples_otcfm.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
40 changes: 37 additions & 3 deletions examples/cifar10/README.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,47 @@
# CIFAR-10 experiments using TorchCFM

This repository is used to reproduce the CIFAR-10 experiments from [1](https://arxiv.org/abs/2302.00482). It is a repository in construction and we will add more features and details in the future (including FID computations and pre-trained weights). We have followed the experimental details provided in [2](https://openreview.net/forum?id=PqvMRDCJT9t).
This repository is used to reproduce the CIFAR-10 experiments from [1](https://arxiv.org/abs/2302.00482). We have designed a novel experimental procedure that helps us to reach an __FID of 3.5__ on the Cifar10 dataset.

To reproduce the experiment and save the weights, install the requirements from the main repository and then run:
<p align="center">
<img src="../../assets/169_generated_samples_otcfm.png" width="600"/>
</p>

To reproduce the experiments and save the weights, install the requirements from the main repository and then run (runs on a single RTX 2080 GPU):

- For the OT-Conditional Flow Matching method:

```bash
python3 train_cifar10.py --model "otcfm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000
```

- For the Conditional Flow Matching method:

```bash
python3 train_cifar10.py --model "cfm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000
```

- For the original Flow Matching method:

```bash
python3 train_cifar10.py
python3 train_cifar10.py --model "fm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000
```

To compute the FID from the OT-CFM model at end of training, run:

```bash
python3 compute_fid.py --model "otcfm" --step 400000 --integration_method dopri5
```

For the other models, change the "otcfm" argument by "cfm" or "fm". For easy reproducibility of our results, you can download the model weights at 400000 iterations here:

- [cfm weights](https://github.com/atong01/conditional-flow-matching/releases/download/1.0.4/cfm_cifar10_weights_step_400000.pt)

- [otcfm weights](https://github.com/atong01/conditional-flow-matching/releases/download/1.0.4/otcfm_cifar10_weights_step_400000.pt)

- [fm weights](https://github.com/atong01/conditional-flow-matching/releases/download/1.0.4/fm_cifar10_weights_step_400000.pt)

To recompute the FID, change the PATH variable with where you have saved the downloaded weights.

If you find this code useful in your research, please cite the following papers (expand for BibTeX):

<details>
Expand Down
105 changes: 105 additions & 0 deletions examples/cifar10/compute_fid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Inspired from https://github.com/w86763777/pytorch-ddpm/tree/master.

# Authors: Kilian Fatras
# Alexander Tong

import os
import sys

import matplotlib.pyplot as plt
import torch
from absl import app, flags
from cleanfid import fid
from torchdiffeq import odeint
from torchdyn.core import NeuralODE

from torchcfm.models.unet.unet import UNetModelWrapper

FLAGS = flags.FLAGS
# UNet
flags.DEFINE_integer("num_channel", 128, help="base channel of UNet")

# Training
flags.DEFINE_bool("parallel", False, help="multi gpu training")
flags.DEFINE_string("input_dir", "./results", help="output_directory")
flags.DEFINE_string("model", "otcfm", help="flow matching model type")
flags.DEFINE_integer("integration_steps", 100, help="number of inference steps")
flags.DEFINE_string("integration_method", "dopri5", help="integration method to use")
flags.DEFINE_integer("step", 400000, help="training steps")
flags.DEFINE_integer("num_gen", 50000, help="number of samples to generate")
flags.DEFINE_float("tol", 1e-5, help="Integrator tolerance (absolute and relative)")
FLAGS(sys.argv)


# Define the model
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")

new_net = UNetModelWrapper(
dim=(3, 32, 32),
num_res_blocks=2,
num_channels=FLAGS.num_channel,
channel_mult=[1, 2, 2, 2],
num_heads=4,
num_head_channels=64,
attention_resolutions="16",
dropout=0.1,
).to(device)


# Load the model
PATH = f"{FLAGS.input_dir}/{FLAGS.model}/cifar10_weights_step_{FLAGS.step}.pt"
print("path: ", PATH)
checkpoint = torch.load(PATH)
state_dict = checkpoint["ema_model"]
try:
new_net.load_state_dict(state_dict)
except RuntimeError:
from collections import OrderedDict

new_state_dict = OrderedDict()
for k, v in state_dict.items():
new_state_dict[k[7:]] = v
new_net.load_state_dict(new_state_dict)
new_net.eval()


# Define the integration method if euler is used
if FLAGS.integration_method == "euler":
node = NeuralODE(new_net, solver=FLAGS.integration_method)


def gen_1_img(unused_latent):
with torch.no_grad():
x = torch.randn(500, 3, 32, 32).to(device)
if FLAGS.integration_method == "euler":
print("Use method: ", FLAGS.integration_method)
t_span = torch.linspace(0, 1, FLAGS.integration_steps + 1).to(device)
traj = node.trajectory(x, t_span=t_span)
else:
print("Use method: ", FLAGS.integration_method)
t_span = torch.linspace(0, 1, 2).to(device)
traj = odeint(
new_net, x, t_span, rtol=FLAGS.tol, atol=FLAGS.tol, method=FLAGS.integration_method
)
traj = traj[-1, :] # .view([-1, 3, 32, 32]).clip(-1, 1)
img = (traj * 127.5 + 128).clip(0, 255).to(torch.uint8) # .permute(1, 2, 0)
return img


print("Start computing FID")
score = fid.compute_fid(
gen=gen_1_img,
dataset_name="cifar10",
batch_size=500,
dataset_res=32,
num_gen=FLAGS.num_gen,
dataset_split="train",
mode="legacy_tensorflow",
)
print()
print("FID has been computed")
# print()
# print("Total NFE: ", new_net.nfe)
print()
print("FID: ", score)
Loading

0 comments on commit 7b0b2bc

Please sign in to comment.