Skip to content

Commit

Permalink
Merge pull request #44 from kilianFatras/conditional_ex
Browse files Browse the repository at this point in the history
Add conditional MNIST examples notebook
  • Loading branch information
atong01 authored Sep 12, 2023
2 parents 3fd8d15 + 3c06643 commit 1588442
Show file tree
Hide file tree
Showing 6 changed files with 578 additions and 25 deletions.
446 changes: 446 additions & 0 deletions examples/notebooks/conditional_mnist.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ scanpy
timm
torchdyn>=1.0.5 # 1.0.4 is broken on pypi
pot
torchdiffeq==0.2.3
116 changes: 93 additions & 23 deletions torchcfm/conditional_flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def compute_mu_t(self, x0, x1, t):
x0 : Tensor, shape (bs, *dim)
represents the source minibatch
x1 : Tensor, shape (bs, *dim)
represents the source minibatch
represents the target minibatch
t : FloatTensor, shape (bs)
Returns
Expand All @@ -85,10 +85,6 @@ def compute_sigma_t(self, t):
Parameters
----------
x0 : Tensor, shape (bs, *dim)
represents the source minibatch
x1 : Tensor, shape (bs, *dim)
represents the source minibatch
t : FloatTensor, shape (bs)
Returns
Expand All @@ -111,7 +107,7 @@ def sample_xt(self, x0, x1, t, epsilon):
x0 : Tensor, shape (bs, *dim)
represents the source minibatch
x1 : Tensor, shape (bs, *dim)
represents the source minibatch
represents the target minibatch
t : FloatTensor, shape (bs)
epsilon : Tensor, shape (bs, *dim)
noise sample from N(0, 1)
Expand All @@ -138,7 +134,7 @@ def compute_conditional_flow(self, x0, x1, t, xt):
x0 : Tensor, shape (bs, *dim)
represents the source minibatch
x1 : Tensor, shape (bs, *dim)
represents the source minibatch
represents the target minibatch
t : FloatTensor, shape (bs)
xt : Tensor, shape (bs, *dim)
represents the samples drawn from probability path pt
Expand Down Expand Up @@ -167,7 +163,7 @@ def sample_location_and_conditional_flow(self, x0, x1, return_noise=False):
x0 : Tensor, shape (bs, *dim)
represents the source minibatch
x1 : Tensor, shape (bs, *dim)
represents the source minibatch
represents the target minibatch
return_noise : bool
return the noise sample epsilon
Expand Down Expand Up @@ -241,7 +237,7 @@ def sample_location_and_conditional_flow(self, x0, x1, return_noise=False):
x0 : Tensor, shape (bs, *dim)
represents the source minibatch
x1 : Tensor, shape (bs, *dim)
represents the source minibatch
represents the target minibatch
return_noise : bool
return the noise sample epsilon
Expand All @@ -260,6 +256,47 @@ def sample_location_and_conditional_flow(self, x0, x1, return_noise=False):
x0, x1 = self.ot_sampler.sample_plan(x0, x1)
return super().sample_location_and_conditional_flow(x0, x1, return_noise)

def guided_sample_location_and_conditional_flow(
self, x0, x1, y0=None, y1=None, return_noise=False
):
r"""
Compute the sample xt (drawn from N(t * x1 + (1 - t) * x0, sigma))
and the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1]
with respect to the minibatch OT plan $\Pi$.
Parameters
----------
x0 : Tensor, shape (bs, *dim)
represents the source minibatch
x1 : Tensor, shape (bs, *dim)
represents the target minibatch
y0 : Tensor, shape (bs) (default: None)
represents the source label minibatch
y1 : Tensor, shape (bs) (default: None)
represents the target label minibatch
return_noise : bool
return the noise sample epsilon
Returns
-------
t : FloatTensor, shape (bs)
xt : Tensor, shape (bs, *dim)
represents the samples drawn from probability path pt
ut : conditional vector field ut(x1|x0) = x1 - x0
(optionally) epsilon : Tensor, shape (bs, *dim) such that xt = mu_t + sigma_t * epsilon
References
----------
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
"""
x0, x1, y0, y1 = self.ot_sampler.sample_plan_with_labels(x0, x1, y0, y1)
if return_noise:
t, xt, ut, eps = super().sample_location_and_conditional_flow(x0, x1, return_noise)
return t, xt, ut, y0, y1, eps
else:
t, xt, ut = super().sample_location_and_conditional_flow(x0, x1, return_noise)
return t, xt, ut, y0, y1


class TargetConditionalFlowMatcher(ConditionalFlowMatcher):
"""Lipman et al. 2023 style target OT conditional flow matching. This class inherits the
Expand All @@ -277,7 +314,7 @@ def compute_mu_t(self, x0, x1, t):
x0 : Tensor, shape (bs, *dim)
represents the source minibatch
x1 : Tensor, shape (bs, *dim)
represents the source minibatch
represents the target minibatch
t : FloatTensor, shape (bs)
Returns
Expand All @@ -297,10 +334,6 @@ def compute_sigma_t(self, t):
Parameters
----------
x0 : Tensor, shape (bs, *dim)
represents the source minibatch
x1 : Tensor, shape (bs, *dim)
represents the source minibatch
t : FloatTensor, shape (bs)
Returns
Expand All @@ -322,7 +355,7 @@ def compute_conditional_flow(self, x0, x1, t, xt):
x0 : Tensor, shape (bs, *dim)
represents the source minibatch
x1 : Tensor, shape (bs, *dim)
represents the source minibatch
represents the target minibatch
t : FloatTensor, shape (bs)
xt : Tensor, shape (bs, *dim)
represents the samples drawn from probability path pt
Expand Down Expand Up @@ -367,10 +400,6 @@ def compute_sigma_t(self, t):
Parameters
----------
x0 : Tensor, shape (bs, *dim)
represents the source minibatch
x1 : Tensor, shape (bs, *dim)
represents the source minibatch
t : FloatTensor, shape (bs)
Returns
Expand All @@ -394,7 +423,7 @@ def compute_conditional_flow(self, x0, x1, t, xt):
x0 : Tensor, shape (bs, *dim)
represents the source minibatch
x1 : Tensor, shape (bs, *dim)
represents the source minibatch
represents the target minibatch
t : FloatTensor, shape (bs)
xt : Tensor, shape (bs, *dim)
represents the samples drawn from probability path pt
Expand Down Expand Up @@ -426,7 +455,7 @@ def sample_location_and_conditional_flow(self, x0, x1, return_noise=False):
x0 : Tensor, shape (bs, *dim)
represents the source minibatch
x1 : Tensor, shape (bs, *dim)
represents the source minibatch
represents the target minibatch
return_noise: bool
return the noise sample epsilon
Expand All @@ -446,6 +475,47 @@ def sample_location_and_conditional_flow(self, x0, x1, return_noise=False):
x0, x1 = self.ot_sampler.sample_plan(x0, x1)
return super().sample_location_and_conditional_flow(x0, x1, return_noise)

def guided_sample_location_and_conditional_flow(
self, x0, x1, y0=None, y1=None, return_noise=False
):
r"""
Compute the sample xt (drawn from N(t * x1 + (1 - t) * x0, sigma))
and the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1]
with respect to the minibatch entropic OT plan $\Pi$.
Parameters
----------
x0 : Tensor, shape (bs, *dim)
represents the source minibatch
x1 : Tensor, shape (bs, *dim)
represents the target minibatch
y0 : Tensor, shape (bs) (default: None)
represents the source label minibatch
y1 : Tensor, shape (bs) (default: None)
represents the target label minibatch
return_noise : bool
return the noise sample epsilon
Returns
-------
t : FloatTensor, shape (bs)
xt : Tensor, shape (bs, *dim)
represents the samples drawn from probability path pt
ut : conditional vector field ut(x1|x0) = x1 - x0
(optionally) epsilon : Tensor, shape (bs, *dim) such that xt = mu_t + sigma_t * epsilon
References
----------
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
"""
x0, x1, y0, y1 = self.ot_sampler.sample_plan_with_labels(x0, x1, y0, y1)
if return_noise:
t, xt, ut, eps = super().sample_location_and_conditional_flow(x0, x1, return_noise)
return t, xt, ut, y0, y1, eps
else:
t, xt, ut = super().sample_location_and_conditional_flow(x0, x1, return_noise)
return t, xt, ut, y0, y1


class VariancePreservingConditionalFlowMatcher(ConditionalFlowMatcher):
"""Albergo et al. 2023 trigonometric interpolants class. This class inherits the
Expand All @@ -463,7 +533,7 @@ def compute_mu_t(self, x0, x1, t):
x0 : Tensor, shape (bs, *dim)
represents the source minibatch
x1 : Tensor, shape (bs, *dim)
represents the source minibatch
represents the target minibatch
t : FloatTensor, shape (bs)
Returns
Expand All @@ -487,7 +557,7 @@ def compute_conditional_flow(self, x0, x1, t, xt):
x0 : Tensor, shape (bs, *dim)
represents the source minibatch
x1 : Tensor, shape (bs, *dim)
represents the source minibatch
represents the target minibatch
t : FloatTensor, shape (bs)
xt : Tensor, shape (bs, *dim)
represents the samples drawn from probability path pt
Expand Down
3 changes: 2 additions & 1 deletion torchcfm/models/unet/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,6 +865,7 @@ def __init__(
channel_mult=None,
learn_sigma=False,
class_cond=False,
num_classes=NUM_CLASSES,
use_checkpoint=False,
attention_resolutions="16",
num_heads=1,
Expand Down Expand Up @@ -909,7 +910,7 @@ def __init__(
attention_resolutions=tuple(attention_ds),
dropout=dropout,
channel_mult=channel_mult,
num_classes=(NUM_CLASSES if class_cond else None),
num_classes=(num_classes if class_cond else None),
use_checkpoint=use_checkpoint,
use_fp16=use_fp16,
num_heads=num_heads,
Expand Down
35 changes: 35 additions & 0 deletions torchcfm/optimal_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,41 @@ def sample_plan(self, x0, x1):
i, j = self.sample_map(pi, x0.shape[0])
return x0[i], x1[j]

def sample_plan_with_labels(self, x0, x1, y0=None, y1=None):
r"""Compute the OT plan $\pi$ (wrt squared Euclidean cost) between a source and a target
minibatch and draw source and target labeled samples from pi $(x,z) \sim \pi$
Parameters
----------
x0 : Tensor, shape (bs, *dim)
represents the source minibatch
x1 : Tensor, shape (bs, *dim)
represents the target minibatch
y0 : Tensor, shape (bs)
represents the source label minibatch
y1 : Tensor, shape (bs)
represents the target label minibatch
Returns
-------
x0[i] : Tensor, shape (bs, *dim)
represents the source minibatch drawn from $\pi$
x1[j] : Tensor, shape (bs, *dim)
represents the target minibatch drawn from $\pi$
y0[i] : Tensor, shape (bs, *dim)
represents the source label minibatch drawn from $\pi$
y1[j] : Tensor, shape (bs, *dim)
represents the target label minibatch drawn from $\pi$
"""
pi = self.get_map(x0, x1)
i, j = self.sample_map(pi, x0.shape[0])
return (
x0[i],
x1[j],
y0[i] if y0 is not None else None,
y1[j] if y1 is not None else None,
)

def sample_trajectory(self, X):
"""Compute the OT trajectories between different sample populations moving from the source
to the target distribution.
Expand Down
2 changes: 1 addition & 1 deletion torchcfm/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.0.1"
__version__ = "1.0.2"

0 comments on commit 1588442

Please sign in to comment.