forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Make dispatcher registrations of SymInt functions backwards compatible (
pytorch#84557) Previously, when we SymInt-ify a schema, this is a BC-breaking change for all people who registered functions for that function; they must accept c10::SymInt where they previously accepted int64_t. This is not great. With this change, I accept old type registrations transparently. The idea is in several parts: - At the registration site, at compile time I have no idea whether or not if the function being registered has a SymInt schema or not. So I must defer the exact compatibility check. What I do instead is check if the function pointer registered to me has SymInt in the argument or not. If it does, I assume it is new-style and ensure it is also registered to a special sym_ slot on KernelFunction. If not, it only goes in the conventional slot. - At the dispatcher site, I know at compile time whether or not this is a SymInt function. If it is, I check for a sym_ slot on the KernelFunction, and preferentially use that. If no such slot exists, I then fall back to the regular slot... but I convert all SymInt arguments to int64_t arguments (doing assertions that no true symbolic integer was passed.) I can skip this test entirely if the function doesn't have any SymInts in it; in that case I know that only the original slot could have been registered. Fortunately, both branches of the short circuit typecheck, so I didn't have to use SFINAE or if-constexpr to make it work; just a plain if statement that I expect the compiler to optimize away. - Schema validation is now modestly more complicated. There are two parts. First, function schema validation proceeds by checking if the signature in question has any SymInt-like types in it or not. If it does, we do function schema validation against the real types; if it doesn't, we do validation against the fake types (but only for symint; MemoryFormat is always MemoryFormat). Second, cpp signature validation also keeps track of a "symint" cpp signature and a "non-symint" cpp signature. We only compare symint with symint, and non-symint with non-symint. I did not implement checking a conflict between a symint and non-symint cpp signature, though in principle you could try converting the SymInt types to non-SymInt types and doing the comparison that way. To show it is working, I remove a bunch of c10::asIntArrayRefSlow shims, as the dispatcher is able to insert them automatically now. I didn't update the Metal registrations (though they can get similar treatment) as OSS CI coverage is insufficient for this case. Signed-off-by: Edward Z. Yang <[email protected]> Differential Revision: [D39280965](https://our.internmc.facebook.com/intern/diff/D39280965) Pull Request resolved: pytorch#84557 Approved by: https://github.com/wconstab
- Loading branch information
1 parent
ed46b96
commit 19e27b1
Showing
12 changed files
with
156 additions
and
89 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.