Skip to content

Commit

Permalink
Allow batch norm with all variations of batching when training=False (#…
Browse files Browse the repository at this point in the history
…958)

* allow batch norm with all variations of batching when training=False

* make running mean/var always call contiguous
  • Loading branch information
Samantha Andow committed Aug 4, 2022
1 parent 0331c43 commit a6e0e61
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 25 deletions.
12 changes: 3 additions & 9 deletions functorch/csrc/BatchRulesNorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ batch_norm_batch_rule(
auto running_mean = *running_mean_maybe_owned;
c10::MaybeOwned<Tensor> running_var_maybe_owned = at::borrow_from_optional_tensor(running_var_opt);
auto running_var = *running_var_maybe_owned;
TORCH_CHECK(!input_bdim || ((!running_mean.defined() || running_mean_bdim) && (!running_var.defined() || running_var_bdim)),
TORCH_CHECK(!training || (!input_bdim || ((!running_mean.defined() || running_mean_bdim) && (!running_var.defined() || running_var_bdim))),
"Batch norm got a batched tensor as input while the running_mean or running_var, which will be updated in place, ",
"were not batched.\nIf you are using a module and do not need eval mode, please set `track_running_stats` to be False.",
"If you are using a prebuilt module and do not need eval mode, please see the functorch website for resources on ",
Expand All @@ -85,18 +85,12 @@ batch_norm_batch_rule(
if (running_mean.defined()) {
running_mean_ = moveBatchDimToFront(running_mean, running_mean_bdim);
running_mean_ = ensure_has_bdim(*running_mean_, running_mean_bdim.has_value(), bdim_size.value());
running_mean_ = reshape_dim_into(0, 0, *running_mean_);
if (training) {
running_mean_ = running_mean_->contiguous();
}
running_mean_ = reshape_dim_into(0, 0, *running_mean_).contiguous();
}
if (running_var.defined()) {
running_var_ = moveBatchDimToFront(running_var, running_var_bdim);
running_var_ = ensure_has_bdim(*running_var_, running_var_bdim.has_value(), bdim_size.value());
running_var_ = reshape_dim_into(0, 0, *running_var_);
if (training) {
running_var_ = running_var_->contiguous();
}
running_var_ = reshape_dim_into(0, 0, *running_var_).contiguous();
}

const auto dummy_weight = at::ones(input_.size(1), input_.options()); // cudnn and miopen require a weight
Expand Down
24 changes: 20 additions & 4 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,29 @@ def get_exhaustive_batched_inputs_for_batch_norm(arg_values, kwarg_values, batch
batch_size=batch_size, bdims=bdims, for_batch_norm=True)


def get_fallback_and_vmap_exhaustive(op, arg_values, kwarg_values, opinfo=None, compute_loop_out=True, bdims=(0, -1)):
def is_batch_norm_training(op_name, kwarg_values):
batch_norm_fns = ("nn.functional.batch_norm", "nn.functional.instance_norm") # instance norm calls batch norm
if op_name not in batch_norm_fns:
return False

# batch norm and instance norm require the value to be a plain bool
default_training = op_name == "nn.functional.instance_norm" # instance norm defaults to training, batch norm doesn't
is_training = tuple(arg for arg in tuple(kwarg_values.values()) if isinstance(arg, bool))
if len(is_training) == 0:
return default_training
else:
assert len(is_training) == 1
return is_training[0]


def get_fallback_and_vmap_exhaustive(op, arg_values, kwarg_values, is_batch_norm_and_training=False, compute_loop_out=True, bdims=(0, -1)):
out_dim = 0
batch_size = 4
generator = get_exhaustive_batched_inputs(arg_values, kwarg_values, batch_size, bdims=bdims)
batch_norm_fns = ("nn.functional.batch_norm", "nn.functional.instance_norm") # instance norm calls batch norm
if opinfo is not None and opinfo.name in batch_norm_fns:
if is_batch_norm_and_training:
generator = get_exhaustive_batched_inputs_for_batch_norm(arg_values, kwarg_values, batch_size, bdims=bdims)
else:
generator = get_exhaustive_batched_inputs(arg_values, kwarg_values, batch_size, bdims=bdims)

for batched_args, in_dims, kwarg_values in generator:
if compute_loop_out:
loop_out = loop(op, in_dims, out_dim, batch_size, *batched_args, **kwarg_values)
Expand Down
33 changes: 25 additions & 8 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
# tol2,
opsToleranceOverride,
check_vmap_fallback,
is_batch_norm_training,
)
from torch.utils._pytree import tree_flatten, tree_unflatten, tree_map
from functorch import grad, vjp, vmap, jacrev, jacfwd
Expand Down Expand Up @@ -570,7 +571,9 @@ def vjp_of_vjp(*args_and_cotangents):
result_vjps, _ = tree_flatten(result_vjps)
return (*result, *result_vjps)

generator = get_fallback_and_vmap_exhaustive(vjp_of_vjp, args_and_cotangents, {}, opinfo=op)
is_batch_norm_and_training = is_batch_norm_training(op.name, sample.kwargs)
generator = get_fallback_and_vmap_exhaustive(
vjp_of_vjp, args_and_cotangents, {}, is_batch_norm_and_training=is_batch_norm_and_training)
for loop_out, batched_out in generator:
self.assertEqual(loop_out, batched_out)

Expand Down Expand Up @@ -642,7 +645,10 @@ def test_vmapvjp(self, device, dtype, op):
for sample in samples:
cotangents = get_sample_cotangents(op, sample)
fn, args = get_vjp_fn_and_args_with_cotangents(op, sample, cotangents)
for loop_out, batched_out in get_fallback_and_vmap_exhaustive(fn, args, {}, opinfo=op):
is_batch_norm_and_training = is_batch_norm_training(op.name, sample.kwargs)
generator = get_fallback_and_vmap_exhaustive(
fn, args, {}, is_batch_norm_and_training=is_batch_norm_and_training)
for loop_out, batched_out in generator:
self.assertEqual(loop_out, batched_out)

# There are several variations we care about
Expand Down Expand Up @@ -731,7 +737,10 @@ def test_vmapjvp(self, device, dtype, op):
kwarg_values = sample.kwargs
args = tuple([*arg_values, *kwarg_values])
fn, args = get_jvp_variant(op, sample)
for loop_out, batched_out in get_fallback_and_vmap_exhaustive(fn, args, {}, opinfo=op, bdims=(0,)):
is_batch_norm_and_training = is_batch_norm_training(op, kwarg_values)
generator = get_fallback_and_vmap_exhaustive(
fn, args, {}, is_batch_norm_and_training=is_batch_norm_and_training, bdims=(0,))
for loop_out, batched_out in generator:
self.assertEqual(loop_out, batched_out)

vmapjvpall_fail = {
Expand Down Expand Up @@ -819,7 +828,10 @@ def test_vmapjvpall(self, device, dtype, op):
kwarg_values = sample.kwargs
args = tuple([*arg_values, *kwarg_values])
fn, args = get_jvp_variant_primals_tangents(op, sample)
for loop_out, batched_out in get_fallback_and_vmap_exhaustive(fn, args, {}, opinfo=op):
is_batch_norm_and_training = is_batch_norm_training(op.name, kwarg_values)
generator = get_fallback_and_vmap_exhaustive(
fn, args, {}, is_batch_norm_and_training=is_batch_norm_and_training)
for loop_out, batched_out in generator:
self.assertEqual(loop_out, batched_out)

@ops(functorch_lagging_op_db, allowed_dtypes=(torch.float,))
Expand Down Expand Up @@ -896,8 +908,9 @@ def test():
kwarg_values = sample.kwargs
args = tuple([*arg_values, *kwarg_values])
fn, args = get_jvp_variant_primals_tangents(op, sample)
is_batch_norm_and_training = is_batch_norm_training(op.name, kwarg_values)
for loop_out, batched_out in get_fallback_and_vmap_exhaustive(
fn, args, {}, opinfo=op, compute_loop_out=False):
fn, args, {}, is_batch_norm_and_training=is_batch_norm_and_training, compute_loop_out=False):
pass
check_vmap_fallback(self, test, op, dry_run=False)

Expand Down Expand Up @@ -1016,13 +1029,14 @@ def test():
for sample in samples:
cotangents = get_sample_cotangents(op, sample)
fn, args = get_vjp_fn_and_args_with_cotangents(op, sample, cotangents)
is_batch_norm_and_training = is_batch_norm_training(op.name, sample.kwargs)
for loop_out, batched_out in get_fallback_and_vmap_exhaustive(
fn, args, {}, opinfo=op, compute_loop_out=False):
fn, args, {}, is_batch_norm_and_training=is_batch_norm_and_training, compute_loop_out=False):
pass
for a_op in op.aliases:
fn, args = get_vjp_fn_and_args_with_cotangents(a_op, sample, cotangents)
for loop_out, batched_out in get_fallback_and_vmap_exhaustive(
fn, args, {}, opinfo=op, compute_loop_out=False):
fn, args, {}, is_batch_norm_and_training=is_batch_norm_and_training, compute_loop_out=False):
pass

check_vmap_fallback(self, test, op, dry_run=False)
Expand Down Expand Up @@ -1447,7 +1461,10 @@ def was_skipped_from_batched_tensors(batched_out, batch_size):
for sample_input in sample_inputs:
cotangents = get_sample_cotangents(op, sample_input)
f, args = get_autograd_fn_and_args_with_cotangents(op, sample_input, cotangents)
for loop_out, batched_out in get_fallback_and_vmap_exhaustive(f, args, {}, opinfo=op):
is_batch_norm_and_training = is_batch_norm_training(op, sample_input.kwargs)
generator = get_fallback_and_vmap_exhaustive(
f, args, {}, is_batch_norm_and_training=is_batch_norm_and_training)
for loop_out, batched_out in generator:
if all(was_skipped_from_batched_tensors(bo, lo.shape[0]) for (bo, lo) in zip(batched_out, loop_out)):
continue # we weren't able to use the batched tensor in autograd.grad
self.assertEqual(loop_out, batched_out)
Expand Down
15 changes: 11 additions & 4 deletions test/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
check_vmap_fallback,
tol1,
opsToleranceOverride,
is_batch_norm_training,
)
import types
from collections import namedtuple
Expand Down Expand Up @@ -3148,16 +3149,19 @@ def test_vmap_exhaustive(self, device, dtype, op):
for sample_input in sample_inputs_itr:
arg_values = [sample_input.input] + list(sample_input.args)
kwarg_values = sample_input.kwargs
is_batch_norm_and_training = is_batch_norm_training(op.name, kwarg_values)
try:
generator = get_fallback_and_vmap_exhaustive(op.op, arg_values, kwarg_values, opinfo=op)
generator = get_fallback_and_vmap_exhaustive(
op.op, arg_values, kwarg_values, is_batch_norm_and_training=is_batch_norm_and_training)
for loop_out, batched_out in generator:
# empty_like and new_empty produce garbage values so we just check the shapes.
if op.name == 'empty_like' or op.name == 'new_empty':
self.assertEqual(loop_out.shape, batched_out.shape)
continue
self.assertEqual(loop_out, batched_out)
for a_op in op.aliases:
a_generator = get_fallback_and_vmap_exhaustive(a_op, arg_values, kwarg_values, opinfo=op)
a_generator = get_fallback_and_vmap_exhaustive(
a_op, arg_values, kwarg_values, is_batch_norm_and_training=is_batch_norm_and_training)
for loop_out, batched_out in a_generator:
self.assertEqual(loop_out, batched_out)
# todo(chilli): Garbage hack I added to deal with indexing not working
Expand Down Expand Up @@ -3294,15 +3298,18 @@ def test():
for sample_input in sample_inputs_itr:
arg_values = [sample_input.input] + list(sample_input.args)
kwarg_values = sample_input.kwargs
generator = get_fallback_and_vmap_exhaustive(op.op, arg_values, kwarg_values, opinfo=op)
is_batch_norm_and_training = is_batch_norm_training(op.name, kwarg_values)
generator = get_fallback_and_vmap_exhaustive(
op.op, arg_values, kwarg_values, is_batch_norm_and_training=is_batch_norm_and_training)
for loop_out, batched_out in generator:
# empty_like and new_empty produce garbage values so we just check the shapes.
if op.name == 'empty_like' or op.name == 'new_empty':
self.assertEqual(loop_out.shape, batched_out.shape)
continue
self.assertEqual(loop_out, batched_out)
for a_op in op.aliases:
a_generator = get_fallback_and_vmap_exhaustive(a_op, arg_values, kwarg_values, opinfo=op)
a_generator = get_fallback_and_vmap_exhaustive(
a_op, arg_values, kwarg_values, is_batch_norm_and_training=is_batch_norm_and_training)
for loop_out, batched_out in a_generator:
self.assertEqual(loop_out, batched_out)
check_vmap_fallback(self, test, op)
Expand Down

0 comments on commit a6e0e61

Please sign in to comment.