Skip to content

Commit

Permalink
update cfm to icfm
Browse files Browse the repository at this point in the history
  • Loading branch information
kilianFatras committed Nov 3, 2023
1 parent b4bbac9 commit bad11c1
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
6 changes: 3 additions & 3 deletions examples/cifar10/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ python3 train_cifar10.py --model "otcfm" --lr 2e-4 --ema_decay 0.9999 --batch_si
- For the Conditional Flow Matching method:

```bash
python3 train_cifar10.py --model "cfm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000
python3 train_cifar10.py --model "icfm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000
```

- For the original Flow Matching method:
Expand All @@ -32,9 +32,9 @@ To compute the FID from the OT-CFM model at end of training, run:
python3 compute_fid.py --model "otcfm" --step 400000 --integration_method dopri5
```

For the other models, change the "otcfm" argument by "cfm" or "fm". For easy reproducibility of our results, you can download the model weights at 400000 iterations here:
For the other models, change the "otcfm" argument by "icfm" or "fm". For easy reproducibility of our results, you can download the model weights at 400000 iterations here:

- [cfm weights](https://github.com/atong01/conditional-flow-matching/releases/download/1.0.4/cfm_cifar10_weights_step_400000.pt)
- [icfm weights](https://github.com/atong01/conditional-flow-matching/releases/download/1.0.4/cfm_cifar10_weights_step_400000.pt)

- [otcfm weights](https://github.com/atong01/conditional-flow-matching/releases/download/1.0.4/otcfm_cifar10_weights_step_400000.pt)

Expand Down
4 changes: 2 additions & 2 deletions examples/cifar10/train_cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,13 @@ def train(argv):
sigma = 0.0
if FLAGS.model == "otcfm":
FM = ExactOptimalTransportConditionalFlowMatcher(sigma=sigma)
elif FLAGS.model == "cfm":
elif FLAGS.model == "icfm":
FM = ConditionalFlowMatcher(sigma=sigma)
elif FLAGS.model == "fm":
FM = TargetConditionalFlowMatcher(sigma=sigma)
else:
raise NotImplementedError(
f"Unknown model {FLAGS.model}, must be one of ['otcfm', 'cfm', 'fm']"
f"Unknown model {FLAGS.model}, must be one of ['otcfm', 'icfm', 'fm']"
)

savedir = FLAGS.output_dir + FLAGS.model + "/"
Expand Down

0 comments on commit bad11c1

Please sign in to comment.