Skip to content

Commit

Permalink
pep 8
Browse files Browse the repository at this point in the history
  • Loading branch information
kilianFatras committed Oct 27, 2023
1 parent 5ffda36 commit 6dd7a13
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 25 deletions.
1 change: 1 addition & 0 deletions examples/cifar10/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ python3 train_cifar10.py --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_s
```

To run a script closer to the original Flow Matching paper, use the following script(might require several GPUs):

```bash
python3 train_cifar10.py --lr 2e-4 --ema_decay 0.9999 --num_channel 256 --batch_size 256 --total_steps 400001 --save_step 20000
```
Expand Down
15 changes: 11 additions & 4 deletions examples/cifar10/compute_fid.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
# Inspired from https://github.com/w86763777/pytorch-ddpm/tree/master.

# Authors: Kilian Fatras
# Alexander Tong

import os
import sys

import matplotlib.pyplot as plt
import torch
from absl import app, flags
from cleanfid import fid
from torchcfm.models.unet.unet import UNetModelWrapper
from torchdiffeq import odeint
from torchdyn.core import NeuralODE

from torchcfm.models.unet.unet import UNetModelWrapper

FLAGS = flags.FLAGS
# UNet
flags.DEFINE_integer("num_channel", 128, help="base channel of UNet")
Expand All @@ -21,7 +27,7 @@
flags.DEFINE_string("integration_method", "dopri5", help="integration method to use")
flags.DEFINE_integer("step", 400000, help="training steps")
flags.DEFINE_integer("num_gen", 50000, help="number of samples to generate")
flags.DEFINE_float("tol", 1e-5, help="Integrator tolerence (absolute and relative)")
flags.DEFINE_float("tol", 1e-5, help="Integrator tolerance (absolute and relative)")
FLAGS(sys.argv)


Expand Down Expand Up @@ -73,7 +79,9 @@ def gen_1_img(unused_latent):
else:
print("Use method: ", FLAGS.integration_method)
t_span = torch.linspace(0, 1, 2).to(device)
traj = odeint(new_net, x, t_span, rtol=FLAGS.tol, atol=FLAGS.tol, method=FLAGS.integration_method)
traj = odeint(
new_net, x, t_span, rtol=FLAGS.tol, atol=FLAGS.tol, method=FLAGS.integration_method
)
traj = traj[-1, :] # .view([-1, 3, 32, 32]).clip(-1, 1)
img = (traj * 127.5 + 128).clip(0, 255).to(torch.uint8) # .permute(1, 2, 0)
return img
Expand All @@ -95,4 +103,3 @@ def gen_1_img(unused_latent):
print("Total NFE: ", new_net.nfe)
print()
print("FID: ", score)

30 changes: 14 additions & 16 deletions examples/cifar10/train_cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,17 @@

import torch
from absl import app, flags
from torchcfm.conditional_flow_matching import (
ConditionalFlowMatcher, ExactOptimalTransportConditionalFlowMatcher,
TargetConditionalFlowMatcher)
from torchcfm.models.unet.unet import UNetModelWrapper
from torchdyn.core import NeuralODE
from torchvision import datasets, transforms
from tqdm import trange
from utils_cifar import ema, generate_samples, infiniteloop

from utils_cifar import *
from torchcfm.conditional_flow_matching import (
ConditionalFlowMatcher,
ExactOptimalTransportConditionalFlowMatcher,
TargetConditionalFlowMatcher,
)
from torchcfm.models.unet.unet import UNetModelWrapper

FLAGS = flags.FLAGS

Expand All @@ -26,14 +28,14 @@
flags.DEFINE_integer("num_channel", 128, help="base channel of UNet")

# Training
flags.DEFINE_float("lr", 2e-4, help="target learning rate") ## TRY 2e-4
flags.DEFINE_float("lr", 2e-4, help="target learning rate") # TRY 2e-4
flags.DEFINE_float("grad_clip", 1.0, help="gradient norm clipping")
flags.DEFINE_integer(
"total_steps", 400001, help="total training steps"
) # Lipman et al uses 400k but double batch size
flags.DEFINE_integer("img_size", 32, help="image size")
flags.DEFINE_integer("warmup", 5000, help="learning rate warmup")
flags.DEFINE_integer("batch_size", 128, help="batch size") ##Lipman et al uses 128
flags.DEFINE_integer("batch_size", 128, help="batch size") # Lipman et al uses 128
flags.DEFINE_integer("num_workers", 4, help="workers of Dataloader")
flags.DEFINE_float("ema_decay", 0.9999, help="ema decay rate")
flags.DEFINE_bool("parallel", False, help="multi gpu training")
Expand All @@ -47,9 +49,7 @@
flags.DEFINE_integer(
"eval_step", 0, help="frequency of evaluating model, 0 to disable during training"
)
flags.DEFINE_integer(
"num_images", 50000, help="the number of generated images for evaluation"
)
flags.DEFINE_integer("num_images", 50000, help="the number of generated images for evaluation")


use_cuda = torch.cuda.is_available()
Expand All @@ -69,7 +69,7 @@ def train(argv):
FLAGS.save_step,
)

#### DATASETS/DATALOADER
# DATASETS/DATALOADER
dataset = datasets.CIFAR10(
root="./data",
train=True,
Expand All @@ -92,7 +92,7 @@ def train(argv):

datalooper = infiniteloop(dataloader)

#### MODELS
# MODELS
net_model = UNetModelWrapper(
dim=(3, 32, 32),
num_res_blocks=2,
Expand Down Expand Up @@ -149,9 +149,7 @@ def train(argv):
vt = net_model(t, xt)
loss = torch.mean((vt - ut) ** 2)
loss.backward()
torch.nn.utils.clip_grad_norm_(
net_model.parameters(), FLAGS.grad_clip
) # new
torch.nn.utils.clip_grad_norm_(net_model.parameters(), FLAGS.grad_clip) # new
optim.step()
sched.step()
ema(net_model, ema_model, FLAGS.ema_decay) # new
Expand All @@ -168,7 +166,7 @@ def train(argv):
"optim": optim.state_dict(),
"step": step,
},
savedir + "cifar10_weights_step_{}.pt".format(step),
savedir + f"cifar10_weights_step_{step}.pt",
)


Expand Down
8 changes: 4 additions & 4 deletions examples/cifar10/utils_cifar.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import torch
from torchdyn.core import NeuralODE
#from torchvision.transforms import ToPILImage

# from torchvision.transforms import ToPILImage
from torchvision.utils import make_grid, save_image

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")


def generate_samples(node_, model, savedir, step, net_="normal"):
model.eval()
with torch.no_grad():
Expand All @@ -15,9 +17,7 @@ def generate_samples(node_, model, savedir, step, net_="normal"):
)
traj = traj[-1, :].view([-1, 3, 32, 32]).clip(-1, 1)
traj = traj / 2 + 0.5
save_image(
traj, savedir + "{}_generated_FM_images_step_{}.png".format(net_, step), nrow=8
)
save_image(traj, savedir + f"{net_}_generated_FM_images_step_{step}.png", nrow=8)

model.train()

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ torchdyn>=1.0.7 # 1.0.4 is broken on pypi
pot
torchdiffeq==0.2.3
absl-py
clean-fid
clean-fid

0 comments on commit 6dd7a13

Please sign in to comment.