Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG]: Piecewise not in torch_mappings #639

Open
tbuckworth opened this issue Jun 3, 2024 · 5 comments
Open

[BUG]: Piecewise not in torch_mappings #639

tbuckworth opened this issue Jun 3, 2024 · 5 comments
Assignees
Labels
bug Something isn't working

Comments

@tbuckworth
Copy link

What happened?

after fitting a pysr module with "greater" as a binary operator, exporting to torch failed with the following error:

KeyError: 'Function Piecewise was not found in Torch function mappings.Please add it to extra_torch_mappings in the format, e.g., {sympy.sqrt: torch.sqrt}.'

I've seen that in #433 Piecewise was added to the mappings, so I'm surprised to see this error.

I did attempt to fix myself, but it didn't work out:
I've tried adding mappings such as:

{sympy.Piecewise: lambda x, y: torch.where(x[1], x[0], y[0])}

but then the same error arises for sympy.functions.elementary.piecewise.ExprCondPair and then sympy.logic.boolalg.BooleanTrue

in the end, I added

extra_torch_mappings = {
        sympy.Piecewise: lambda x, y: torch.where(x[1], x[0], y[0]),
        sympy.functions.elementary.piecewise.ExprCondPair: tuple,
        sympy.logic.boolalg.BooleanTrue: torch.BoolTensor,
        "greater": lambda x, y: torch.where(x > y, 1.0, 0.0),
    }

But even this produced the following error:

KeyError: 'Function ITE was not found in Torch function mappings.Please add it to extra_torch_mappings in the format, e.g., {sympy.sqrt: torch.sqrt}.'

Hopefully, I am missing something obvious?

Version

0.18.4

Operating System

Linux

Package Manager

pip

Interface

Script (i.e., python my_script.py)

Relevant log output

No response

Extra Info

No response

@tbuckworth tbuckworth added the bug Something isn't working label Jun 3, 2024
@tbuckworth
Copy link
Author

I just realised that #433 is a pull request, so I copied the code and used it to add the mappings manually.
However, I'm still getting the error:
KeyError: 'Function ITE was not found in Torch function mappings.Please add it to extra_torch_mappings in the format, e.g., {sympy.sqrt: torch.sqrt}.'

@tbuckworth
Copy link
Author

I've added this mapping, which seems to circumvent the error, but I haven't fully tested it yet:

def if_then_else(*conds):
    a, b, c = conds
    return torch.where(a, torch.where(b, True, False), torch.where(c, True, False))

extra_torch_mappings = {sympy.logic.boolalg.ITE: if_then_else}

@MilesCranmer
Copy link
Owner

Nice! Yeah that should be added to the GitHub pull request. Feel free to suggest that on the PR via the review system and you will be credited as a coauthor of the PR.

@tbuckworth
Copy link
Author

Thanks! I'll add a review comment on the PR.

There was another error with piecewise, when cond is a float (1.), but I fixed it by replacing cond with cond.bool():

output += torch.where(
                    cond.bool() & ~already_used, expr, torch.zeros_like(expr)
                )
                already_used = already_used | cond.bool()

Now, as long as I use a single batch dimension, it works, but multiple batch dimensions fail.

I believe this is due to export_torch.py, where _SingleSymPyModule.forward is:

            def forward(self, X):
                if self._selection is not None:
                    X = X[:, self._selection]
                symbols = {symbol: X[:, i] for i, symbol in enumerate(self.symbols_in)}
                return self._node(symbols)

if X[:, is replaced with X[..., then i believe it will work. This is a separate issue though, I suppose

@MilesCranmer
Copy link
Owner

(Just leaving it open until that PR is closed, since there are still some TODO items)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants