Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sigma type & doc #72

Merged
merged 6 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions torchcfm/conditional_flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# License: MIT License

import math
from typing import Union

import torch

Expand All @@ -31,7 +32,7 @@ def pad_t_like_x(t, x):
t: Vector (bs)
pad_t_like_x(t, x): Tensor (bs, 1, 1, 1)
"""
if isinstance(t, float):
if isinstance(t, (float, int)):
return t
return t.reshape(-1, *([1] * (x.dim() - 1)))

Expand All @@ -47,12 +48,12 @@ class ConditionalFlowMatcher:
- score function $\nabla log p_t(x|x0, x1)$
"""

def __init__(self, sigma: float = 0.0):
def __init__(self, sigma: Union[float, int] = 0.0):
r"""Initialize the ConditionalFlowMatcher class. It requires the hyper-parameter $\sigma$.

Parameters
----------
sigma : float
sigma : Union[float, int]
"""
self.sigma = sigma

Expand Down Expand Up @@ -215,15 +216,15 @@ class ExactOptimalTransportConditionalFlowMatcher(ConditionalFlowMatcher):
It overrides the sample_location_and_conditional_flow.
"""

def __init__(self, sigma: float = 0.0):
def __init__(self, sigma: Union[float, int] = 0.0):
r"""Initialize the ConditionalFlowMatcher class. It requires the hyper-parameter $\sigma$.

Parameters
----------
sigma : float
sigma : Union[float, int]
ot_sampler: exact OT method to draw couplings (x0, x1) (see Eq.(17) [1]).
"""
self.sigma = sigma
super().__init__(sigma)
self.ot_sampler = OTPlanSampler(method="exact")

def sample_location_and_conditional_flow(self, x0, x1, return_noise=False):
Expand Down Expand Up @@ -382,16 +383,16 @@ class SchrodingerBridgeConditionalFlowMatcher(ConditionalFlowMatcher):
sample_location_and_conditional_flow functions.
"""

def __init__(self, sigma: float = 1.0, ot_method="exact"):
def __init__(self, sigma: Union[float, int] = 1.0, ot_method="exact"):
r"""Initialize the SchrodingerBridgeConditionalFlowMatcher class. It requires the hyper-
parameter $\sigma$ and the entropic OT map.

Parameters
----------
sigma : float
sigma : Union[float, int]
ot_sampler: exact OT method to draw couplings (x0, x1) (see Eq.(17) [1]).
"""
self.sigma = sigma
super().__init__(sigma)
self.ot_method = ot_method
self.ot_sampler = OTPlanSampler(method=ot_method, reg=2 * self.sigma**2)

Expand Down
15 changes: 15 additions & 0 deletions torchcfm/optimal_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,21 @@ def __init__(
normalize_cost=False,
**kwargs,
):
r"""Initialize the OTPlanSampler class.

Parameters
----------
method : str
The method used to compute the OT plan. Can be one of "exact", "sinkhorn",
"unbalanced", or "partial".
reg : float (default : 0.05)
Entropic regularization coefficients.
reg_m : float (default : 1.0)
Marginal relaxation term for unbalanced OT (`method='unbalanced'`).
normalize_cost : bool (default : False)
Whether to normalize the cost matrix by its maximum value.
It should be set to `False` when using minibatches.
"""
# ot_fn should take (a, b, M) as arguments where a, b are marginals and
# M is a cost matrix
if method == "exact":
Expand Down
Loading