Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ExactOptimalTransportConditionalFlowMatcher with multiple conditions #132

Open
lukasschmit opened this issue Aug 11, 2024 · 1 comment
Open

Comments

@lukasschmit
Copy link

Hi all, huge fan of your work/this library, we've had great success with training latent flow matching models.

Our model relies on multiple conditions to generate x1, it looks like the optimal transport class only supports having a single condition for the prior (y0) and the target distribution (y1).

If our model needs two conditions for x1 (ya_1, yb_1), I'd assume that this would just require the ot_sampler.sample_plan_with_labels() to accept perhaps a list[Tensor] for y1 and return:

def sample_plan_with_labels(self, x0, x1, y0=None, y1=None, replace=True):
    pi = self.get_map(x0, x1)
    i, j = self.sample_map(pi, x0.shape[0], replace=replace)

    y1_pi = None
    if isinstance(y1, torch.Tensor):
        y1_pi = y1[j]
    elif isinstance(y1, list):
        y1_pi = [_y1[j] for _y1 in y1]

    return (
        x0[i],
        x1[j],
        # ...
        y1_pi
    )

I think something like this would allow the model to correctly accept multiple conditions along with the sampled xt?

Also it could be helpful to return the ot map pi for cases where downstream logic depends on the batch order (e.g. logging/loss aggregation which are per batch sample)

@atong01
Copy link
Owner

atong01 commented Aug 12, 2024

Yep this looks correct to me. Happy to consider proposed changes to the interface! May want to consider other batching strategies (dependent on your conditions), but I haven't looked into this deeply.

--Alex

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants