Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sigma type & doc #72

Merged
merged 6 commits into from
Nov 14, 2023
Merged

sigma type & doc #72

merged 6 commits into from
Nov 14, 2023

Conversation

guillaumehu
Copy link
Contributor

Adding docstrings for the __init__ of OTPlanSampler, and fixing issue #56 by checking the type of sigma during initialization.

@guillaumehu guillaumehu marked this pull request as ready for review November 13, 2023 13:21
@kilianFatras
Copy link
Collaborator

Hello Guillaume!

Thanks for the PR. I think the sigma type should be checked for all classes... That includes OT-CFM, SB-CFM, FM, Stochastic Interpolants. @josephdviviano what do you think? We want to ensure that sigma is a float as when it is an int, we get a bug.

@atong01
Copy link
Owner

atong01 commented Nov 13, 2023

I think we want to make sure sigma and t have the same type as x at runtime and not type check on init. I think if someone has x which is a float64 tensor right now we might also have problems. Good place to put a test @kilianFatras 😆.

@guillaumehu
Copy link
Contributor Author

@atong01 I just tested it, and it is working with float64 even if sigma is initialized as int, float or tensor (32 or 64).
't' has the right dtype since you already specify t = torch.rand(x0.shape[0]).type_as(x0).

@kilianFatras
Copy link
Collaborator

@atong01 Maybe a better solution would be to make the type of sigma_t to match the type of x within the compute_xt function?

@guillaumehu
Copy link
Contributor Author

But they already have the same type, since torch works with mixed precision.
The problem is here on L34.

if isinstance(t, float):
return t
return t.reshape(-1, *([1] * (x.dim() - 1)))

A simple fix is replacing L34 with if isinstance(t, (float, int)):, this works without checking the type of sigma in the initialization.

@josephdviviano
Copy link
Collaborator

I think I don't fully understand how the code works. Just a moment. But initial thoughts:

torch works with mixed precision.

  • When I read mixed precision I think of using a mix of 16 and 32 bit floating point arithmetic during training so you can get a bigger / faster model out of some hardware. Here I think you're referring to multiplying int and float, which should always return a float. This might be confusing for the user, if they submit a int and expect another int in return. It's fine if things need to be this way, but I think it would be good for it to be consistent and documented (either types are always conserved or they're always cast to float). Sorry in advance if I misunderstand the code and one of those two conditions already holds.
  • I still think the @property should not do any type casting, because this variable self.sigma should always equal self._sigma, it leaves the door open for a very confusing user experience if they're grappling with some strange type error.

Copy link
Collaborator

@josephdviviano josephdviviano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not entirely sure what the correct solution is but I don't want part of it to include the type casting in the @property, sorry to be annoying about this.

@@ -54,7 +54,14 @@ def __init__(self, sigma: float = 0.0):
----------
sigma : float
"""
self.sigma = sigma
self._sigma = sigma
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why aren't we enforcing the type here?

My concern is that self.sigma != self._sigma, which will lead to very very confusing bugs.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @atong01 - you can use x to typecast all other variables if required so that types don't change magically for the user. I stand by what I said before though, I.e., sigma should be enforced to be a particular type if required, and the property should not do any typecasting.

if isinstance(self._sigma, int):
return float(self._sigma)
else:
return self._sigma
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should do this for reasons I've explained -- I think it is confusing. If we need sigma to always be a float we should typecast/typecheck in the init, or, we should write the logic such that for any input, the typing is conserved.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes good point, I agree that it becomes confusing. I think we should just fix L34 mentioned above, as it does not work for type(t) = int, this way we don't have to typecheck sigma in the init or with the @property.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. I'm not sure if this is intentional but you are type checking sigma -- maybe you want to remove this as well?

https://github.com/atong01/conditional-flow-matching/blob/5975eac652b47b9f8ba3d3aa1883aaac64e73193/torchcfm/conditional_flow_matching.py#L50C6-L50C6

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just changed it, thanks for pointing it out.

@kilianFatras
Copy link
Collaborator

I let @josephdviviano decide when this PR is ready as he has more knowledge than me on how to fix a variable's type.

@guillaumehu once this PR is ready, please prepare a test to add to the add_tests branch (or to this PR but I would prefer all tests within the same PR). I have made tests for all classes within torchcfm and we need a test on sigma's type as well.

Copy link
Collaborator

@josephdviviano josephdviviano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved with a comment :)

"""
self._sigma = sigma

@property
Copy link
Collaborator

@josephdviviano josephdviviano Nov 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. One comment - the @property method is useful for allowing one to "get" a variable, without allowing the user to change it. So if someone did

obj = FlowMatcher(sigma=1.0)
obj.sigma = 2

they would get an error. They could be doing something like obj._sigma = 2 to overwrite the property, but they would be very intentionally doing this. So it's mostly a way of telling the user how to not shoot themselves in the foot.

If you think it would be bad practice for people to be overwriting the obj.sigma property, then adding the @property logic back in still makes sense (sans the type casting).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @josephdviviano for the clarification! I am not sure what is the best for sigma, I think it is fine for now if we keep it like this, but happy to make the change if someone thinks otherwise.

@kilianFatras kilianFatras merged commit 7cb209d into atong01:main Nov 14, 2023
16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants