diff --git a/functorch/csrc/BatchRulesNorm.cpp b/functorch/csrc/BatchRulesNorm.cpp index 0bc0a958d..c5b273652 100644 --- a/functorch/csrc/BatchRulesNorm.cpp +++ b/functorch/csrc/BatchRulesNorm.cpp @@ -75,7 +75,7 @@ batch_norm_batch_rule( mean = std::get<1>(result); rstd = std::get<2>(result); } else { - bdim_size = get_bdim_size3(input, input_bdim, running_mean, running_mean_bdim, running_var, running_mean_bdim); + bdim_size = get_bdim_size3(input, input_bdim, running_mean, running_mean_bdim, running_var, running_var_bdim); auto input_ = moveBatchDimToFront(input, input_bdim); input_ = ensure_has_bdim(input_, input_bdim.has_value(), bdim_size.value()); input_ = reshape_dim_into(0, /*channels dim*/1, input_); @@ -86,11 +86,17 @@ batch_norm_batch_rule( running_mean_ = moveBatchDimToFront(running_mean, running_mean_bdim); running_mean_ = ensure_has_bdim(*running_mean_, running_mean_bdim.has_value(), bdim_size.value()); running_mean_ = reshape_dim_into(0, 0, *running_mean_); + if (training) { + running_mean_ = running_mean_->contiguous(); + } } if (running_var.defined()) { running_var_ = moveBatchDimToFront(running_var, running_var_bdim); running_var_ = ensure_has_bdim(*running_var_, running_var_bdim.has_value(), bdim_size.value()); running_var_ = reshape_dim_into(0, 0, *running_var_); + if (training) { + running_var_ = running_var_->contiguous(); + } } const auto dummy_weight = at::ones(input_.size(1), input_.options()); // cudnn and miopen require a weight diff --git a/test/common_utils.py b/test/common_utils.py index f8153c7b9..19c702ae8 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -92,7 +92,6 @@ def add_batch_choices(a): if all([i is None for i in in_dims]): continue - yield pytree.tree_unflatten(batched_args, arg_spec), pytree.tree_unflatten(in_dims, arg_spec), kwarg_values if for_batch_norm and len(orig_flat_args) >= 2: @@ -111,7 +110,6 @@ def add_batch_choices(a): in_dims_tuple = pytree.tree_unflatten(in_dims, arg_spec) yield batched_args_tuple, in_dims_tuple, kwarg_values - def get_exhaustive_batched_inputs_for_batch_norm(arg_values, kwarg_values, batch_size=3, bdims=(0, -1)): return get_exhaustive_batched_inputs(arg_values, kwarg_values, batch_size=batch_size, bdims=bdims, for_batch_norm=True) diff --git a/test/test_ops.py b/test/test_ops.py index 23a4c63da..b994f1ac9 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -19,6 +19,7 @@ from common_utils import ( get_fallback_and_vmap_exhaustive, get_exhaustive_batched_inputs, + get_exhaustive_batched_inputs_for_batch_norm, xfail, skip, skipOps, @@ -746,14 +747,11 @@ def test_vmapjvp(self, device, dtype, op): # The following are bugs that we should fix skip('nn.functional.max_pool1d'), # fails on cpu, runs on cuda - xfail('nn.functional.batch_norm', device_type='cuda'), - xfail('nn.functional.batch_norm', 'without_cudnn', device_type='cuda'), xfail('nn.functional.hinge_embedding_loss', device_type='cuda'), xfail('_masked.mean'), xfail('_masked.prod'), # Causing issues with multiple cpu levels of forward mode AD - xfail('nn.functional.batch_norm', device_type='cpu'), xfail('nn.functional.hinge_embedding_loss', device_type='cpu'), xfail('nn.functional.soft_margin_loss', ''), @@ -788,6 +786,10 @@ def test_vmapjvp(self, device, dtype, op): xfail('nn.functional.embedding'), # embedding_renorm_ does not support fwd AD xfail('put'), # calls put_ during vmap with only vmaps over other, not self xfail('nn.functional.prelu'), # Call Tensor.as_strided + + # erroring because running_mean and running_var aren't differentiable + xfail('nn.functional.batch_norm'), + xfail('nn.functional.batch_norm', 'without_cudnn'), } @ops(functorch_lagging_op_db, allowed_dtypes=(torch.float,)) @@ -1082,7 +1084,10 @@ def test_vjpvmap(self, device, dtype, op): for sample in samples: args = [sample.input] + list(sample.args) kwargs = sample.kwargs - generator = get_exhaustive_batched_inputs(args, kwargs, for_batch_norm=is_batch_norm) + if is_batch_norm: + generator = get_exhaustive_batched_inputs_for_batch_norm(args, kwargs) + else: + generator = get_exhaustive_batched_inputs(args, kwargs) for batched_args, in_dims, kwargs in generator: vmapped_op = vmap(op, in_dims)