Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sigma type & doc #72

Merged
merged 6 commits into from
Nov 14, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 9 additions & 15 deletions torchcfm/conditional_flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# License: MIT License

import math
from typing import Union

import torch

Expand All @@ -31,7 +32,7 @@ def pad_t_like_x(t, x):
t: Vector (bs)
pad_t_like_x(t, x): Tensor (bs, 1, 1, 1)
"""
if isinstance(t, float):
if isinstance(t, (float, int)):
return t
return t.reshape(-1, *([1] * (x.dim() - 1)))

Expand All @@ -47,21 +48,14 @@ class ConditionalFlowMatcher:
- score function $\nabla log p_t(x|x0, x1)$
"""

def __init__(self, sigma: float = 0.0):
def __init__(self, sigma: Union[float, int] = 0.0):
r"""Initialize the ConditionalFlowMatcher class. It requires the hyper-parameter $\sigma$.

Parameters
----------
sigma : float
sigma : Union[float, int]
"""
self._sigma = sigma

@property
Copy link
Collaborator

@josephdviviano josephdviviano Nov 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. One comment - the @property method is useful for allowing one to "get" a variable, without allowing the user to change it. So if someone did

obj = FlowMatcher(sigma=1.0)
obj.sigma = 2

they would get an error. They could be doing something like obj._sigma = 2 to overwrite the property, but they would be very intentionally doing this. So it's mostly a way of telling the user how to not shoot themselves in the foot.

If you think it would be bad practice for people to be overwriting the obj.sigma property, then adding the @property logic back in still makes sense (sans the type casting).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @josephdviviano for the clarification! I am not sure what is the best for sigma, I think it is fine for now if we keep it like this, but happy to make the change if someone thinks otherwise.

def sigma(self):
if isinstance(self._sigma, int):
return float(self._sigma)
else:
return self._sigma
self.sigma = sigma

def compute_mu_t(self, x0, x1, t):
"""
Expand Down Expand Up @@ -222,12 +216,12 @@ class ExactOptimalTransportConditionalFlowMatcher(ConditionalFlowMatcher):
It overrides the sample_location_and_conditional_flow.
"""

def __init__(self, sigma: float = 0.0):
def __init__(self, sigma: Union[float, int] = 0.0):
r"""Initialize the ConditionalFlowMatcher class. It requires the hyper-parameter $\sigma$.

Parameters
----------
sigma : float
sigma : Union[float, int]
ot_sampler: exact OT method to draw couplings (x0, x1) (see Eq.(17) [1]).
"""
super().__init__(sigma)
Expand Down Expand Up @@ -389,13 +383,13 @@ class SchrodingerBridgeConditionalFlowMatcher(ConditionalFlowMatcher):
sample_location_and_conditional_flow functions.
"""

def __init__(self, sigma: float = 1.0, ot_method="exact"):
def __init__(self, sigma: Union[float, int] = 1.0, ot_method="exact"):
r"""Initialize the SchrodingerBridgeConditionalFlowMatcher class. It requires the hyper-
parameter $\sigma$ and the entropic OT map.

Parameters
----------
sigma : float
sigma : Union[float, int]
ot_sampler: exact OT method to draw couplings (x0, x1) (see Eq.(17) [1]).
"""
super().__init__(sigma)
Expand Down
Loading