diff --git a/aten/src/ATen/BatchingRegistrations.cpp b/aten/src/ATen/BatchingRegistrations.cpp index fab2c9e607625c..02bbbb7088d6ed 100644 --- a/aten/src/ATen/BatchingRegistrations.cpp +++ b/aten/src/ATen/BatchingRegistrations.cpp @@ -185,11 +185,6 @@ Tensor expand_batching_rule(const Tensor& self, IntArrayRef size, bool implicit) return self_physical.getPhysicalToLogicalMap().apply(result); } -Tensor expand_symint_batching_rule(const Tensor& self, SymIntArrayRef psize, bool implicit) { - // TODO: properly support this - return expand_batching_rule(self, asIntArrayRefSlow(psize), implicit); -} - std::vector chunk_batching_rule(const Tensor& self, int64_t chunks, int64_t dim) { auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); auto dim_physical = self_physical.getPhysicalDim(dim); @@ -469,11 +464,6 @@ Tensor view_batching_rule(const Tensor& self, IntArrayRef size) { return self_physical.getPhysicalToLogicalMap().apply(result); } -Tensor view_symint_batching_rule(const Tensor& self, c10::SymIntArrayRef size) { - // TODO: properly support this - return view_batching_rule(self, asIntArrayRefSlow(size)); -} - Tensor view_as_complex_batching_rule(const Tensor& self) { // guard against the user passing in a batch of scalar tensors with batch // size equal to 2. @@ -1004,17 +994,6 @@ Tensor new_empty_batching_rule( return physical_view.getPhysicalToLogicalMap().apply(result); } -Tensor new_empty_symint_batching_rule( - const Tensor& self, - c10::SymIntArrayRef size, - c10::optional dtype, - c10::optional layout, - c10::optional device, - c10::optional pin_memory) { - // TODO: properly support this - return new_empty_batching_rule(self, asIntArrayRefSlow(size), dtype, layout, device, pin_memory); -} - Tensor new_empty_strided_batching_rule( const Tensor& self, IntArrayRef size, @@ -1112,7 +1091,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) { m.impl("tensor_split.sections", tensor_split_sections_batching_rule); m.impl("tensor_split.indices", tensor_split_indices_batching_rule); m.impl("diagonal", diagonal_batching_rule); - m.impl("expand", expand_symint_batching_rule); + m.impl("expand", expand_batching_rule); m.impl("expand_as", native::expand_as); // composite wrt autograd m.impl("movedim.intlist", movedim_batching_rule); m.impl("movedim.int", static_cast(native::movedim)); // composite wrt autograd @@ -1140,7 +1119,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) { m.impl("unbind.int", unbind_batching_rule); m.impl("unfold", unfold_batching_rule); m.impl("unsqueeze", unsqueeze_batching_rule); - m.impl("view", view_symint_batching_rule); + m.impl("view", view_batching_rule); m.impl("view_as", native::view_as); // composite wrt autograd // clamp operations @@ -1278,7 +1257,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) { m.impl("diagonal_backward", diagonal_backward_batching_rule); // Tensor.new_* operators - m.impl("new_empty", new_empty_symint_batching_rule); + m.impl("new_empty", new_empty_batching_rule); m.impl("new_empty_strided", new_empty_strided_batching_rule); m.impl("new_zeros", new_zeros_batching_rule); diff --git a/aten/src/ATen/core/boxing/KernelFunction.h b/aten/src/ATen/core/boxing/KernelFunction.h index 8ab34e95046abb..0f48c7560d6b06 100644 --- a/aten/src/ATen/core/boxing/KernelFunction.h +++ b/aten/src/ATen/core/boxing/KernelFunction.h @@ -14,6 +14,40 @@ class OperatorHandle; struct OperatorKernel; class KernelFunction; +template +using has_symint = + guts::disjunction< + std::is_same>, + std::is_same>, + std::is_same, std::decay_t> + >; + +template +struct remove_symint { + using type = T; +}; + +template <> +struct remove_symint { + using type = int64_t; +}; + +template <> +struct remove_symint { + using type = c10::IntArrayRef; +}; + +template <> +struct remove_symint> { + using type = c10::optional; +}; + +template +using fn_has_symint = typename guts::typelist::true_for_any_type< + has_symint, + typename guts::infer_function_traits::type::parameter_types +>; + /** * KernelFunction is similar to std::function but stores a kernel function. * You can create a KernelFunction from a boxed or unboxed function/functor/lambda @@ -31,6 +65,7 @@ class TORCH_API KernelFunction final { // Fast path for dispatch to allow not touching the boxed kernel in // the common case where unboxed is available. bool isValidUnboxed() const; + bool isValidSymUnboxed() const; bool isValid() const; bool isFallthrough() const; @@ -182,13 +217,16 @@ class TORCH_API KernelFunction final { explicit KernelFunction( std::unique_ptr functor, InternalBoxedKernelFunction* boxed_kernel_func, - void* unboxed_kernel_func); + void* unboxed_kernel_func, + void* sym_unboxed_kernel_func); explicit KernelFunction( BoxedKernel boxed_fn, - void* unboxed_kernel_func); + void* unboxed_kernel_func, + void* sym_unboxed_kernel_func); BoxedKernel boxed_kernel_func_; void* unboxed_kernel_func_; + void* sym_unboxed_kernel_func_; }; } diff --git a/aten/src/ATen/core/boxing/KernelFunction_impl.h b/aten/src/ATen/core/boxing/KernelFunction_impl.h index c33175e4b99abe..8c968e835fa60b 100644 --- a/aten/src/ATen/core/boxing/KernelFunction_impl.h +++ b/aten/src/ATen/core/boxing/KernelFunction_impl.h @@ -8,22 +8,29 @@ namespace c10 { inline KernelFunction::KernelFunction() : boxed_kernel_func_() , unboxed_kernel_func_(nullptr) + , sym_unboxed_kernel_func_(nullptr) {} -inline KernelFunction::KernelFunction(std::unique_ptr functor, InternalBoxedKernelFunction* boxed_kernel_func, void* unboxed_kernel_func) +inline KernelFunction::KernelFunction(std::unique_ptr functor, InternalBoxedKernelFunction* boxed_kernel_func, void* unboxed_kernel_func, void* sym_unboxed_kernel_func = nullptr) : boxed_kernel_func_(std::move(functor), boxed_kernel_func) , unboxed_kernel_func_(unboxed_kernel_func) + , sym_unboxed_kernel_func_(sym_unboxed_kernel_func) {} -inline KernelFunction::KernelFunction(BoxedKernel boxed_fn, void* unboxed_kernel_func) +inline KernelFunction::KernelFunction(BoxedKernel boxed_fn, void* unboxed_kernel_func, void* sym_unboxed_kernel_func = nullptr) : boxed_kernel_func_(std::move(boxed_fn)) , unboxed_kernel_func_(unboxed_kernel_func) + , sym_unboxed_kernel_func_(sym_unboxed_kernel_func) {} inline bool KernelFunction::isValidUnboxed() const { return unboxed_kernel_func_ != nullptr; } +inline bool KernelFunction::isValidSymUnboxed() const { + return sym_unboxed_kernel_func_ != nullptr; +} + inline bool KernelFunction::isValid() const { return boxed_kernel_func_.isValid(); } @@ -43,16 +50,52 @@ inline Return callUnboxedKernelFunction(void* unboxed_kernel_func, OperatorKerne return (*func)(functor, dispatchKeySet, std::forward(args)...); } +// This template requires you to explicitly specify the argument you want to +// forward; it doesn't work if you try to deduce it + +template +inline typename remove_symint::type unpackSymInt(T x) { return x; } + +template <> +inline typename remove_symint::type unpackSymInt(c10::SymInt x) { + return x.expect_int(); +} + +template <> +inline typename remove_symint::type unpackSymInt(c10::SymIntArrayRef x) { + return c10::asIntArrayRefSlow(x); +} + +template <> +inline typename remove_symint>::type unpackSymInt(c10::optional x) { + return x.has_value() ? c10::make_optional(x->expect_int()) : c10::nullopt; +} + template C10_ALWAYS_INLINE Return KernelFunction::call(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Args... args) const { // note: Args above is intentionally not Args&&. We don't want perfect // forwarding, which would require Args to be deduced, but instead we // want callers to explicitly specify the Args. - if (C10_LIKELY(unboxed_kernel_func_ != nullptr)) { - auto *functor = boxed_kernel_func_.getFunctor(); - return callUnboxedKernelFunction( - unboxed_kernel_func_, functor, dispatchKeySet, std::forward(args)...); + // This should get inlined by compiler + if (guts::disjunction...>::value) { + if (sym_unboxed_kernel_func_ != nullptr) { + auto *functor = boxed_kernel_func_.getFunctor(); + return callUnboxedKernelFunction( + sym_unboxed_kernel_func_, functor, dispatchKeySet, std::forward(args)...); + } + + if (unboxed_kernel_func_ != nullptr) { + auto *functor = boxed_kernel_func_.getFunctor(); + return callUnboxedKernelFunction::type...>( + unboxed_kernel_func_, functor, dispatchKeySet, unpackSymInt(args)...); + } + } else { + if (C10_LIKELY(unboxed_kernel_func_ != nullptr)) { + auto *functor = boxed_kernel_func_.getFunctor(); + return callUnboxedKernelFunction( + unboxed_kernel_func_, functor, dispatchKeySet, std::forward(args)...); + } } return impl::BoxedKernelWrapper::call( @@ -102,10 +145,14 @@ inline KernelFunction KernelFunction::makeFromUnboxedFunctor(std::unique_ptr::value, "Tried to call KernelFunction::makeFromUnboxedFunctor, but the functor doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it."); + auto* unboxed_fn = &impl::wrap_kernel_functor_unboxed::call; + void* void_unboxed_fn = reinterpret_cast(unboxed_fn); + bool is_symint = fn_has_symint::value; return KernelFunction( std::move(kernelFunctor), &impl::make_boxed_from_unboxed_functor::call, - reinterpret_cast(&impl::wrap_kernel_functor_unboxed::call) + is_symint ? nullptr : void_unboxed_fn, + is_symint ? void_unboxed_fn : nullptr ); } diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.cpp b/aten/src/ATen/core/dispatch/OperatorEntry.cpp index 139880c6d7fa3c..01d30c888db2c1 100644 --- a/aten/src/ATen/core/dispatch/OperatorEntry.cpp +++ b/aten/src/ATen/core/dispatch/OperatorEntry.cpp @@ -26,6 +26,7 @@ OperatorEntry::OperatorEntry(OperatorName&& operator_name) , dispatchKeyExtractor_(DispatchKeyExtractor::makeUninitialized()) , kernels_() , cpp_signature_() +, sym_cpp_signature_() , is_observed_(ObservedOperators::isObserved(name_)) { // Pick up any backend fallbacks that were registered prior to this @@ -34,12 +35,11 @@ OperatorEntry::OperatorEntry(OperatorName&& operator_name) } namespace { - void checkSchema(const OperatorName& name, const FunctionSchema& from_def, const std::string& from_def_debug, const FunctionSchema& inferred, const std::string& inferred_debug) { + void checkSchema(const OperatorName& name, const FunctionSchema& from_def_, const std::string& from_def_debug, const KernelFunction& kernel, const FunctionSchema& inferred_, const std::string& inferred_debug) { // TODO: figure out if we can just directly save real schema at def time - c10::optional schema_difference = findSchemaDifferences( - from_def.cloneWithRealTypes(), - inferred.cloneWithRealTypes() - ); + FunctionSchema from_def = from_def_.cloneWithRealTypes(kernel.isValidSymUnboxed()); + FunctionSchema inferred = inferred_.cloneWithRealTypes(); + c10::optional schema_difference = findSchemaDifferences(from_def, inferred); if (schema_difference.has_value()) { TORCH_CHECK(false, "Inferred operator schema for a C++ kernel function doesn't match the expected function schema.\n" @@ -64,12 +64,24 @@ const AnnotatedKernel& OperatorEntry::ambiguousAutogradOtherKernel() const { return kernel; } +void OperatorEntry::assertSignatureIsCorrect(const CppSignature call_signature, bool has_symint) const { + if (has_symint) { + if (C10_UNLIKELY(sym_cpp_signature_.has_value() && (call_signature != sym_cpp_signature_->signature))) { + reportSignatureError(call_signature, *sym_cpp_signature_); + } + } else { + if (C10_UNLIKELY(cpp_signature_.has_value() && (call_signature != cpp_signature_->signature))) { + reportSignatureError(call_signature, *cpp_signature_); + } + } +} + void OperatorEntry::registerSchema(FunctionSchema&& schema, std::string&& debug, std::vector tags) { TORCH_INTERNAL_ASSERT(!schema_.has_value()); for (const auto& kernel : kernels_) { for (const auto &j : kernel.second) { if (j.inferred_function_schema != nullptr) { - checkSchema(name_, schema, debug, *j.inferred_function_schema, j.debug); + checkSchema(name_, schema, debug, j.kernel, *j.inferred_function_schema, j.debug); } } } @@ -103,25 +115,26 @@ OperatorEntry::AnnotatedKernelContainerIterator OperatorEntry::registerKernel( // which means if you could validly change the type of a cpp_signature, then // that would also invalidate the old TypedOperatorHandles. if (cpp_signature.has_value()) { - if (cpp_signature_.has_value()) { - TORCH_CHECK(*cpp_signature == cpp_signature_->signature, + auto& local_cpp_signature = kernel.isValidSymUnboxed() ? sym_cpp_signature_ : cpp_signature_; + if (local_cpp_signature.has_value()) { + TORCH_CHECK(*cpp_signature == local_cpp_signature->signature, "\nMismatch in kernel C++ signatures\n", " operator: ", (this->schema_.has_value() ? toString(this->schema_->schema) : toString(name_)), "\n", " ", (this->schema_.has_value() ? this->schema_->debug : "no debug info"), "\n", - " kernel 1: ", cpp_signature_->signature.name(), "\n", - " dispatch key: ", toString(cpp_signature_->dispatch_key), "\n", - " ", cpp_signature_->debug, "\n", + " kernel 1: ", local_cpp_signature->signature.name(), "\n", + " dispatch key: ", toString(local_cpp_signature->dispatch_key), "\n", + " ", local_cpp_signature->debug, "\n", " kernel 2: ", cpp_signature->name(), "\n", " dispatch key: ", toString(dispatch_key), "\n", " ", debug, "\n" ); } else { - cpp_signature_ = CppSignatureWithDebug { *cpp_signature, debug, dispatch_key }; + local_cpp_signature = CppSignatureWithDebug { *cpp_signature, debug, dispatch_key }; } } if (schema_ && inferred_function_schema) { - checkSchema(name_, schema_->schema, schema_->debug, *inferred_function_schema, debug); + checkSchema(name_, schema_->schema, schema_->debug, kernel, *inferred_function_schema, debug); } // Add the kernel to the kernels list, @@ -138,7 +151,7 @@ OperatorEntry::AnnotatedKernelContainerIterator OperatorEntry::registerKernel( " operator: ", (schema_.has_value() ? toString(schema_->schema) : toString(name_)), "\n", " ", (this->schema_.has_value() ? this->schema_->debug : "no debug info"), "\n", " dispatch key: ", toString(dispatch_key), "\n", - " previous kernel: ", (cpp_signature_.has_value() ? cpp_signature_->debug : "no debug info"), "\n", + " previous kernel: ", (cpp_signature_.has_value() ? cpp_signature_->debug : (sym_cpp_signature_.has_value() ? sym_cpp_signature_->debug : "no debug info")), "\n", " new kernel: ", debug ); } @@ -471,13 +484,13 @@ std::string OperatorEntry::listAllDispatchKeys() const { return str.str(); } -void OperatorEntry::reportSignatureError(const CppSignature call_signature) const { +void OperatorEntry::reportSignatureError(const CppSignature& call_signature, const CppSignatureWithDebug& saved_signature) const { TORCH_CHECK(false, "\nTried to access or call an operator with a wrong signature.\n", " operator: ", (schema_.has_value() ? toString(schema_->schema) : toString(name_)), "\n", " ", (schema_.has_value() ? schema_->debug : "unknown debug info"), "\n", - " correct signature: ", cpp_signature_->signature.name(), "\n", - " ", cpp_signature_->debug, "\n", + " correct signature: ", saved_signature.signature.name(), "\n", + " ", saved_signature.debug, "\n", " accessed/called as: ", call_signature.name(), "\n", "This likely happened in a call to OperatorHandle::typed(). ", "Please make sure that the function signature matches the signature in the operator registration call." diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.h b/aten/src/ATen/core/dispatch/OperatorEntry.h index 1d9f1495f3c74e..a964423d6aa85c 100644 --- a/aten/src/ATen/core/dispatch/OperatorEntry.h +++ b/aten/src/ATen/core/dispatch/OperatorEntry.h @@ -163,14 +163,10 @@ class TORCH_API OperatorEntry final { // Asserts that the given FuncType is correct for calling this operator in an unboxed way. template inline void assertSignatureIsCorrect() { - assertSignatureIsCorrect(CppSignature::make()); + assertSignatureIsCorrect(CppSignature::make(), fn_has_symint::value); } - void assertSignatureIsCorrect(const CppSignature call_signature) { - if (C10_UNLIKELY(cpp_signature_.has_value() && (call_signature != cpp_signature_->signature))) { - reportSignatureError(call_signature); - } - } + void assertSignatureIsCorrect(const CppSignature call_signature, bool has_symint) const; [[noreturn]] void reportError(DispatchKey dispatchKey) const; @@ -280,11 +276,12 @@ class TORCH_API OperatorEntry final { c10::optional dispatch_key; }; c10::optional cpp_signature_; + c10::optional sym_cpp_signature_; // Whether this operator needs to be observed with RecordFunction const bool is_observed_; - [[noreturn]] void reportSignatureError(CppSignature call_signature) const; + [[noreturn]] void reportSignatureError(const CppSignature& call_signature, const CppSignatureWithDebug& saved_signature) const; const KernelFunction& computeDispatchTableEntry(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) const; std::pair computeDispatchTableEntryWithDebug( const c10::Dispatcher& dispatcher, DispatchKey dispatch_key diff --git a/aten/src/ATen/core/function_schema.cpp b/aten/src/ATen/core/function_schema.cpp index 00a31224a48393..7a743c225fcb02 100644 --- a/aten/src/ATen/core/function_schema.cpp +++ b/aten/src/ATen/core/function_schema.cpp @@ -17,9 +17,23 @@ const std::vector& FunctionSchema::getCorrectList(SchemaArgType type) } } -FunctionSchema FunctionSchema::cloneWithRealTypes() const { - auto cloneWithRealTypes = [](const Argument& a) { - return a.cloneWithType(a.real_type()); +FunctionSchema FunctionSchema::cloneWithRealTypes(bool with_symint) const { + auto cloneWithRealTypes = [&](const Argument& a) { + if (with_symint) { + return a.cloneWithType(a.real_type()); + } + // Don't use real type if it looks like a SymInt + // NB: keep this in sync with unpackSymInt in KernelFunction_impl.h + if ( + *a.real_type() == *getTypePtr() || + *a.real_type() == *getTypePtr>() || + *a.real_type() == *getTypePtr() + ) { + // Keep the fake type + return a.cloneWithType(a.type()); + } else { + return a.cloneWithType(a.real_type()); + } }; std::vector new_arguments, new_returns; std::transform(arguments().begin(), arguments().end(), std::back_inserter(new_arguments), cloneWithRealTypes); diff --git a/aten/src/ATen/core/function_schema.h b/aten/src/ATen/core/function_schema.h index bafc0d81032036..14f134939d76ef 100644 --- a/aten/src/ATen/core/function_schema.h +++ b/aten/src/ATen/core/function_schema.h @@ -474,7 +474,7 @@ struct TORCH_API FunctionSchema { FunctionSchema cloneWithRemappedTypes( const std::function type_map) const; - FunctionSchema cloneWithRealTypes() const; + FunctionSchema cloneWithRealTypes(bool with_symint=true) const; // Check that inputs have the correct types and appends any missing default // values. diff --git a/aten/src/ATen/native/vulkan/ops/Factory.cpp b/aten/src/ATen/native/vulkan/ops/Factory.cpp index ce09521668f4f3..06d44ec0619353 100644 --- a/aten/src/ATen/native/vulkan/ops/Factory.cpp +++ b/aten/src/ATen/native/vulkan/ops/Factory.cpp @@ -29,13 +29,12 @@ Tensor _empty_affine_quantized( } Tensor empty_memory_format( - const SymIntArrayRef sym_sizes, + const IntArrayRef sizes, const c10::optional dtype, const c10::optional layout, const c10::optional device, const c10::optional pin_memory, const optional memory_format) { - auto sizes = c10::asIntArrayRefSlow(sym_sizes); return convert(vTensor{ api::context(), sizes, @@ -56,12 +55,7 @@ Tensor empty_strided( const optional device, const optional pin_memory) { return empty_memory_format( - c10::SymIntArrayRef::fromIntArrayRef(sizes), - dtype, - layout, - device, - pin_memory, - c10::MemoryFormat::Contiguous); + sizes, dtype, layout, device, pin_memory, c10::MemoryFormat::Contiguous); } #ifdef USE_VULKAN_API diff --git a/aten/src/ATen/native/vulkan/ops/Shape.cpp b/aten/src/ATen/native/vulkan/ops/Shape.cpp index e1bda761749d7d..d8263e59668e6f 100644 --- a/aten/src/ATen/native/vulkan/ops/Shape.cpp +++ b/aten/src/ATen/native/vulkan/ops/Shape.cpp @@ -42,8 +42,7 @@ Tensor view_internal(const Tensor& self_arg, const IntArrayRef shape) { return convert(v_output); } -inline Tensor view(const Tensor& self_arg, const SymIntArrayRef sym_shape) { - auto shape = c10::asIntArrayRefSlow(sym_shape); +inline Tensor view(const Tensor& self_arg, IntArrayRef shape) { return view_internal(self_arg, shape); } diff --git a/functorch/functorch/csrc/BatchRulesViews.cpp b/functorch/functorch/csrc/BatchRulesViews.cpp index df42086acef39f..44f1134486c577 100644 --- a/functorch/functorch/csrc/BatchRulesViews.cpp +++ b/functorch/functorch/csrc/BatchRulesViews.cpp @@ -439,12 +439,6 @@ std::tuple> view_batching_rule( return std::make_tuple(self_.view_symint(size_), 0); } -Tensor view_symint_decomposition(const Tensor& self, - c10::SymIntArrayRef size) { - return self.view( c10::asIntArrayRefSlow(size)); -} - - template std::tuple> expand_batch_rule( const Tensor &self, optional self_bdim, SymIntArrayRef size, bool implicit) @@ -512,14 +506,6 @@ std::tuple> diag_embed_batch_rule(const Tensor& self, return std::make_tuple(at::diag_embed(self_, offset, dim1, dim2), 0); } -// We need to write a real batching rule to fully support symint. -// This requires symint variants of other operations, like `view`, -// which don't exist yet. -Tensor expand_symint_decomp_hack(const Tensor& self, SymIntArrayRef packed_size, bool implicit) { - auto size = asIntArrayRefSlow(packed_size); - return self.expand(size, implicit); -} - TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { VMAP_SUPPORT(diag, diag_batch_rule); VMAP_SUPPORT(chunk, chunk_batching_rule); diff --git a/test/cpp_extensions/open_registration_extension.cpp b/test/cpp_extensions/open_registration_extension.cpp index 7f43e60a6b393d..ad036109903d4c 100644 --- a/test/cpp_extensions/open_registration_extension.cpp +++ b/test/cpp_extensions/open_registration_extension.cpp @@ -49,9 +49,9 @@ at::Tensor custom_empty_memory_format(at::IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional memory_format) { +at::Tensor custom_empty_symint(c10::IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional memory_format) { constexpr c10::DispatchKeySet private_use_ks(c10::DispatchKey::PrivateUse1); - return at::detail::empty_generic(c10::asIntArrayRefSlow(size), &global_custom_alloc, private_use_ks, c10::dtype_or_default(dtype), memory_format); + return at::detail::empty_generic(size, &global_custom_alloc, private_use_ks, c10::dtype_or_default(dtype), memory_format); } at::Tensor & custom_fill__scalar(at::Tensor & self, const at::Scalar & value) { diff --git a/test/cpp_extensions/ort_extension.cpp b/test/cpp_extensions/ort_extension.cpp index 3422bccd6d38c6..b646f3b14939dc 100644 --- a/test/cpp_extensions/ort_extension.cpp +++ b/test/cpp_extensions/ort_extension.cpp @@ -20,10 +20,10 @@ Tensor get_tensor(caffe2::TypeMeta dtype, IntArrayRef size) { return Tensor(std::move(tensor_impl)); } -Tensor empty_override(SymIntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, +Tensor empty_override(IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional optional_memory_format) { test_int = 0; - return get_tensor(scalarTypeToTypeMeta(dtype_or_default(dtype)), c10::asIntArrayRefSlow(size)); + return get_tensor(scalarTypeToTypeMeta(dtype_or_default(dtype)), size); } Tensor& add_out_override(const Tensor & a, const Tensor & b , const Scalar& c, Tensor & out) {