From 0ef7c77672f74a52b9cc326e247b3b16f4a6ce43 Mon Sep 17 00:00:00 2001 From: Kilian Date: Tue, 31 Oct 2023 15:34:28 -0400 Subject: [PATCH] update README within cifar experiments --- examples/cifar10/README.md | 33 ++++++++++++++++++++++++++++----- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/examples/cifar10/README.md b/examples/cifar10/README.md index 204b9c6..ddf3f1e 100644 --- a/examples/cifar10/README.md +++ b/examples/cifar10/README.md @@ -1,19 +1,42 @@ # CIFAR-10 experiments using TorchCFM -This repository is used to reproduce the CIFAR-10 experiments from [1](https://arxiv.org/abs/2302.00482). It is a repository in construction and we will add more features and details in the future (including FID computations and pre-trained weights). We have followed the experimental details provided in [2](https://openreview.net/forum?id=PqvMRDCJT9t). +This repository is used to reproduce the CIFAR-10 experiments from [1](https://arxiv.org/abs/2302.00482). We have designed a novel experimental procedure that helps us to reach an __FID of 3.5__ on the Cifar10 dataset. -To reproduce the experiment and save the weights, install the requirements from the main repository and then run (runs on a single RTX 2080 GPU): +

+ +

+To reproduce the experiments and save the weights, install the requirements from the main repository and then run (runs on a single RTX 2080 GPU): + +- For the OT-Conditional Flow Matching method: +```bash +python3 train_cifar10.py --model "otcfm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000 +``` + +- For the Conditional Flow Matching method: +```bash +python3 train_cifar10.py --model "cfm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000 +``` + +- For the original Flow Matching method: ```bash -python3 train_cifar10.py --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 800001 --save_step 20000 +python3 train_cifar10.py --model "fm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000 ``` -To run a script closer to the original Flow Matching paper, use the following script(might require several GPUs): +To compute the FID from the OT-CFM model at end of training, run: ```bash -python3 train_cifar10.py --lr 2e-4 --ema_decay 0.9999 --num_channel 256 --batch_size 256 --total_steps 400001 --save_step 20000 +python3 compute_fid.py --model "otcfm" --step 400000 --integration_method dopri5 ``` +For the other models, change the "otcfm" argument by "cfm" or "fm". For easy reproducibility of our results, we provide the saved weights of our different methods trained for 400000 iterations. + +- [Trained OT-CFM weights](http://alextong.net) +- [Trained CFM weights](http://alextong.net) +- [Trained FM weights](http://alextong.net) + +To recompute the FID, change the PATH variable with where you have saved the downloaded weights. + If you find this code useful in your research, please cite the following papers (expand for BibTeX):