Skip to content

Commit

Permalink
update README within cifar experiments
Browse files Browse the repository at this point in the history
  • Loading branch information
kilianFatras committed Oct 31, 2023
1 parent c0c7181 commit 0ef7c77
Showing 1 changed file with 28 additions and 5 deletions.
33 changes: 28 additions & 5 deletions examples/cifar10/README.md
Original file line number Diff line number Diff line change
@@ -1,19 +1,42 @@
# CIFAR-10 experiments using TorchCFM

This repository is used to reproduce the CIFAR-10 experiments from [1](https://arxiv.org/abs/2302.00482). It is a repository in construction and we will add more features and details in the future (including FID computations and pre-trained weights). We have followed the experimental details provided in [2](https://openreview.net/forum?id=PqvMRDCJT9t).
This repository is used to reproduce the CIFAR-10 experiments from [1](https://arxiv.org/abs/2302.00482). We have designed a novel experimental procedure that helps us to reach an __FID of 3.5__ on the Cifar10 dataset.

To reproduce the experiment and save the weights, install the requirements from the main repository and then run (runs on a single RTX 2080 GPU):
<p align="center">
<img src="../../assets/169_generated_samples_otcfm.gif" width="600"/>
</p>

To reproduce the experiments and save the weights, install the requirements from the main repository and then run (runs on a single RTX 2080 GPU):

- For the OT-Conditional Flow Matching method:
```bash
python3 train_cifar10.py --model "otcfm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000
```

- For the Conditional Flow Matching method:
```bash
python3 train_cifar10.py --model "cfm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000
```

- For the original Flow Matching method:
```bash
python3 train_cifar10.py --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 800001 --save_step 20000
python3 train_cifar10.py --model "fm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000
```

To run a script closer to the original Flow Matching paper, use the following script(might require several GPUs):
To compute the FID from the OT-CFM model at end of training, run:

```bash
python3 train_cifar10.py --lr 2e-4 --ema_decay 0.9999 --num_channel 256 --batch_size 256 --total_steps 400001 --save_step 20000
python3 compute_fid.py --model "otcfm" --step 400000 --integration_method dopri5
```

For the other models, change the "otcfm" argument by "cfm" or "fm". For easy reproducibility of our results, we provide the saved weights of our different methods trained for 400000 iterations.

- [Trained OT-CFM weights](http://alextong.net)
- [Trained CFM weights](http://alextong.net)
- [Trained FM weights](http://alextong.net)

To recompute the FID, change the PATH variable with where you have saved the downloaded weights.

If you find this code useful in your research, please cite the following papers (expand for BibTeX):

<details>
Expand Down

0 comments on commit 0ef7c77

Please sign in to comment.