Skip to content

Commit

Permalink
Make dispatcher registrations of SymInt functions backwards compatible (
Browse files Browse the repository at this point in the history
pytorch#84557)

Previously, when we SymInt-ify a schema, this is a BC-breaking change
for all people who registered functions for that function; they
must accept c10::SymInt where they previously accepted int64_t.
This is not great.

With this change, I accept old type registrations transparently.  The
idea is in several parts:

- At the registration site, at compile time I have no idea whether or not
  if the function being registered has a SymInt schema or not.  So I
  must defer the exact compatibility check.  What I do instead is
  check if the function pointer registered to me has SymInt in the
  argument or not.  If it does, I assume it is new-style and ensure
  it is also registered to a special sym_ slot on KernelFunction.
  If not, it only goes in the conventional slot.

- At the dispatcher site, I know at compile time whether or not this
  is a SymInt function.  If it is, I check for a sym_ slot on the
  KernelFunction, and preferentially use that.  If no such slot
  exists, I then fall back to the regular slot... but I convert
  all SymInt arguments to int64_t arguments (doing assertions that
  no true symbolic integer was passed.)  I can skip this test entirely
  if the function doesn't have any SymInts in it; in that case I know
  that only the original slot could have been registered. Fortunately,
  both branches of the short circuit typecheck, so I didn't have to
  use SFINAE or if-constexpr to make it work; just a plain if statement
  that I expect the compiler to optimize away.

- Schema validation is now modestly more complicated. There are two parts. First, function schema validation proceeds by checking if the signature in question has any SymInt-like types in it or not. If it does, we do function schema validation against the real types; if it doesn't, we do validation against the fake types (but only for symint; MemoryFormat is always MemoryFormat). Second, cpp signature validation also keeps track of a "symint" cpp signature and a "non-symint" cpp signature. We only compare symint with symint, and non-symint with non-symint. I did not implement checking a conflict between a symint and non-symint cpp signature, though in principle you could try converting the SymInt types to non-SymInt types and doing the comparison that way.

To show it is working, I remove a bunch of c10::asIntArrayRefSlow shims, as the dispatcher is able to insert them automatically now.

I didn't update the Metal registrations (though they can get similar treatment) as OSS CI coverage is insufficient for this case.

Signed-off-by: Edward Z. Yang <[email protected]>

Differential Revision: [D39280965](https://our.internmc.facebook.com/intern/diff/D39280965)
Pull Request resolved: pytorch#84557
Approved by: https://github.com/wconstab
  • Loading branch information
ezyang authored and pytorchmergebot committed Sep 7, 2022
1 parent ed46b96 commit 19e27b1
Show file tree
Hide file tree
Showing 12 changed files with 156 additions and 89 deletions.
27 changes: 3 additions & 24 deletions aten/src/ATen/BatchingRegistrations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,6 @@ Tensor expand_batching_rule(const Tensor& self, IntArrayRef size, bool implicit)
return self_physical.getPhysicalToLogicalMap().apply(result);
}

Tensor expand_symint_batching_rule(const Tensor& self, SymIntArrayRef psize, bool implicit) {
// TODO: properly support this
return expand_batching_rule(self, asIntArrayRefSlow(psize), implicit);
}

std::vector<Tensor> chunk_batching_rule(const Tensor& self, int64_t chunks, int64_t dim) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto dim_physical = self_physical.getPhysicalDim(dim);
Expand Down Expand Up @@ -469,11 +464,6 @@ Tensor view_batching_rule(const Tensor& self, IntArrayRef size) {
return self_physical.getPhysicalToLogicalMap().apply(result);
}

Tensor view_symint_batching_rule(const Tensor& self, c10::SymIntArrayRef size) {
// TODO: properly support this
return view_batching_rule(self, asIntArrayRefSlow(size));
}

Tensor view_as_complex_batching_rule(const Tensor& self) {
// guard against the user passing in a batch of scalar tensors with batch
// size equal to 2.
Expand Down Expand Up @@ -1004,17 +994,6 @@ Tensor new_empty_batching_rule(
return physical_view.getPhysicalToLogicalMap().apply(result);
}

Tensor new_empty_symint_batching_rule(
const Tensor& self,
c10::SymIntArrayRef size,
c10::optional<ScalarType> dtype,
c10::optional<Layout> layout,
c10::optional<Device> device,
c10::optional<bool> pin_memory) {
// TODO: properly support this
return new_empty_batching_rule(self, asIntArrayRefSlow(size), dtype, layout, device, pin_memory);
}

Tensor new_empty_strided_batching_rule(
const Tensor& self,
IntArrayRef size,
Expand Down Expand Up @@ -1112,7 +1091,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
m.impl("tensor_split.sections", tensor_split_sections_batching_rule);
m.impl("tensor_split.indices", tensor_split_indices_batching_rule);
m.impl("diagonal", diagonal_batching_rule);
m.impl("expand", expand_symint_batching_rule);
m.impl("expand", expand_batching_rule);
m.impl("expand_as", native::expand_as); // composite wrt autograd
m.impl("movedim.intlist", movedim_batching_rule);
m.impl("movedim.int", static_cast<Tensor(*)(const Tensor&,int64_t,int64_t)>(native::movedim)); // composite wrt autograd
Expand Down Expand Up @@ -1140,7 +1119,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
m.impl("unbind.int", unbind_batching_rule);
m.impl("unfold", unfold_batching_rule);
m.impl("unsqueeze", unsqueeze_batching_rule);
m.impl("view", view_symint_batching_rule);
m.impl("view", view_batching_rule);
m.impl("view_as", native::view_as); // composite wrt autograd

// clamp operations
Expand Down Expand Up @@ -1278,7 +1257,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
m.impl("diagonal_backward", diagonal_backward_batching_rule);

// Tensor.new_* operators
m.impl("new_empty", new_empty_symint_batching_rule);
m.impl("new_empty", new_empty_batching_rule);
m.impl("new_empty_strided", new_empty_strided_batching_rule);
m.impl("new_zeros", new_zeros_batching_rule);

Expand Down
42 changes: 40 additions & 2 deletions aten/src/ATen/core/boxing/KernelFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,40 @@ class OperatorHandle;
struct OperatorKernel;
class KernelFunction;

template <typename T>
using has_symint =
guts::disjunction<
std::is_same<c10::SymInt, std::decay_t<T>>,
std::is_same<c10::SymIntArrayRef, std::decay_t<T>>,
std::is_same<c10::optional<c10::SymInt>, std::decay_t<T>>
>;

template <typename T>
struct remove_symint {
using type = T;
};

template <>
struct remove_symint<c10::SymInt> {
using type = int64_t;
};

template <>
struct remove_symint<c10::SymIntArrayRef> {
using type = c10::IntArrayRef;
};

template <>
struct remove_symint<c10::optional<c10::SymInt>> {
using type = c10::optional<int64_t>;
};

template <typename T>
using fn_has_symint = typename guts::typelist::true_for_any_type<
has_symint,
typename guts::infer_function_traits<T>::type::parameter_types
>;

/**
* KernelFunction is similar to std::function but stores a kernel function.
* You can create a KernelFunction from a boxed or unboxed function/functor/lambda
Expand All @@ -31,6 +65,7 @@ class TORCH_API KernelFunction final {
// Fast path for dispatch to allow not touching the boxed kernel in
// the common case where unboxed is available.
bool isValidUnboxed() const;
bool isValidSymUnboxed() const;
bool isValid() const;
bool isFallthrough() const;

Expand Down Expand Up @@ -182,13 +217,16 @@ class TORCH_API KernelFunction final {
explicit KernelFunction(
std::unique_ptr<OperatorKernel> functor,
InternalBoxedKernelFunction* boxed_kernel_func,
void* unboxed_kernel_func);
void* unboxed_kernel_func,
void* sym_unboxed_kernel_func);
explicit KernelFunction(
BoxedKernel boxed_fn,
void* unboxed_kernel_func);
void* unboxed_kernel_func,
void* sym_unboxed_kernel_func);

BoxedKernel boxed_kernel_func_;
void* unboxed_kernel_func_;
void* sym_unboxed_kernel_func_;
};

}
Expand Down
61 changes: 54 additions & 7 deletions aten/src/ATen/core/boxing/KernelFunction_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,29 @@ namespace c10 {
inline KernelFunction::KernelFunction()
: boxed_kernel_func_()
, unboxed_kernel_func_(nullptr)
, sym_unboxed_kernel_func_(nullptr)
{}

inline KernelFunction::KernelFunction(std::unique_ptr<OperatorKernel> functor, InternalBoxedKernelFunction* boxed_kernel_func, void* unboxed_kernel_func)
inline KernelFunction::KernelFunction(std::unique_ptr<OperatorKernel> functor, InternalBoxedKernelFunction* boxed_kernel_func, void* unboxed_kernel_func, void* sym_unboxed_kernel_func = nullptr)
: boxed_kernel_func_(std::move(functor), boxed_kernel_func)
, unboxed_kernel_func_(unboxed_kernel_func)
, sym_unboxed_kernel_func_(sym_unboxed_kernel_func)
{}

inline KernelFunction::KernelFunction(BoxedKernel boxed_fn, void* unboxed_kernel_func)
inline KernelFunction::KernelFunction(BoxedKernel boxed_fn, void* unboxed_kernel_func, void* sym_unboxed_kernel_func = nullptr)
: boxed_kernel_func_(std::move(boxed_fn))
, unboxed_kernel_func_(unboxed_kernel_func)
, sym_unboxed_kernel_func_(sym_unboxed_kernel_func)
{}

inline bool KernelFunction::isValidUnboxed() const {
return unboxed_kernel_func_ != nullptr;
}

inline bool KernelFunction::isValidSymUnboxed() const {
return sym_unboxed_kernel_func_ != nullptr;
}

inline bool KernelFunction::isValid() const {
return boxed_kernel_func_.isValid();
}
Expand All @@ -43,16 +50,52 @@ inline Return callUnboxedKernelFunction(void* unboxed_kernel_func, OperatorKerne
return (*func)(functor, dispatchKeySet, std::forward<Args>(args)...);
}

// This template requires you to explicitly specify the argument you want to
// forward; it doesn't work if you try to deduce it

template <typename T>
inline typename remove_symint<T>::type unpackSymInt(T x) { return x; }

template <>
inline typename remove_symint<c10::SymInt>::type unpackSymInt(c10::SymInt x) {
return x.expect_int();
}

template <>
inline typename remove_symint<c10::SymIntArrayRef>::type unpackSymInt(c10::SymIntArrayRef x) {
return c10::asIntArrayRefSlow(x);
}

template <>
inline typename remove_symint<c10::optional<c10::SymInt>>::type unpackSymInt(c10::optional<c10::SymInt> x) {
return x.has_value() ? c10::make_optional(x->expect_int()) : c10::nullopt;
}

template<class Return, class... Args>
C10_ALWAYS_INLINE Return KernelFunction::call(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Args... args) const {
// note: Args above is intentionally not Args&&. We don't want perfect
// forwarding, which would require Args to be deduced, but instead we
// want callers to explicitly specify the Args.

if (C10_LIKELY(unboxed_kernel_func_ != nullptr)) {
auto *functor = boxed_kernel_func_.getFunctor();
return callUnboxedKernelFunction<Return, Args...>(
unboxed_kernel_func_, functor, dispatchKeySet, std::forward<Args>(args)...);
// This should get inlined by compiler
if (guts::disjunction<has_symint<Args>...>::value) {
if (sym_unboxed_kernel_func_ != nullptr) {
auto *functor = boxed_kernel_func_.getFunctor();
return callUnboxedKernelFunction<Return, Args...>(
sym_unboxed_kernel_func_, functor, dispatchKeySet, std::forward<Args>(args)...);
}

if (unboxed_kernel_func_ != nullptr) {
auto *functor = boxed_kernel_func_.getFunctor();
return callUnboxedKernelFunction<Return, typename remove_symint<Args>::type...>(
unboxed_kernel_func_, functor, dispatchKeySet, unpackSymInt<Args>(args)...);
}
} else {
if (C10_LIKELY(unboxed_kernel_func_ != nullptr)) {
auto *functor = boxed_kernel_func_.getFunctor();
return callUnboxedKernelFunction<Return, Args...>(
unboxed_kernel_func_, functor, dispatchKeySet, std::forward<Args>(args)...);
}
}

return impl::BoxedKernelWrapper<Return(Args...)>::call(
Expand Down Expand Up @@ -102,10 +145,14 @@ inline KernelFunction KernelFunction::makeFromUnboxedFunctor(std::unique_ptr<Ope
#endif
static_assert(std::is_base_of<OperatorKernel, KernelFunctor>::value, "Tried to call KernelFunction::makeFromUnboxedFunctor<KernelFunctor>, but the functor doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it.");

auto* unboxed_fn = &impl::wrap_kernel_functor_unboxed<KernelFunctor>::call;
void* void_unboxed_fn = reinterpret_cast<void*>(unboxed_fn);
bool is_symint = fn_has_symint<decltype(unboxed_fn)>::value;
return KernelFunction(
std::move(kernelFunctor),
&impl::make_boxed_from_unboxed_functor<KernelFunctor, AllowLegacyTypes>::call,
reinterpret_cast<void*>(&impl::wrap_kernel_functor_unboxed<KernelFunctor>::call)
is_symint ? nullptr : void_unboxed_fn,
is_symint ? void_unboxed_fn : nullptr
);
}

Expand Down
47 changes: 30 additions & 17 deletions aten/src/ATen/core/dispatch/OperatorEntry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ OperatorEntry::OperatorEntry(OperatorName&& operator_name)
, dispatchKeyExtractor_(DispatchKeyExtractor::makeUninitialized())
, kernels_()
, cpp_signature_()
, sym_cpp_signature_()
, is_observed_(ObservedOperators::isObserved(name_))
{
// Pick up any backend fallbacks that were registered prior to this
Expand All @@ -34,12 +35,11 @@ OperatorEntry::OperatorEntry(OperatorName&& operator_name)
}

namespace {
void checkSchema(const OperatorName& name, const FunctionSchema& from_def, const std::string& from_def_debug, const FunctionSchema& inferred, const std::string& inferred_debug) {
void checkSchema(const OperatorName& name, const FunctionSchema& from_def_, const std::string& from_def_debug, const KernelFunction& kernel, const FunctionSchema& inferred_, const std::string& inferred_debug) {
// TODO: figure out if we can just directly save real schema at def time
c10::optional<std::string> schema_difference = findSchemaDifferences(
from_def.cloneWithRealTypes(),
inferred.cloneWithRealTypes()
);
FunctionSchema from_def = from_def_.cloneWithRealTypes(kernel.isValidSymUnboxed());
FunctionSchema inferred = inferred_.cloneWithRealTypes();
c10::optional<std::string> schema_difference = findSchemaDifferences(from_def, inferred);
if (schema_difference.has_value()) {
TORCH_CHECK(false,
"Inferred operator schema for a C++ kernel function doesn't match the expected function schema.\n"
Expand All @@ -64,12 +64,24 @@ const AnnotatedKernel& OperatorEntry::ambiguousAutogradOtherKernel() const {
return kernel;
}

void OperatorEntry::assertSignatureIsCorrect(const CppSignature call_signature, bool has_symint) const {
if (has_symint) {
if (C10_UNLIKELY(sym_cpp_signature_.has_value() && (call_signature != sym_cpp_signature_->signature))) {
reportSignatureError(call_signature, *sym_cpp_signature_);
}
} else {
if (C10_UNLIKELY(cpp_signature_.has_value() && (call_signature != cpp_signature_->signature))) {
reportSignatureError(call_signature, *cpp_signature_);
}
}
}

void OperatorEntry::registerSchema(FunctionSchema&& schema, std::string&& debug, std::vector<at::Tag> tags) {
TORCH_INTERNAL_ASSERT(!schema_.has_value());
for (const auto& kernel : kernels_) {
for (const auto &j : kernel.second) {
if (j.inferred_function_schema != nullptr) {
checkSchema(name_, schema, debug, *j.inferred_function_schema, j.debug);
checkSchema(name_, schema, debug, j.kernel, *j.inferred_function_schema, j.debug);
}
}
}
Expand Down Expand Up @@ -103,25 +115,26 @@ OperatorEntry::AnnotatedKernelContainerIterator OperatorEntry::registerKernel(
// which means if you could validly change the type of a cpp_signature, then
// that would also invalidate the old TypedOperatorHandles.
if (cpp_signature.has_value()) {
if (cpp_signature_.has_value()) {
TORCH_CHECK(*cpp_signature == cpp_signature_->signature,
auto& local_cpp_signature = kernel.isValidSymUnboxed() ? sym_cpp_signature_ : cpp_signature_;
if (local_cpp_signature.has_value()) {
TORCH_CHECK(*cpp_signature == local_cpp_signature->signature,
"\nMismatch in kernel C++ signatures\n",
" operator: ", (this->schema_.has_value() ? toString(this->schema_->schema) : toString(name_)), "\n",
" ", (this->schema_.has_value() ? this->schema_->debug : "no debug info"), "\n",
" kernel 1: ", cpp_signature_->signature.name(), "\n",
" dispatch key: ", toString(cpp_signature_->dispatch_key), "\n",
" ", cpp_signature_->debug, "\n",
" kernel 1: ", local_cpp_signature->signature.name(), "\n",
" dispatch key: ", toString(local_cpp_signature->dispatch_key), "\n",
" ", local_cpp_signature->debug, "\n",
" kernel 2: ", cpp_signature->name(), "\n",
" dispatch key: ", toString(dispatch_key), "\n",
" ", debug, "\n"
);
} else {
cpp_signature_ = CppSignatureWithDebug { *cpp_signature, debug, dispatch_key };
local_cpp_signature = CppSignatureWithDebug { *cpp_signature, debug, dispatch_key };
}
}

if (schema_ && inferred_function_schema) {
checkSchema(name_, schema_->schema, schema_->debug, *inferred_function_schema, debug);
checkSchema(name_, schema_->schema, schema_->debug, kernel, *inferred_function_schema, debug);
}

// Add the kernel to the kernels list,
Expand All @@ -138,7 +151,7 @@ OperatorEntry::AnnotatedKernelContainerIterator OperatorEntry::registerKernel(
" operator: ", (schema_.has_value() ? toString(schema_->schema) : toString(name_)), "\n",
" ", (this->schema_.has_value() ? this->schema_->debug : "no debug info"), "\n",
" dispatch key: ", toString(dispatch_key), "\n",
" previous kernel: ", (cpp_signature_.has_value() ? cpp_signature_->debug : "no debug info"), "\n",
" previous kernel: ", (cpp_signature_.has_value() ? cpp_signature_->debug : (sym_cpp_signature_.has_value() ? sym_cpp_signature_->debug : "no debug info")), "\n",
" new kernel: ", debug
);
}
Expand Down Expand Up @@ -471,13 +484,13 @@ std::string OperatorEntry::listAllDispatchKeys() const {
return str.str();
}

void OperatorEntry::reportSignatureError(const CppSignature call_signature) const {
void OperatorEntry::reportSignatureError(const CppSignature& call_signature, const CppSignatureWithDebug& saved_signature) const {
TORCH_CHECK(false,
"\nTried to access or call an operator with a wrong signature.\n",
" operator: ", (schema_.has_value() ? toString(schema_->schema) : toString(name_)), "\n",
" ", (schema_.has_value() ? schema_->debug : "unknown debug info"), "\n",
" correct signature: ", cpp_signature_->signature.name(), "\n",
" ", cpp_signature_->debug, "\n",
" correct signature: ", saved_signature.signature.name(), "\n",
" ", saved_signature.debug, "\n",
" accessed/called as: ", call_signature.name(), "\n",
"This likely happened in a call to OperatorHandle::typed<Return (Args...)>(). ",
"Please make sure that the function signature matches the signature in the operator registration call."
Expand Down
11 changes: 4 additions & 7 deletions aten/src/ATen/core/dispatch/OperatorEntry.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,10 @@ class TORCH_API OperatorEntry final {
// Asserts that the given FuncType is correct for calling this operator in an unboxed way.
template<class FuncType>
inline void assertSignatureIsCorrect() {
assertSignatureIsCorrect(CppSignature::make<FuncType>());
assertSignatureIsCorrect(CppSignature::make<FuncType>(), fn_has_symint<FuncType>::value);
}

void assertSignatureIsCorrect(const CppSignature call_signature) {
if (C10_UNLIKELY(cpp_signature_.has_value() && (call_signature != cpp_signature_->signature))) {
reportSignatureError(call_signature);
}
}
void assertSignatureIsCorrect(const CppSignature call_signature, bool has_symint) const;

[[noreturn]] void reportError(DispatchKey dispatchKey) const;

Expand Down Expand Up @@ -280,11 +276,12 @@ class TORCH_API OperatorEntry final {
c10::optional<DispatchKey> dispatch_key;
};
c10::optional<CppSignatureWithDebug> cpp_signature_;
c10::optional<CppSignatureWithDebug> sym_cpp_signature_;

// Whether this operator needs to be observed with RecordFunction
const bool is_observed_;

[[noreturn]] void reportSignatureError(CppSignature call_signature) const;
[[noreturn]] void reportSignatureError(const CppSignature& call_signature, const CppSignatureWithDebug& saved_signature) const;
const KernelFunction& computeDispatchTableEntry(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) const;
std::pair<const AnnotatedKernel&, const char*> computeDispatchTableEntryWithDebug(
const c10::Dispatcher& dispatcher, DispatchKey dispatch_key
Expand Down
Loading

0 comments on commit 19e27b1

Please sign in to comment.