diff --git a/functorch/csrc/BatchRulesNorm.cpp b/functorch/csrc/BatchRulesNorm.cpp index c5b273652..e78538329 100644 --- a/functorch/csrc/BatchRulesNorm.cpp +++ b/functorch/csrc/BatchRulesNorm.cpp @@ -58,7 +58,7 @@ batch_norm_batch_rule( auto running_mean = *running_mean_maybe_owned; c10::MaybeOwned running_var_maybe_owned = at::borrow_from_optional_tensor(running_var_opt); auto running_var = *running_var_maybe_owned; - TORCH_CHECK(!input_bdim || ((!running_mean.defined() || running_mean_bdim) && (!running_var.defined() || running_var_bdim)), + TORCH_CHECK(!training || (!input_bdim || ((!running_mean.defined() || running_mean_bdim) && (!running_var.defined() || running_var_bdim))), "Batch norm got a batched tensor as input while the running_mean or running_var, which will be updated in place, ", "were not batched.\nIf you are using a module and do not need eval mode, please set `track_running_stats` to be False.", "If you are using a prebuilt module and do not need eval mode, please see the functorch website for resources on ", @@ -85,18 +85,12 @@ batch_norm_batch_rule( if (running_mean.defined()) { running_mean_ = moveBatchDimToFront(running_mean, running_mean_bdim); running_mean_ = ensure_has_bdim(*running_mean_, running_mean_bdim.has_value(), bdim_size.value()); - running_mean_ = reshape_dim_into(0, 0, *running_mean_); - if (training) { - running_mean_ = running_mean_->contiguous(); - } + running_mean_ = reshape_dim_into(0, 0, *running_mean_).contiguous(); } if (running_var.defined()) { running_var_ = moveBatchDimToFront(running_var, running_var_bdim); running_var_ = ensure_has_bdim(*running_var_, running_var_bdim.has_value(), bdim_size.value()); - running_var_ = reshape_dim_into(0, 0, *running_var_); - if (training) { - running_var_ = running_var_->contiguous(); - } + running_var_ = reshape_dim_into(0, 0, *running_var_).contiguous(); } const auto dummy_weight = at::ones(input_.size(1), input_.options()); // cudnn and miopen require a weight diff --git a/test/common_utils.py b/test/common_utils.py index 19c702ae8..7136df8fa 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -115,13 +115,29 @@ def get_exhaustive_batched_inputs_for_batch_norm(arg_values, kwarg_values, batch batch_size=batch_size, bdims=bdims, for_batch_norm=True) -def get_fallback_and_vmap_exhaustive(op, arg_values, kwarg_values, opinfo=None, compute_loop_out=True, bdims=(0, -1)): +def is_batch_norm_training(op_name, kwarg_values): + batch_norm_fns = ("nn.functional.batch_norm", "nn.functional.instance_norm") # instance norm calls batch norm + if op_name not in batch_norm_fns: + return False + + # batch norm and instance norm require the value to be a plain bool + default_training = op_name == "nn.functional.instance_norm" # instance norm defaults to training, batch norm doesn't + is_training = tuple(arg for arg in tuple(kwarg_values.values()) if isinstance(arg, bool)) + if len(is_training) == 0: + return default_training + else: + assert len(is_training) == 1 + return is_training[0] + + +def get_fallback_and_vmap_exhaustive(op, arg_values, kwarg_values, is_batch_norm_and_training=False, compute_loop_out=True, bdims=(0, -1)): out_dim = 0 batch_size = 4 - generator = get_exhaustive_batched_inputs(arg_values, kwarg_values, batch_size, bdims=bdims) - batch_norm_fns = ("nn.functional.batch_norm", "nn.functional.instance_norm") # instance norm calls batch norm - if opinfo is not None and opinfo.name in batch_norm_fns: + if is_batch_norm_and_training: generator = get_exhaustive_batched_inputs_for_batch_norm(arg_values, kwarg_values, batch_size, bdims=bdims) + else: + generator = get_exhaustive_batched_inputs(arg_values, kwarg_values, batch_size, bdims=bdims) + for batched_args, in_dims, kwarg_values in generator: if compute_loop_out: loop_out = loop(op, in_dims, out_dim, batch_size, *batched_args, **kwarg_values) diff --git a/test/test_ops.py b/test/test_ops.py index b994f1ac9..01527f428 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -27,6 +27,7 @@ # tol2, opsToleranceOverride, check_vmap_fallback, + is_batch_norm_training, ) from torch.utils._pytree import tree_flatten, tree_unflatten, tree_map from functorch import grad, vjp, vmap, jacrev, jacfwd @@ -570,7 +571,9 @@ def vjp_of_vjp(*args_and_cotangents): result_vjps, _ = tree_flatten(result_vjps) return (*result, *result_vjps) - generator = get_fallback_and_vmap_exhaustive(vjp_of_vjp, args_and_cotangents, {}, opinfo=op) + is_batch_norm_and_training = is_batch_norm_training(op.name, sample.kwargs) + generator = get_fallback_and_vmap_exhaustive( + vjp_of_vjp, args_and_cotangents, {}, is_batch_norm_and_training=is_batch_norm_and_training) for loop_out, batched_out in generator: self.assertEqual(loop_out, batched_out) @@ -642,7 +645,10 @@ def test_vmapvjp(self, device, dtype, op): for sample in samples: cotangents = get_sample_cotangents(op, sample) fn, args = get_vjp_fn_and_args_with_cotangents(op, sample, cotangents) - for loop_out, batched_out in get_fallback_and_vmap_exhaustive(fn, args, {}, opinfo=op): + is_batch_norm_and_training = is_batch_norm_training(op.name, sample.kwargs) + generator = get_fallback_and_vmap_exhaustive( + fn, args, {}, is_batch_norm_and_training=is_batch_norm_and_training) + for loop_out, batched_out in generator: self.assertEqual(loop_out, batched_out) # There are several variations we care about @@ -731,7 +737,10 @@ def test_vmapjvp(self, device, dtype, op): kwarg_values = sample.kwargs args = tuple([*arg_values, *kwarg_values]) fn, args = get_jvp_variant(op, sample) - for loop_out, batched_out in get_fallback_and_vmap_exhaustive(fn, args, {}, opinfo=op, bdims=(0,)): + is_batch_norm_and_training = is_batch_norm_training(op, kwarg_values) + generator = get_fallback_and_vmap_exhaustive( + fn, args, {}, is_batch_norm_and_training=is_batch_norm_and_training, bdims=(0,)) + for loop_out, batched_out in generator: self.assertEqual(loop_out, batched_out) vmapjvpall_fail = { @@ -819,7 +828,10 @@ def test_vmapjvpall(self, device, dtype, op): kwarg_values = sample.kwargs args = tuple([*arg_values, *kwarg_values]) fn, args = get_jvp_variant_primals_tangents(op, sample) - for loop_out, batched_out in get_fallback_and_vmap_exhaustive(fn, args, {}, opinfo=op): + is_batch_norm_and_training = is_batch_norm_training(op.name, kwarg_values) + generator = get_fallback_and_vmap_exhaustive( + fn, args, {}, is_batch_norm_and_training=is_batch_norm_and_training) + for loop_out, batched_out in generator: self.assertEqual(loop_out, batched_out) @ops(functorch_lagging_op_db, allowed_dtypes=(torch.float,)) @@ -896,8 +908,9 @@ def test(): kwarg_values = sample.kwargs args = tuple([*arg_values, *kwarg_values]) fn, args = get_jvp_variant_primals_tangents(op, sample) + is_batch_norm_and_training = is_batch_norm_training(op.name, kwarg_values) for loop_out, batched_out in get_fallback_and_vmap_exhaustive( - fn, args, {}, opinfo=op, compute_loop_out=False): + fn, args, {}, is_batch_norm_and_training=is_batch_norm_and_training, compute_loop_out=False): pass check_vmap_fallback(self, test, op, dry_run=False) @@ -1016,13 +1029,14 @@ def test(): for sample in samples: cotangents = get_sample_cotangents(op, sample) fn, args = get_vjp_fn_and_args_with_cotangents(op, sample, cotangents) + is_batch_norm_and_training = is_batch_norm_training(op.name, sample.kwargs) for loop_out, batched_out in get_fallback_and_vmap_exhaustive( - fn, args, {}, opinfo=op, compute_loop_out=False): + fn, args, {}, is_batch_norm_and_training=is_batch_norm_and_training, compute_loop_out=False): pass for a_op in op.aliases: fn, args = get_vjp_fn_and_args_with_cotangents(a_op, sample, cotangents) for loop_out, batched_out in get_fallback_and_vmap_exhaustive( - fn, args, {}, opinfo=op, compute_loop_out=False): + fn, args, {}, is_batch_norm_and_training=is_batch_norm_and_training, compute_loop_out=False): pass check_vmap_fallback(self, test, op, dry_run=False) @@ -1447,7 +1461,10 @@ def was_skipped_from_batched_tensors(batched_out, batch_size): for sample_input in sample_inputs: cotangents = get_sample_cotangents(op, sample_input) f, args = get_autograd_fn_and_args_with_cotangents(op, sample_input, cotangents) - for loop_out, batched_out in get_fallback_and_vmap_exhaustive(f, args, {}, opinfo=op): + is_batch_norm_and_training = is_batch_norm_training(op, sample_input.kwargs) + generator = get_fallback_and_vmap_exhaustive( + f, args, {}, is_batch_norm_and_training=is_batch_norm_and_training) + for loop_out, batched_out in generator: if all(was_skipped_from_batched_tensors(bo, lo.shape[0]) for (bo, lo) in zip(batched_out, loop_out)): continue # we weren't able to use the batched tensor in autograd.grad self.assertEqual(loop_out, batched_out) diff --git a/test/test_vmap.py b/test/test_vmap.py index 3e32cb13f..cd61d7aea 100644 --- a/test/test_vmap.py +++ b/test/test_vmap.py @@ -34,6 +34,7 @@ check_vmap_fallback, tol1, opsToleranceOverride, + is_batch_norm_training, ) import types from collections import namedtuple @@ -3148,8 +3149,10 @@ def test_vmap_exhaustive(self, device, dtype, op): for sample_input in sample_inputs_itr: arg_values = [sample_input.input] + list(sample_input.args) kwarg_values = sample_input.kwargs + is_batch_norm_and_training = is_batch_norm_training(op.name, kwarg_values) try: - generator = get_fallback_and_vmap_exhaustive(op.op, arg_values, kwarg_values, opinfo=op) + generator = get_fallback_and_vmap_exhaustive( + op.op, arg_values, kwarg_values, is_batch_norm_and_training=is_batch_norm_and_training) for loop_out, batched_out in generator: # empty_like and new_empty produce garbage values so we just check the shapes. if op.name == 'empty_like' or op.name == 'new_empty': @@ -3157,7 +3160,8 @@ def test_vmap_exhaustive(self, device, dtype, op): continue self.assertEqual(loop_out, batched_out) for a_op in op.aliases: - a_generator = get_fallback_and_vmap_exhaustive(a_op, arg_values, kwarg_values, opinfo=op) + a_generator = get_fallback_and_vmap_exhaustive( + a_op, arg_values, kwarg_values, is_batch_norm_and_training=is_batch_norm_and_training) for loop_out, batched_out in a_generator: self.assertEqual(loop_out, batched_out) # todo(chilli): Garbage hack I added to deal with indexing not working @@ -3294,7 +3298,9 @@ def test(): for sample_input in sample_inputs_itr: arg_values = [sample_input.input] + list(sample_input.args) kwarg_values = sample_input.kwargs - generator = get_fallback_and_vmap_exhaustive(op.op, arg_values, kwarg_values, opinfo=op) + is_batch_norm_and_training = is_batch_norm_training(op.name, kwarg_values) + generator = get_fallback_and_vmap_exhaustive( + op.op, arg_values, kwarg_values, is_batch_norm_and_training=is_batch_norm_and_training) for loop_out, batched_out in generator: # empty_like and new_empty produce garbage values so we just check the shapes. if op.name == 'empty_like' or op.name == 'new_empty': @@ -3302,7 +3308,8 @@ def test(): continue self.assertEqual(loop_out, batched_out) for a_op in op.aliases: - a_generator = get_fallback_and_vmap_exhaustive(a_op, arg_values, kwarg_values, opinfo=op) + a_generator = get_fallback_and_vmap_exhaustive( + a_op, arg_values, kwarg_values, is_batch_norm_and_training=is_batch_norm_and_training) for loop_out, batched_out in a_generator: self.assertEqual(loop_out, batched_out) check_vmap_fallback(self, test, op)