diff --git a/functorch/functorch/csrc/BatchRulesHelper.cpp b/functorch/functorch/csrc/BatchRulesHelper.cpp index d49ecd5e87378..dfd690ac21688 100644 --- a/functorch/functorch/csrc/BatchRulesHelper.cpp +++ b/functorch/functorch/csrc/BatchRulesHelper.cpp @@ -133,6 +133,20 @@ void vmapIncompatibleInplaceError(const char* schema_name) { "please file a bug report instead."); } +void run_jit_decomposition(const c10::OperatorHandle& op, torch::jit::Stack* stack) { + const auto& schema = op.schema(); + // TODO: templatize based on op and keep static trace_exec + auto * trace_exec = torch::jit::GetDecompositionExecutor(schema); + trace_exec->run((*stack)); + if (stack->back().isTuple()) { + IValue tup = stack->back(); + stack->pop_back(); + for (const auto& elem: tup.toTuple()->elements()) { + stack->push_back(elem); + } + } +} + static void handleScalarTypePromotion(Tensor& logical_scalar_tensor, Tensor& second) { auto result_type = at::native::result_type(logical_scalar_tensor[0], second); if (logical_scalar_tensor.scalar_type() != result_type) { diff --git a/functorch/functorch/csrc/BatchRulesHelper.h b/functorch/functorch/csrc/BatchRulesHelper.h index 329d0db42b50f..552a38b20e205 100644 --- a/functorch/functorch/csrc/BatchRulesHelper.h +++ b/functorch/functorch/csrc/BatchRulesHelper.h @@ -195,6 +195,12 @@ inline void handle_variadic_bdims(std::vector>()); +void run_jit_decomposition(const c10::OperatorHandle& op, torch::jit::Stack* stack); + +#define RUN_JIT_DECOMPOSITION(op) \ + m.impl(#op, torch::CppFunction::makeFromBoxedFunction<&run_jit_decomposition>()); + + using UnpackedBatchedTensor = std::tuple>; inline void find_and_unpack_tensors( diff --git a/functorch/functorch/csrc/BatchRulesViews.cpp b/functorch/functorch/csrc/BatchRulesViews.cpp index 44f1134486c57..68a6c377f7504 100644 --- a/functorch/functorch/csrc/BatchRulesViews.cpp +++ b/functorch/functorch/csrc/BatchRulesViews.cpp @@ -15,7 +15,6 @@ #include #include #include -#include namespace at { namespace functorch { @@ -511,7 +510,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { VMAP_SUPPORT(chunk, chunk_batching_rule); m.impl("flatten.using_ints", static_cast(native::flatten)); VMAP_SUPPORT(flip, flip_batch_rule); - m.impl("trace", torch::CppFunction::makeFromBoxedFunction<&torch::jit::run_jit_decomposition>()); + RUN_JIT_DECOMPOSITION(trace) VMAP_SUPPORT(tril, VARIADIC_BDIMS_BATCH_RULE(ATEN_FN(tril))); VMAP_SUPPORT(triu, VARIADIC_BDIMS_BATCH_RULE(ATEN_FN(triu))); VMAP_SUPPORT(repeat, repeat_batch_rule); diff --git a/functorch/functorch/csrc/DynamicLayer.cpp b/functorch/functorch/csrc/DynamicLayer.cpp index c83edf327b2cb..08cd4d7a7d6b3 100644 --- a/functorch/functorch/csrc/DynamicLayer.cpp +++ b/functorch/functorch/csrc/DynamicLayer.cpp @@ -389,9 +389,43 @@ WithoutTop::~WithoutTop() { pushDynamicLayer(std::move(layer_)); } -static void dynamicLayerFrontFallback( +// NOTE: [forward-mode AD decompositions hack] +// +// The mechanism is: in DynamicLayerFrontMode, IF we are dispatching on the +// jvp transform, AND we have a decomposition for the operation, then run +// the decomposition. +// +// Let's break that down. There are a douple of moving pieces. +// +// 0. How do we know what transform we're dispatching on? +// Easy, check the top of the DynamicLayerStack and read the transform. +// +// 1. Next, we must identify when an operation (e.g. nll_loss_backward) +// gets dispatched to. +// - register a special kernel to the DynamicLayerFrontMode key +// (see JVP_DECOMP) +// - that special kernel invokes dynamicLayerFrontFallbackOperator with +// an arg indicating we're going to use a decomp +// +// 2. Next, we need to call the decomposition. See call_decomposition_for_jvp. +// We currently use python decompositions that we torchscript. + +// Ideally c10::OperatorHandle would have a field like this +// to identify the operator. +// The stuff here should map 1:1 with the operator name. +// aten::nll_loss_backward -> nll_loss_backward +// aten::add.Tensor -> add_Tensor + +static void call_decomposition_for_jvp( const c10::OperatorHandle& op, torch::jit::Stack* stack) { + run_jit_decomposition(op, stack); +} + +static void dynamicLayerFrontFallbackOperator( + const c10::OperatorHandle& op, + torch::jit::Stack* stack, + bool decomp_jvp) { auto& dynamicLayerStack = dynamicLayerStackAccessor(); TORCH_INTERNAL_ASSERT(dynamicLayerStack.size() > 0); #ifdef HAS_TORCH_SHOW_DISPATCH_TRACE @@ -400,6 +434,13 @@ static void dynamicLayerFrontFallback( dump_local_tls(); } #endif + + // Hack: if jvp and we have a decomposition registered, then do the decomposition + if (dynamicLayerStack.back().interpreter().key() == TransformType::Jvp && + decomp_jvp) { + return call_decomposition_for_jvp(op, stack); + } + // Save the current LocalDispatchKeySet (to the current DynamicLayer). // Upon exiting the current scope, that LocalDispatchKeySet gets restored. // When the current DynamicLayer dispatches to the next (inner) DynamicLayer, @@ -419,6 +460,16 @@ restoreLocalDispatchKeySetRAII(const c10::impl::LocalDispatchKeySet& key_set) { return c10::impl::ForceDispatchKeyGuard(key_set); } +void dynamicLayerFrontFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { + return dynamicLayerFrontFallbackOperator(op, stack, false); +} + +void dynamicLayerFrontFallBackWithDecomp( + const c10::OperatorHandle& op, + torch::jit::Stack* stack) { + return dynamicLayerFrontFallbackOperator(op, stack, true); +} + void dynamicLayerBackFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { auto& layer = dynamicLayerStackAccessor().back(); auto restore_guard = restoreLocalDispatchKeySetRAII(layer.interpreter().getSavedLocalDispatchKeySet()); @@ -435,5 +486,24 @@ TORCH_LIBRARY_IMPL(_, FT_DYNAMIC_LAYER_BACK_MODE_KEY, m) { m.fallback(torch::CppFunction::makeFromBoxedFunction<&dynamicLayerBackFallback>()); } +#define JVP_DECOMP(op) \ + m.impl(#op, torch::CppFunction::makeFromBoxedFunction<&dynamicLayerFrontFallBackWithDecomp>()); + +#define JVP_DECOMP2(op, overload) \ + m.impl(#op "." #overload, torch::CppFunction::makeFromBoxedFunction<&dynamicLayerFrontFallBackWithDecomp>()); + +TORCH_LIBRARY_IMPL(aten, FT_DYNAMIC_LAYER_FRONT_MODE_KEY, m) { + JVP_DECOMP(nll_loss_backward); + JVP_DECOMP(nll_loss2d_backward); + JVP_DECOMP(_log_softmax_backward_data); + JVP_DECOMP(_softmax_backward_data); + OP_DECOMPOSE(log_sigmoid); + JVP_DECOMP(log_sigmoid_forward); + JVP_DECOMP(native_layer_norm_backward); + JVP_DECOMP(native_batch_norm_backward); + JVP_DECOMP(cudnn_batch_norm_backward); +} + + } } // namespace at diff --git a/functorch/test/test_ops.py b/functorch/test/test_ops.py index 218ba47b46ed9..8d69fe7e22b56 100644 --- a/functorch/test/test_ops.py +++ b/functorch/test/test_ops.py @@ -1047,6 +1047,9 @@ def get_vjp(cotangents, *primals): # RuntimeError: Trying to set a forward gradient that has a different size than that of the original Tensor, # this is not supported. Tensor is of size [5, 2, 3] while the given forward gradient is of size [1, 2, 3]. xfail('normal', ''), + xfail('_masked.log_softmax', ''), # NYI: forward-AD for _log_softmax_backward_data + xfail('_masked.softmax', ''), # NYI: forward-AD for _softmax_backward_data + xfail('_masked.softmin', ''), # NYI: forward-AD for _softmax_backward_data xfail('cdist', ''), # NYI: forward-AD for _cdist_forward xfail('cholesky', ''), # NYI: forward-AD for cholesky xfail('eig', ''), # NYI: forward-AD for eig @@ -1055,7 +1058,10 @@ def get_vjp(cotangents, *primals): xfail('nn.functional.grid_sample', ''), # NYI: forward AD for grid_sampler_2d xfail('nn.functional.hardsigmoid', ''), # NYI: forward AD for hardsigmoid_backward xfail('nn.functional.huber_loss', ''), # NYI: forward AD for huber_loss_backward + xfail('nn.functional.instance_norm', ''), # NYI: forward AD for native_batch_norm_backward xfail('nn.functional.logsigmoid', ''), # not differentiable w.r.t. buffer + xfail('nn.functional.softmin', ''), # NYI: forward-AD for _softmax_backward_data + xfail('nn.functional.softmin', 'with_dtype'), # NYI: forward-AD for _softmax_backward_data xfail('renorm', ''), # NYI: forward AD for renorm xfail('symeig', ''), # NYI: forward AD for symeig xfail('nn.functional.multilabel_margin_loss', ''), # NYI: multilabel_margin_loss_forward @@ -1069,6 +1075,7 @@ def get_vjp(cotangents, *primals): xfail('scatter_reduce', 'mean'), # NYI: forward-AD for scatter_reduce xfail('scatter_reduce', 'prod'), # NYI: forward-AD for scatter_reduce skip('linalg.householder_product', '', device_type='cuda'), # flaky, I'm not sure why + xfail('native_layer_norm', ''), # NYI: forward-AD for native_layer_norm_backward xfail('sparse.sampled_addmm', ''), # Sparse tensors have no strides skip('as_strided_scatter', ''), # seems flaky xfail('segment_reduce', 'offsets'), # NYI: forward-AD for segment_reduce @@ -1129,8 +1136,37 @@ def reference(primals, cotangents, primals_tangents, cotangents_tangents): expected = (tree_unflatten(primals_out, spec), tree_unflatten(tangents_out, spec)) return expected - expected = reference(primals, cotangents, primals_tangents, cotangents_tangents) - self.assertEqual(result, expected) + # HACK: obviously pytorch should also have the same coverage + # For things that do have the same coverage, we test that jvp x vjp + # are the same between PyTorch and functorch. For things that don't, + # we check that jacfwd(vjp) and jacrev(vjp) are the same. This results + # in slower tests. + FUNCTORCH_HAS_FORMULA_BUT_NOT_PYTORCH = { + 'nn.functional.nll_loss', + 'softmax', + 'log_softmax', + 'nn.functional.cross_entropy', + 'nn.functional.layer_norm', + 'nn.functional.batch_norm', + } + if op.name in FUNCTORCH_HAS_FORMULA_BUT_NOT_PYTORCH: + self.assertFalse(op.supports_fwgrad_bwgrad, + f"{op.name} now supports forward over reverse without a decomposition. " + + "Please remove the decomposition version") + + def is_differentiable(t): + return isinstance(t, torch.Tensor) and t.dtype == torch.float32 + args = (cotangents, *primals) + if op.name == 'nn.functional.binary_cross_entropy': + argnums = (0, 1) # targets is float32 but isn't differentiable + atol_rtol = 1.5e-4, 1.3e-06 + else: + argnums = tuple(i for i in range(len(args)) if is_differentiable(args[i])) + atol_rtol = None + self._compare_jacobians_of_vjp(fn, args, argnums, atol_rtol) + else: + expected = reference(primals, cotangents, primals_tangents, cotangents_tangents) + self.assertEqual(result, expected) def _make_extremal_inputs(self, shape, device): if shape is None: diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index c5b1ec04fd875..5a8bf46319f02 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -1956,20 +1956,7 @@ - name: log_sigmoid_forward(Tensor self) -> (Tensor output, Tensor buffer) self: log_sigmoid_backward(grad, self, buffer) - # HACK: This is just auto_element_wise followed by a view_as. The reason we have - # this is bc forward AD was complaining here about the shapes not being the same: - # the primal/tangent are 0-D/1-D respectively. This started happening after moving the - # jvp decomposition mechanism from functorch to core, possibly due to a batching rule. - # In functorch we rely on OP_DECOMPOSE, but now we compute forward AD using an actual - # formula. - # - # We'd like to avoid keeping the entire jvp decomposition mechanism in functorch, - # just for this single decomposition, but also want to avoid any cases from regressing: - # e.g. test_vmapjvpall_nn_functional_logsigmoid_cuda_float32 (passes on cpu, fails on CUDA). - # - # We should either figure out what is going on with vmap or perhaps fwd AD could - # be more tolerant about 0-dim vs 1-dim tensors - output: log_sigmoid_backward(self_t.conj(), self_p, buffer).conj().view_as(self_p) + output: auto_element_wise - name: _log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor self: _log_softmax_backward_data(grad, result, dim, self.scalar_type()) diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 35987ca24266d..f9afe838203de 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -31,7 +31,6 @@ from torchgen.api.autograd import ( DifferentiableInput, dispatch_strategy, - ForwardDerivative, gen_differentiable_outputs, is_differentiable, NativeFunctionWithDifferentiabilityInfo, @@ -598,14 +597,8 @@ DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES = CodeTemplate( """\ auto ${tmp_var} = ([&]() { - if (${try_jit_decomposition_bool} && ${any_has_forward_grad}) { - static c10::OperatorName full_name("aten::${op_name}", "${op_overload}"); - static c10::optional opt_op = c10::Dispatcher::singleton().findSchema(full_name); - return impl::run_jit_decomposition_with_args_for_jvp<${returns_and_args}>("${op_name}", *opt_op, ks, ${arg_names}); - } else { - ${guard} - return ${base_type_call}; - } + ${guard} + return ${base_type_call}; })(); """ ) @@ -649,12 +642,6 @@ """ ) -FW_DERIVATIVE_TENSORLIST_CHECK_TEMPLATE = CodeTemplate( - """\ -isFwGradDefinedTensorList(${req_inp})\ -""" -) - FW_DERIVATIVE_DEFINED_GRAD_TEMPLATE = CodeTemplate( """\ auto ${inp}_t_raw = toNonOptFwGrad(${inp}); @@ -985,23 +972,6 @@ def find_args_with_derivatives( f"ERROR: derivative ignored for {name} -- specified an autograd function without derivative" ) - if requires_derivative and not len(fw_derivatives) == 0: - assert sum(len(derivative.var_names) for derivative in fw_derivatives) == len( - differentiable_outputs - ), ( - "Expected the number of forward derivatives implemented to match the " - "number of differentiable outputs. NB: This only applies when at least " - "one forward derivative is implemented. Not implementing any forward " - "derivatives is also okay, and we would require inputs to the op to " - "not have associated tangents in that case." - ) - try_jit_decomposition = ( - requires_derivative - and len(fw_derivatives) == 0 - and (not modifies_arguments(f)) - and (not returns_void) - ) - def emit_save_inputs() -> List[str]: setup: List[str] = [] if info is None or not info.has_derivatives: @@ -1368,9 +1338,7 @@ def check_tensorimpl_and_storage( ) return call - def emit_call( - f: NativeFunction, unpacked_bindings: List[Binding], try_jit_decomposition: bool - ) -> str: + def emit_call(f: NativeFunction, unpacked_bindings: List[Binding]) -> str: # We only care about adding `at::AutoDispatchBelowAutograd` guard for non-variable dispatch # (which corresponds to 'use_derived' strategy). The purpose of this guard is to make sure # the baseType operations still dispatch to non-Variable type, even if the arguments passed @@ -1384,51 +1352,13 @@ def emit_call( else: guard = "at::AutoDispatchBelowADInplaceOrView guard;" - try_jit_decomposition_bool = "true" if try_jit_decomposition else "false" - any_has_forward_grad = ( - get_any_has_fw_grad_cond(derivative=None) - if requires_derivative - else "false" - ) - return_types = ", ".join( - [cpp.return_type(a, symint=True).cpp_type() for a in f.func.returns] - ) - if len(f.func.returns) > 1: - return_types = f"std::tuple<{return_types}>" - - arg_types = [ - cpp.argument_type(a, binds="", symint=True).cpp_type() - for a in f.func.arguments.flat_all - ] - arg_names = [ - a.name - for a in cpp.arguments( - f.func.arguments, - faithful=True, - symint=True, - method=False, - cpp_no_default_args=set(), - ) - ] - if not modifies_arguments(f) and not returns_void: - # Just to keep things simple here, we only care about this path - # and always emit the if/else for now call = DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES.substitute( - base_type_call=base_type_call, - tmp_var=TMP_VAR, - guard=guard, - try_jit_decomposition_bool=try_jit_decomposition_bool, - any_has_forward_grad=any_has_forward_grad, - op_name=cpp.name(f.func), - op_overload=f.func.name.overload_name, - returns_and_args=return_types + ", " + ", ".join(arg_types), - arg_names=arg_names, + base_type_call=base_type_call, tmp_var=TMP_VAR, guard=guard ) call += wrap_output(f, unpacked_bindings, TMP_VAR) else: - assert not try_jit_decomposition call = DISPATCH_TO_NON_VAR_TYPE_WITHOUT_RETURN_VALUES.substitute( base_type_call=base_type_call, guard=guard ) @@ -1476,14 +1406,38 @@ def get_any_has_forward_grad_name(var_names: Tuple[str, ...]) -> str: def emit_any_has_forward_grad() -> List[str]: content: List[str] = [] for derivative in fw_derivatives: - requires_fw_grad = get_any_has_fw_grad_cond(derivative=derivative) + assert derivative.required_inputs_fw_grad is not None + requires_fw_grad = " || ".join( + [ + FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp=inp.name) + for inp in differentiable_inputs + if inp.name in derivative.required_inputs_fw_grad + ] + ) + if not requires_fw_grad: + # Handle functions like stack + # For these, we don't unpack anything and always call the user function + if not ( + len(differentiable_inputs) == 1 + and is_tensor_list_type(differentiable_inputs[0].type) + ): + raise RuntimeError( + f'No differentiable input to "{name}" is a differentiable Tensor (as the provided ' + "forward AD formula does not use any input tangent) even though a forward gradient " + "formula has been defined for it. This case should only happen for function that " + "take a single TensorList as input. All other cases are not supported right now." + ) + requires_fw_grad = "true" + if info and info.output_differentiability_conditions: assert len(info.output_differentiability_conditions) == 1 - requires_fw_grad = f"({info.output_differentiability_conditions[0]}) && {requires_fw_grad}" + requires_fw_grad = f"({info.output_differentiability_conditions[0]}) && ({requires_fw_grad})" + content.append( f"auto {get_any_has_forward_grad_name(derivative.var_names)} = {requires_fw_grad};\n" f"(void){get_any_has_forward_grad_name(derivative.var_names)};" ) + return content def emit_check_inplace() -> List[str]: @@ -1606,83 +1560,46 @@ def emit_fw_derivatives() -> List[str]: content.append("\n".join(fw_grad_setters)) return content - def get_any_has_fw_grad_cond(derivative: Optional[ForwardDerivative]) -> str: - # - # Produces a condition string (e.g, "isFwGradDefined(grad_output) || isFwGradDefined(output)") - # - if derivative is None: - # (1) If a derivative is NOT provided, cond will check fw_grad of ALL differentiable inputs - # - Used in the out_fn case when we want to forbid fw derivatives - # - Used in the case where the fw_derivative is not defined, but we want - # To check if there is a decomposition registered for jvp - to_check: List[str] = [] - for inp in list( - mapMaybe( - gen_differentiable_input, - f.func.arguments.non_out + list(f.func.arguments.out), # type: ignore[operator] + def emit_forbid_fw_derivatives(is_out_fn: bool = False) -> str: + def get_msg() -> str: + if is_out_fn: + msg = "because it is an out= function" + else: + msg = ( + "because it has not been implemented yet.\\nPlease file an issue " + "to PyTorch at https://github.com/pytorch/pytorch/issues/new?template=feature-request.yml " + "so that we can prioritize its implementation." + ) + return msg + + res = "" + to_check: List[str] = [] + for inp in list( + mapMaybe( + gen_differentiable_input, + f.func.arguments.non_out + list(f.func.arguments.out), # type: ignore[operator] + ) + ): + if is_tensor_type(inp.type): + to_check.append( + FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp=inp.name) + ) + elif is_tensor_list_type(inp.type): + cond = FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp="_t") + res += FW_DERIVATIVE_FORBID_LIST_TEMPLATE.substitute( + arg=inp.name, cond=cond, name=name, msg=get_msg() ) - ): - if is_tensor_type(inp.type): - to_check.append( - FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp=inp.name) - ) - elif is_tensor_list_type(inp.type): - to_check.append( - FW_DERIVATIVE_TENSORLIST_CHECK_TEMPLATE.substitute( - req_inp=inp.name - ) - ) - else: - raise RuntimeError( - f'Unsupported input type for "{name}" when forbidding forward AD usage.' - ) - return f'({" || ".join(to_check)})' - else: - # (2) If derivative is provided, use that information to determine which inputs - # to check fw_grad for - assert derivative.required_inputs_fw_grad is not None - - if len(derivative.required_inputs_fw_grad) == 0: - # Handle functions like stack - # For these, we don't unpack anything and always call the user function - if not ( - len(differentiable_inputs) == 1 - and is_tensor_list_type(differentiable_inputs[0].type) - ): - raise RuntimeError( - f'No differentiable input to "{name}" is a differentiable Tensor (as the provided ' - "forward AD formula does not use any input tangent) even though a forward gradient " - "formula has been defined for it. This case should only happen for function that " - "take a single TensorList as input. All other cases are not supported right now." - ) - any_has_fw_grad = "true" else: - any_has_fw_grad = " || ".join( - [ - FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp=inp.name) - for inp in differentiable_inputs - if inp.name in derivative.required_inputs_fw_grad - ] + raise RuntimeError( + f'Unsupported input type for "{name}" when forbidding forward AD usage.' ) - any_has_fw_grad = f"({any_has_fw_grad})" - - return any_has_fw_grad - def emit_forbid_fw_derivatives(is_out_fn: bool = False) -> str: - if is_out_fn: - msg = "because it is an out= function" - else: - msg = ( - "because it has not been implemented yet.\\nPlease file an issue " - "to PyTorch at https://github.com/pytorch/pytorch/issues/new?template=feature-request.yml " - "so that we can prioritize its implementation." + if len(to_check) > 0: + cond = " || ".join(to_check) + res += FW_DERIVATIVE_FORBID_TEMPLATE.substitute( + cond=cond, name=name, msg=get_msg() ) - cond = get_any_has_fw_grad_cond(derivative=None) - return ( - FW_DERIVATIVE_FORBID_TEMPLATE.substitute(cond=cond, name=name, msg=msg) - if cond != "" - else "" - ) + return res body: List[str] = [] unpack_args_stats, unpacked_bindings = unpack_args(f) @@ -1696,7 +1613,7 @@ def emit_forbid_fw_derivatives(is_out_fn: bool = False) -> str: body.extend(setup_derivative(differentiable_inputs)) body.append(declare_returned_variables(f)) - body.append(emit_call(f, unpacked_bindings, try_jit_decomposition)) + body.append(emit_call(f, unpacked_bindings)) if requires_derivative: # set_flags has to appear after version_counter, because rebase_history # requires that the counter is incremented before it is called @@ -1706,11 +1623,20 @@ def emit_forbid_fw_derivatives(is_out_fn: bool = False) -> str: if is_out_fn: body.append(emit_forbid_fw_derivatives(is_out_fn=True)) else: - if requires_derivative and not try_jit_decomposition: - if len(fw_derivatives) > 0: - body.extend(emit_fw_derivatives()) - else: + if requires_derivative: + body.extend(emit_fw_derivatives()) + if len(fw_derivatives) == 0: body.append(emit_forbid_fw_derivatives()) + else: + assert sum( + len(derivative.var_names) for derivative in fw_derivatives + ) == len(differentiable_outputs), ( + "Expected the number of forward derivatives implemented to match the " + "number of differentiable outputs. NB: This only applies when at least " + "one forward derivative is implemented. Not implementing any forward " + "derivatives is also okay, and we would require inputs to the op to " + "not have associated tangents in that case." + ) if requires_derivative: # Save only after the forward AD has been set up diff --git a/tools/autograd/templates/VariableType.cpp b/tools/autograd/templates/VariableType.cpp index 3c467f83c3182..9cd2d5c40de79 100644 --- a/tools/autograd/templates/VariableType.cpp +++ b/tools/autograd/templates/VariableType.cpp @@ -1,5 +1,4 @@ #include "torch/csrc/autograd/VariableTypeUtils.h" -#include "torch/csrc/autograd/VariableTypeUtilsDependOnOps.h" #include "torch/csrc/autograd/generated/VariableType.h" #include "torch/csrc/autograd/FunctionsManual.h" diff --git a/torch/csrc/autograd/VariableTypeUtilsDependOnOps.h b/torch/csrc/autograd/VariableTypeUtilsDependOnOps.h deleted file mode 100644 index f2569c9d64631..0000000000000 --- a/torch/csrc/autograd/VariableTypeUtilsDependOnOps.h +++ /dev/null @@ -1,40 +0,0 @@ -#pragma once - -#include - -// This is the set of helpers in VariableTypeUtils have a dependency on -// native_functions.yaml meaning the file will need to be re-compiled every time -// an operator is changed or added. We cannot simply put these functions in -// VariableType.h and VariableTypeutils.h, since they are included in files like -// ADInplaceOrViewType_X.cpp which don't always want to be recompiled. - -namespace torch { -namespace autograd { -namespace impl { - -// Depends on torch/csrc/jit/ir/ir.h -> aten/src/ATen/core/interned_strings.h -template -Return run_jit_decomposition_with_args_for_jvp( - c10::string_view name, - const c10::OperatorHandle& opHandle, - c10::DispatchKeySet dispatchKeySet, - Args... args) { - bool has_decomp = jit::has_jit_decomposition(opHandle.schema()); - - TORCH_CHECK_NOT_IMPLEMENTED( - has_decomp, - "Trying to use forward AD with ", - name, - " that does not support it" - "because it has not been implemented yet and does not have a decomposition.\\nPlease file an issue " - "to PyTorch at https://github.com/pytorch/pytorch/issues/new?template=feature-request.yml " - "so that we can prioritize its implementation."); - - return c10::KernelFunction::makeFromBoxedKernel( - c10::BoxedKernel::makeFromFunction<&jit::run_jit_decomposition>()) - .call(opHandle, dispatchKeySet, args...); -} - -} // namespace impl -} // namespace autograd -} // namespace torch diff --git a/torch/csrc/autograd/functions/utils.h b/torch/csrc/autograd/functions/utils.h index 75df1a0302c95..a2169f18656f0 100644 --- a/torch/csrc/autograd/functions/utils.h +++ b/torch/csrc/autograd/functions/utils.h @@ -100,23 +100,5 @@ inline bool isFwGradDefined(const c10::optional& t) { return t.has_value() && t->defined() && t->_fw_grad(/*level */ 0).defined(); } -inline bool isFwGradDefinedTensorList(const at::TensorList& variables) { - bool ret = false; - for (auto& variable : variables) { - ret |= isFwGradDefined(variable); - } - return ret; -} - -inline bool isFwGradDefinedTensorList( - const c10::List> li) { - bool ret = false; - for (auto i : c10::irange(li.size())) { - auto t = li.get(i); - ret |= (t.has_value() && isFwGradDefined(t.value())); - } - return ret; -} - } // namespace autograd } // namespace torch diff --git a/torch/csrc/jit/runtime/decomposition_registry.cpp b/torch/csrc/jit/runtime/decomposition_registry.cpp index bfad602ef2f2f..d55ac7eac9be5 100644 --- a/torch/csrc/jit/runtime/decomposition_registry.cpp +++ b/torch/csrc/jit/runtime/decomposition_registry.cpp @@ -160,26 +160,6 @@ void RegisterDecomposition( schema_to_decomposition[&schema] = g; } -void run_jit_decomposition( - const c10::OperatorHandle& op, - torch::jit::Stack* stack) { - const auto& schema = op.schema(); - // TODO: templatize based on op and keep static trace_exec - auto* trace_exec = torch::jit::GetDecompositionExecutor(schema); - trace_exec->run((*stack)); - if (stack->back().isTuple()) { - at::IValue tup = stack->back(); - stack->pop_back(); - for (const auto& elem : tup.toTuple()->elements()) { - stack->push_back(elem); - } - } -} - -bool has_jit_decomposition(const FunctionSchema& schema) { - return GetDecompositionFunction(schema).has_value(); -} - Function* GetDecompositionExecutor(const FunctionSchema& schema) { auto maybe_func = GetDecompositionFunction(schema); TORCH_INTERNAL_ASSERT(maybe_func); diff --git a/torch/csrc/jit/runtime/decomposition_registry.h b/torch/csrc/jit/runtime/decomposition_registry.h index 225204cf60de3..4c6ef3029a0bc 100644 --- a/torch/csrc/jit/runtime/decomposition_registry.h +++ b/torch/csrc/jit/runtime/decomposition_registry.h @@ -25,11 +25,5 @@ TORCH_API Function* GetDecompositionExecutor(const char* schema_literal); TORCH_API Function* GetDecompositionExecutor(const FunctionSchema& schema); -TORCH_API void run_jit_decomposition( - const c10::OperatorHandle& op, - torch::jit::Stack* stack); - -TORCH_API bool has_jit_decomposition(const FunctionSchema& schema); - } // namespace jit } // namespace torch diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index a9e98a44dcaac..3f152354e6d21 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -39,10 +39,6 @@ import torch._refs.special import torch._refs.linalg -# Make sure that decompositions used for test_forward_mode_AD and -# test_fn_fwgrad_bwgrad are registered to the jit -import torch._decomp.decompositions_for_jvp - import torch._prims as prims # noqa: F401 from torch.utils._pytree import tree_flatten @@ -10168,7 +10164,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1): assert_jit_shape_analysis=True, assert_autodiffed=True, supports_forward_ad=True, - supports_fwgrad_bwgrad=True, supports_out=True), OpInfo('softmax', aliases=('special.softmax', 'nn.functional.softmax',), @@ -10178,7 +10173,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1): sample_inputs_func=partial(sample_inputs_softmax_variant, with_dtype=True), assert_autodiffed=True, supports_forward_ad=True, - supports_fwgrad_bwgrad=True, supports_out=True), # `softmin` supports different dtypes based on whether `dtype` argument, # is passed or not. Hence two OpInfo entries, one with dtype and other without. @@ -10191,7 +10185,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1): assert_jit_shape_analysis=False, assert_autodiffed=False, supports_forward_ad=True, - supports_fwgrad_bwgrad=True, supports_out=False), OpInfo('nn.functional.softmin', variant_test_name="with_dtype", @@ -10200,7 +10193,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1): sample_inputs_func=partial(sample_inputs_softmax_variant, with_dtype=True), assert_autodiffed=False, supports_forward_ad=True, - supports_fwgrad_bwgrad=True, supports_out=False), OpInfo( "nn.functional.cross_entropy", @@ -10209,7 +10201,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1): sample_inputs_func=sample_inputs_cross_entropy, supports_out=False, supports_forward_ad=True, - supports_fwgrad_bwgrad=True, decorators=( DecorateInfo( toleranceOverride({torch.float32: tol(atol=1e-5, rtol=1e-3)}), @@ -10301,7 +10292,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1): dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), supports_out=False, assert_jit_shape_analysis=True, - supports_fwgrad_bwgrad=True, sample_inputs_func=sample_inputs_native_layer_norm, error_inputs_func=error_inputs_native_layer_norm, skips=( @@ -10673,7 +10663,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1): dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), supports_out=False, supports_forward_ad=True, - supports_fwgrad_bwgrad=True, decorators=[ # RuntimeError: Cannot insert a Tensor that requires grad as a constant. # Consider making it a parameter or input, or detaching the gradient @@ -10692,7 +10681,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1): dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), supports_out=False, supports_forward_ad=True, - supports_fwgrad_bwgrad=True, assert_jit_shape_analysis=True, decorators=[ DecorateInfo( @@ -11732,7 +11720,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1): dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), supports_out=False, supports_forward_ad=True, - supports_fwgrad_bwgrad=True, assert_jit_shape_analysis=True, sample_inputs_func=sample_inputs_batch_norm, skips=( @@ -11755,7 +11742,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1): dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), supports_out=False, supports_forward_ad=True, - supports_fwgrad_bwgrad=True, decorators=[onlyCUDA, disablecuDNN], skips=( DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}), @@ -14718,7 +14704,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1): dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), sample_inputs_func=sample_inputs_softmax_variant, supports_forward_ad=True, - supports_fwgrad_bwgrad=True, assert_autodiffed=True), OpInfo( 'log_softmax', @@ -14728,7 +14713,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1): dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), sample_inputs_func=partial(sample_inputs_softmax_variant, with_dtype=True), supports_forward_ad=True, - supports_fwgrad_bwgrad=True, assert_autodiffed=True), UnaryUfuncInfo('logit', aten_backward_name='logit_backward', @@ -15605,7 +15589,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1): supports_out=False, sample_inputs_func=sample_inputs_nll_loss, supports_forward_ad=True, - supports_fwgrad_bwgrad=True, assert_jit_shape_analysis=True, skips=( # RuntimeError: diff --git a/torch/testing/_internal/opinfo/definitions/_masked.py b/torch/testing/_internal/opinfo/definitions/_masked.py index cb88766e70c64..d8a3e8aa948de 100644 --- a/torch/testing/_internal/opinfo/definitions/_masked.py +++ b/torch/testing/_internal/opinfo/definitions/_masked.py @@ -990,7 +990,6 @@ def sample_inputs_masked_normalize(op_info, device, dtype, requires_grad, **kwar ), gradcheck_wrapper=gradcheck_wrapper_masked_operation, supports_forward_ad=True, - supports_fwgrad_bwgrad=True, supports_out=False, ), OpInfo( @@ -1018,7 +1017,6 @@ def sample_inputs_masked_normalize(op_info, device, dtype, requires_grad, **kwar ], gradcheck_wrapper=gradcheck_wrapper_masked_operation, supports_forward_ad=True, - supports_fwgrad_bwgrad=True, supports_out=False, ), OpInfo( @@ -1039,7 +1037,6 @@ def sample_inputs_masked_normalize(op_info, device, dtype, requires_grad, **kwar ), gradcheck_wrapper=gradcheck_wrapper_masked_operation, supports_forward_ad=True, - supports_fwgrad_bwgrad=True, supports_out=False, ), OpInfo(