Skip to content

Commit

Permalink
Revert "Call jit decomposition in VariableType to increase forward AD…
Browse files Browse the repository at this point in the history
… coverage (pytorch#84151)"

This reverts commit 42d99e6.

Reverted pytorch#84151 on behalf of https://github.com/malfet due to Regressed test_jvpvjp_nn_functional_layer_norm_cuda_float32, see https://hud.pytorch.org/pytorch/pytorch/commit/42d99e6f196233627a28b8e9efb26a0a166fa370
  • Loading branch information
pytorchmergebot committed Sep 7, 2022
1 parent 31ef8dd commit acb4a09
Show file tree
Hide file tree
Showing 14 changed files with 210 additions and 277 deletions.
14 changes: 14 additions & 0 deletions functorch/functorch/csrc/BatchRulesHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,20 @@ void vmapIncompatibleInplaceError(const char* schema_name) {
"please file a bug report instead.");
}

void run_jit_decomposition(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
const auto& schema = op.schema();
// TODO: templatize based on op and keep static trace_exec
auto * trace_exec = torch::jit::GetDecompositionExecutor(schema);
trace_exec->run((*stack));
if (stack->back().isTuple()) {
IValue tup = stack->back();
stack->pop_back();
for (const auto& elem: tup.toTuple()->elements()) {
stack->push_back(elem);
}
}
}

static void handleScalarTypePromotion(Tensor& logical_scalar_tensor, Tensor& second) {
auto result_type = at::native::result_type(logical_scalar_tensor[0], second);
if (logical_scalar_tensor.scalar_type() != result_type) {
Expand Down
6 changes: 6 additions & 0 deletions functorch/functorch/csrc/BatchRulesHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,12 @@ inline void handle_variadic_bdims(std::vector<std::pair<Tensor, optional<int64_t
#define VARIADIC_BDIMS_BOXED(op) \
m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_tensor_inputs_batch_rule<decltype(&handle_variadic_bdims), &handle_variadic_bdims>>());

void run_jit_decomposition(const c10::OperatorHandle& op, torch::jit::Stack* stack);

#define RUN_JIT_DECOMPOSITION(op) \
m.impl(#op, torch::CppFunction::makeFromBoxedFunction<&run_jit_decomposition>());


using UnpackedBatchedTensor = std::tuple<Tensor,optional<int64_t>>;

inline void find_and_unpack_tensors(
Expand Down
3 changes: 1 addition & 2 deletions functorch/functorch/csrc/BatchRulesViews.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
#include <c10/core/SymIntArrayRef.h>
#include <c10/util/SmallBuffer.h>
#include <ATen/InferSize.h>
#include <torch/csrc/jit/runtime/decomposition_registry.h>

namespace at { namespace functorch {

Expand Down Expand Up @@ -511,7 +510,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
VMAP_SUPPORT(chunk, chunk_batching_rule);
m.impl("flatten.using_ints", static_cast<decltype(&ATEN_FN2(flatten, using_ints))>(native::flatten));
VMAP_SUPPORT(flip, flip_batch_rule);
m.impl("trace", torch::CppFunction::makeFromBoxedFunction<&torch::jit::run_jit_decomposition>());
RUN_JIT_DECOMPOSITION(trace)
VMAP_SUPPORT(tril, VARIADIC_BDIMS_BATCH_RULE(ATEN_FN(tril)));
VMAP_SUPPORT(triu, VARIADIC_BDIMS_BATCH_RULE(ATEN_FN(triu)));
VMAP_SUPPORT(repeat, repeat_batch_rule);
Expand Down
72 changes: 71 additions & 1 deletion functorch/functorch/csrc/DynamicLayer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -389,9 +389,43 @@ WithoutTop::~WithoutTop() {
pushDynamicLayer(std::move(layer_));
}

static void dynamicLayerFrontFallback(
// NOTE: [forward-mode AD decompositions hack]
//
// The mechanism is: in DynamicLayerFrontMode, IF we are dispatching on the
// jvp transform, AND we have a decomposition for the operation, then run
// the decomposition.
//
// Let's break that down. There are a douple of moving pieces.
//
// 0. How do we know what transform we're dispatching on?
// Easy, check the top of the DynamicLayerStack and read the transform.
//
// 1. Next, we must identify when an operation (e.g. nll_loss_backward)
// gets dispatched to.
// - register a special kernel to the DynamicLayerFrontMode key
// (see JVP_DECOMP)
// - that special kernel invokes dynamicLayerFrontFallbackOperator with
// an arg indicating we're going to use a decomp
//
// 2. Next, we need to call the decomposition. See call_decomposition_for_jvp.
// We currently use python decompositions that we torchscript.

// Ideally c10::OperatorHandle would have a field like this
// to identify the operator.
// The stuff here should map 1:1 with the operator name.
// aten::nll_loss_backward -> nll_loss_backward
// aten::add.Tensor -> add_Tensor

static void call_decomposition_for_jvp(
const c10::OperatorHandle& op,
torch::jit::Stack* stack) {
run_jit_decomposition(op, stack);
}

static void dynamicLayerFrontFallbackOperator(
const c10::OperatorHandle& op,
torch::jit::Stack* stack,
bool decomp_jvp) {
auto& dynamicLayerStack = dynamicLayerStackAccessor();
TORCH_INTERNAL_ASSERT(dynamicLayerStack.size() > 0);
#ifdef HAS_TORCH_SHOW_DISPATCH_TRACE
Expand All @@ -400,6 +434,13 @@ static void dynamicLayerFrontFallback(
dump_local_tls();
}
#endif

// Hack: if jvp and we have a decomposition registered, then do the decomposition
if (dynamicLayerStack.back().interpreter().key() == TransformType::Jvp &&
decomp_jvp) {
return call_decomposition_for_jvp(op, stack);
}

// Save the current LocalDispatchKeySet (to the current DynamicLayer).
// Upon exiting the current scope, that LocalDispatchKeySet gets restored.
// When the current DynamicLayer dispatches to the next (inner) DynamicLayer,
Expand All @@ -419,6 +460,16 @@ restoreLocalDispatchKeySetRAII(const c10::impl::LocalDispatchKeySet& key_set) {
return c10::impl::ForceDispatchKeyGuard(key_set);
}

void dynamicLayerFrontFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
return dynamicLayerFrontFallbackOperator(op, stack, false);
}

void dynamicLayerFrontFallBackWithDecomp(
const c10::OperatorHandle& op,
torch::jit::Stack* stack) {
return dynamicLayerFrontFallbackOperator(op, stack, true);
}

void dynamicLayerBackFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
auto& layer = dynamicLayerStackAccessor().back();
auto restore_guard = restoreLocalDispatchKeySetRAII(layer.interpreter().getSavedLocalDispatchKeySet());
Expand All @@ -435,5 +486,24 @@ TORCH_LIBRARY_IMPL(_, FT_DYNAMIC_LAYER_BACK_MODE_KEY, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&dynamicLayerBackFallback>());
}

#define JVP_DECOMP(op) \
m.impl(#op, torch::CppFunction::makeFromBoxedFunction<&dynamicLayerFrontFallBackWithDecomp>());

#define JVP_DECOMP2(op, overload) \
m.impl(#op "." #overload, torch::CppFunction::makeFromBoxedFunction<&dynamicLayerFrontFallBackWithDecomp>());

TORCH_LIBRARY_IMPL(aten, FT_DYNAMIC_LAYER_FRONT_MODE_KEY, m) {
JVP_DECOMP(nll_loss_backward);
JVP_DECOMP(nll_loss2d_backward);
JVP_DECOMP(_log_softmax_backward_data);
JVP_DECOMP(_softmax_backward_data);
OP_DECOMPOSE(log_sigmoid);
JVP_DECOMP(log_sigmoid_forward);
JVP_DECOMP(native_layer_norm_backward);
JVP_DECOMP(native_batch_norm_backward);
JVP_DECOMP(cudnn_batch_norm_backward);
}


}
} // namespace at
40 changes: 38 additions & 2 deletions functorch/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,6 +1047,9 @@ def get_vjp(cotangents, *primals):
# RuntimeError: Trying to set a forward gradient that has a different size than that of the original Tensor,
# this is not supported. Tensor is of size [5, 2, 3] while the given forward gradient is of size [1, 2, 3].
xfail('normal', ''),
xfail('_masked.log_softmax', ''), # NYI: forward-AD for _log_softmax_backward_data
xfail('_masked.softmax', ''), # NYI: forward-AD for _softmax_backward_data
xfail('_masked.softmin', ''), # NYI: forward-AD for _softmax_backward_data
xfail('cdist', ''), # NYI: forward-AD for _cdist_forward
xfail('cholesky', ''), # NYI: forward-AD for cholesky
xfail('eig', ''), # NYI: forward-AD for eig
Expand All @@ -1055,7 +1058,10 @@ def get_vjp(cotangents, *primals):
xfail('nn.functional.grid_sample', ''), # NYI: forward AD for grid_sampler_2d
xfail('nn.functional.hardsigmoid', ''), # NYI: forward AD for hardsigmoid_backward
xfail('nn.functional.huber_loss', ''), # NYI: forward AD for huber_loss_backward
xfail('nn.functional.instance_norm', ''), # NYI: forward AD for native_batch_norm_backward
xfail('nn.functional.logsigmoid', ''), # not differentiable w.r.t. buffer
xfail('nn.functional.softmin', ''), # NYI: forward-AD for _softmax_backward_data
xfail('nn.functional.softmin', 'with_dtype'), # NYI: forward-AD for _softmax_backward_data
xfail('renorm', ''), # NYI: forward AD for renorm
xfail('symeig', ''), # NYI: forward AD for symeig
xfail('nn.functional.multilabel_margin_loss', ''), # NYI: multilabel_margin_loss_forward
Expand All @@ -1069,6 +1075,7 @@ def get_vjp(cotangents, *primals):
xfail('scatter_reduce', 'mean'), # NYI: forward-AD for scatter_reduce
xfail('scatter_reduce', 'prod'), # NYI: forward-AD for scatter_reduce
skip('linalg.householder_product', '', device_type='cuda'), # flaky, I'm not sure why
xfail('native_layer_norm', ''), # NYI: forward-AD for native_layer_norm_backward
xfail('sparse.sampled_addmm', ''), # Sparse tensors have no strides
skip('as_strided_scatter', ''), # seems flaky
xfail('segment_reduce', 'offsets'), # NYI: forward-AD for segment_reduce
Expand Down Expand Up @@ -1129,8 +1136,37 @@ def reference(primals, cotangents, primals_tangents, cotangents_tangents):
expected = (tree_unflatten(primals_out, spec), tree_unflatten(tangents_out, spec))
return expected

expected = reference(primals, cotangents, primals_tangents, cotangents_tangents)
self.assertEqual(result, expected)
# HACK: obviously pytorch should also have the same coverage
# For things that do have the same coverage, we test that jvp x vjp
# are the same between PyTorch and functorch. For things that don't,
# we check that jacfwd(vjp) and jacrev(vjp) are the same. This results
# in slower tests.
FUNCTORCH_HAS_FORMULA_BUT_NOT_PYTORCH = {
'nn.functional.nll_loss',
'softmax',
'log_softmax',
'nn.functional.cross_entropy',
'nn.functional.layer_norm',
'nn.functional.batch_norm',
}
if op.name in FUNCTORCH_HAS_FORMULA_BUT_NOT_PYTORCH:
self.assertFalse(op.supports_fwgrad_bwgrad,
f"{op.name} now supports forward over reverse without a decomposition. " +
"Please remove the decomposition version")

def is_differentiable(t):
return isinstance(t, torch.Tensor) and t.dtype == torch.float32
args = (cotangents, *primals)
if op.name == 'nn.functional.binary_cross_entropy':
argnums = (0, 1) # targets is float32 but isn't differentiable
atol_rtol = 1.5e-4, 1.3e-06
else:
argnums = tuple(i for i in range(len(args)) if is_differentiable(args[i]))
atol_rtol = None
self._compare_jacobians_of_vjp(fn, args, argnums, atol_rtol)
else:
expected = reference(primals, cotangents, primals_tangents, cotangents_tangents)
self.assertEqual(result, expected)

def _make_extremal_inputs(self, shape, device):
if shape is None:
Expand Down
15 changes: 1 addition & 14 deletions tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1956,20 +1956,7 @@

- name: log_sigmoid_forward(Tensor self) -> (Tensor output, Tensor buffer)
self: log_sigmoid_backward(grad, self, buffer)
# HACK: This is just auto_element_wise followed by a view_as. The reason we have
# this is bc forward AD was complaining here about the shapes not being the same:
# the primal/tangent are 0-D/1-D respectively. This started happening after moving the
# jvp decomposition mechanism from functorch to core, possibly due to a batching rule.
# In functorch we rely on OP_DECOMPOSE, but now we compute forward AD using an actual
# formula.
#
# We'd like to avoid keeping the entire jvp decomposition mechanism in functorch,
# just for this single decomposition, but also want to avoid any cases from regressing:
# e.g. test_vmapjvpall_nn_functional_logsigmoid_cuda_float32 (passes on cpu, fails on CUDA).
#
# We should either figure out what is going on with vmap or perhaps fwd AD could
# be more tolerant about 0-dim vs 1-dim tensors
output: log_sigmoid_backward(self_t.conj(), self_p, buffer).conj().view_as(self_p)
output: auto_element_wise

- name: _log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor
self: _log_softmax_backward_data(grad, result, dim, self.scalar_type())
Expand Down
Loading

0 comments on commit acb4a09

Please sign in to comment.