Skip to content

Commit

Permalink
Generate n^2 not n^3 inputs for batch and instance norm; small batch …
Browse files Browse the repository at this point in the history
…norm fix (#951)

* refactor batch norm exhaustive inputs

* fix typo in batch rule

* fix expand issue, add without cudnn xfail
  • Loading branch information
Samantha Andow committed Aug 4, 2022
1 parent ba6ebd7 commit 0331c43
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 7 deletions.
8 changes: 7 additions & 1 deletion functorch/csrc/BatchRulesNorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ batch_norm_batch_rule(
mean = std::get<1>(result);
rstd = std::get<2>(result);
} else {
bdim_size = get_bdim_size3(input, input_bdim, running_mean, running_mean_bdim, running_var, running_mean_bdim);
bdim_size = get_bdim_size3(input, input_bdim, running_mean, running_mean_bdim, running_var, running_var_bdim);
auto input_ = moveBatchDimToFront(input, input_bdim);
input_ = ensure_has_bdim(input_, input_bdim.has_value(), bdim_size.value());
input_ = reshape_dim_into(0, /*channels dim*/1, input_);
Expand All @@ -86,11 +86,17 @@ batch_norm_batch_rule(
running_mean_ = moveBatchDimToFront(running_mean, running_mean_bdim);
running_mean_ = ensure_has_bdim(*running_mean_, running_mean_bdim.has_value(), bdim_size.value());
running_mean_ = reshape_dim_into(0, 0, *running_mean_);
if (training) {
running_mean_ = running_mean_->contiguous();
}
}
if (running_var.defined()) {
running_var_ = moveBatchDimToFront(running_var, running_var_bdim);
running_var_ = ensure_has_bdim(*running_var_, running_var_bdim.has_value(), bdim_size.value());
running_var_ = reshape_dim_into(0, 0, *running_var_);
if (training) {
running_var_ = running_var_->contiguous();
}
}

const auto dummy_weight = at::ones(input_.size(1), input_.options()); // cudnn and miopen require a weight
Expand Down
2 changes: 0 additions & 2 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ def add_batch_choices(a):

if all([i is None for i in in_dims]):
continue

yield pytree.tree_unflatten(batched_args, arg_spec), pytree.tree_unflatten(in_dims, arg_spec), kwarg_values

if for_batch_norm and len(orig_flat_args) >= 2:
Expand All @@ -111,7 +110,6 @@ def add_batch_choices(a):
in_dims_tuple = pytree.tree_unflatten(in_dims, arg_spec)
yield batched_args_tuple, in_dims_tuple, kwarg_values


def get_exhaustive_batched_inputs_for_batch_norm(arg_values, kwarg_values, batch_size=3, bdims=(0, -1)):
return get_exhaustive_batched_inputs(arg_values, kwarg_values,
batch_size=batch_size, bdims=bdims, for_batch_norm=True)
Expand Down
13 changes: 9 additions & 4 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from common_utils import (
get_fallback_and_vmap_exhaustive,
get_exhaustive_batched_inputs,
get_exhaustive_batched_inputs_for_batch_norm,
xfail,
skip,
skipOps,
Expand Down Expand Up @@ -746,14 +747,11 @@ def test_vmapjvp(self, device, dtype, op):

# The following are bugs that we should fix
skip('nn.functional.max_pool1d'), # fails on cpu, runs on cuda
xfail('nn.functional.batch_norm', device_type='cuda'),
xfail('nn.functional.batch_norm', 'without_cudnn', device_type='cuda'),
xfail('nn.functional.hinge_embedding_loss', device_type='cuda'),
xfail('_masked.mean'),
xfail('_masked.prod'),

# Causing issues with multiple cpu levels of forward mode AD
xfail('nn.functional.batch_norm', device_type='cpu'),
xfail('nn.functional.hinge_embedding_loss', device_type='cpu'),

xfail('nn.functional.soft_margin_loss', ''),
Expand Down Expand Up @@ -788,6 +786,10 @@ def test_vmapjvp(self, device, dtype, op):
xfail('nn.functional.embedding'), # embedding_renorm_ does not support fwd AD
xfail('put'), # calls put_ during vmap with only vmaps over other, not self
xfail('nn.functional.prelu'), # Call Tensor.as_strided

# erroring because running_mean and running_var aren't differentiable
xfail('nn.functional.batch_norm'),
xfail('nn.functional.batch_norm', 'without_cudnn'),
}

@ops(functorch_lagging_op_db, allowed_dtypes=(torch.float,))
Expand Down Expand Up @@ -1082,7 +1084,10 @@ def test_vjpvmap(self, device, dtype, op):
for sample in samples:
args = [sample.input] + list(sample.args)
kwargs = sample.kwargs
generator = get_exhaustive_batched_inputs(args, kwargs, for_batch_norm=is_batch_norm)
if is_batch_norm:
generator = get_exhaustive_batched_inputs_for_batch_norm(args, kwargs)
else:
generator = get_exhaustive_batched_inputs(args, kwargs)

for batched_args, in_dims, kwargs in generator:
vmapped_op = vmap(op, in_dims)
Expand Down

0 comments on commit 0331c43

Please sign in to comment.