Skip to content

Commit

Permalink
[reland] Kill capture_pre_autograd_graph API (pytorch#143426)
Browse files Browse the repository at this point in the history
Summary:
Delete the following API:

- capture_pre_autograd_graph()
- capture_pre_autograd_graph_using_training_ir()
- gm_using_training_ir()

Update XLA pin to include pytorch/xla#8398

There's no more call sites to `capture_pre_autograd_graph`.

Except
1) two test cases in coreml, guarded by version guard, PR to remove: apple/coremltools#2400
2) a few call sites guarded by version guard (< 2.5.0)

Test Plan: CI

Differential Revision: D67354440

Pull Request resolved: pytorch#143426
Approved by: https://github.com/gmagogsfm
  • Loading branch information
yushangdi authored and aditew01 committed Dec 18, 2024
1 parent c24f0c4 commit 8837af9
Show file tree
Hide file tree
Showing 8 changed files with 6 additions and 250 deletions.
2 changes: 1 addition & 1 deletion .github/ci_commit_pins/xla.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
73f54ba5bd7fb83d7ba81fe6f5e05fb6ee815d6f
b2b890e962f5fb6f481e5da2eb4a43bb990d0f1b
209 changes: 0 additions & 209 deletions torch/_export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,215 +58,6 @@ class ExportDynamoConfig:
allow_rnn: bool = True


# We only want to print this once to avoid flooding logs in workflows where capture_pre_autograd_graph
# is called multiple times.
@lru_cache
def capture_pre_autograd_graph_warning():
from torch._inductor import config

log.warning("+============================+")
log.warning("| !!! WARNING !!! |")
log.warning("+============================+")
log.warning("capture_pre_autograd_graph() is deprecated and doesn't provide any function guarantee moving forward.")
log.warning("Please switch to use torch.export.export_for_training instead.")
if config.is_fbcode():
log.warning("For unittest, capture_pre_autograd_graph() will fallback to torch.export.export_for_training.") # noqa: B950

@lru_cache
def print_export_warning():
log.warning("Using torch.export.export_for_training(...,strict=True)")

def gm_using_training_ir(graph_module: torch.fx.GraphModule) -> bool:
"""
Returns true if the graph module is detected to use training IR.
This function checks for two specific conditions within the nodes of the graph module:
1. The presence of the `torch.ops.aten.batch_norm.default` operation which indicates the use of training IR.
2. The presence of deprecated IR tags on node meta or batch norm ops produced by the deprecated IR.
The function raises a RuntimeError if both conditions are met, indicating a conflict in the IR.
"""
# TODO: clean up this code after training IR migration.
# T199018392
has_training_ir_batch_norm = False
has_deprecated_ir_tag = getattr(graph_module, "capture_pre_autograd_graph_tag", False)
for node in graph_module.graph.nodes:
if node.op == "call_function":
if node.target == torch.ops.aten.batch_norm.default:
has_training_ir_batch_norm = True
if node.meta.get("capture_pre_autograd_graph_tag", False):
has_deprecated_ir_tag = True
if node.target in [
torch.ops.aten._native_batch_norm_legit.default,
torch.ops.aten.cudnn_batch_norm.default,
torch.ops.aten.miopen_batch_norm.default,
]:
has_deprecated_ir_tag = True

if has_deprecated_ir_tag and has_training_ir_batch_norm:
raise RuntimeError("Conflicting IR detected.")
return has_training_ir_batch_norm or not has_deprecated_ir_tag

@compatibility(is_backward_compatible=False)
def capture_pre_autograd_graph(
f: torch.nn.Module,
args: Tuple[Any],
kwargs: Optional[Dict[str, Any]] = None,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
) -> torch.nn.Module:
"""
A helper function that is intended to trace a module before any pre-autograd
decomposition is run. The produced module will be "non-functional" and
composed of aten operators. Later this API will be deleted in favor of more general
torch.export API.
Args:
f: nn.Module to be traced
args: example positional inputs.
kwargs: optional example keyword inputs.
dynamic_shapes: Should either be:
1) a dict from argument names of ``f`` to their dynamic shape specifications,
2) a tuple that specifies dynamic shape specifications for each input in original order.
If you are specifying dynamism on keyword args, you will need to pass them in the order that
is defined in the original function signature.
The dynamic shape of a tensor argument can be specified as either
(1) a dict from dynamic dimension indices to :func:`Dim` types, where it is
not required to include static dimension indices in this dict, but when they are,
they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None,
where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions
are denoted by None. Arguments that are dicts or tuples / lists of tensors are
recursively specified by using mappings or sequences of contained specifications.
Returns:
An nn.Module containing the traced method.
"""
from torch.export._trace import _extract_fake_inputs, DEFAULT_EXPORT_DYNAMO_CONFIG, _ignore_backend_decomps
from torch._utils_internal import capture_pre_autograd_graph_using_training_ir
from torch._export.non_strict_utils import make_constraints
from torch._subclasses.functional_tensor import FunctionalTensor
from torch.export._unlift import _create_stateful_graph_module
from torch.export.dynamic_shapes import _combine_args

capture_pre_autograd_graph_warning()

if sys.platform == "win32":
raise RuntimeError("capture_pre_autograd_graph not yet supported on Windows")

assert isinstance(f, torch.nn.Module), "Expected an nn.Module instance."

if kwargs is None:
kwargs = {}

if capture_pre_autograd_graph_using_training_ir():
print_export_warning()
module = torch.export.export_for_training(f, args, kwargs, dynamic_shapes=dynamic_shapes, strict=True).module()
else:
log_export_usage(event="export.private_api", flags={"capture_pre_autograd_graph"})

# Do not decompose dropout for exported models, because in eval mode the dropout
# op disappears from the graph, which makes it difficult to switch to train mode.
# See https://github.com/pytorch/pytorch/pull/115258#issuecomment-1900755832.

# We force create native_batch_norm because the below materialization logic
# only applies to CIA ops.
maybe_aliasing_or_mutating_ops = [torch.ops.aten.native_batch_norm.default]

_materialize_cpp_cia_ops()

for op in torch.ops.aten:
op_obj = getattr(torch.ops.aten, op)
for overload in op_obj.overloads():
op_overload = getattr(op_obj, overload)
if torch.Tag.maybe_aliasing_or_mutating in op_overload.tags:
maybe_aliasing_or_mutating_ops.append(op_overload)

decomp_table = {
op: op.decompose
for op in maybe_aliasing_or_mutating_ops
if op != torch.ops.aten.dropout.default
}
with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)), _ignore_backend_decomps():
m = torch._dynamo.export(
f,
dynamic_shapes=dynamic_shapes,
assume_static_by_default=True,
tracing_mode="symbolic",
decomposition_table=decomp_table,
pre_dispatch=True,
aten_graph=True,
_log_export_usage=False,
)(
*args,
**kwargs,
)[0]

_, _, fake_mode = _extract_fake_inputs(m, args, kwargs)

m.meta["inline_constraints"] = {
k: v
for k, v in fake_mode.shape_env.var_to_range.items()
if re.match(r"^[if]\d+$", str(k))
}

if isinstance(f, torch.nn.Module):
from torch.export._trace import _restore_state_dict
_restore_state_dict(f, m)

combined_args = _combine_args(f, args, kwargs)
range_constraints = make_constraints(
fake_mode,
m,
combined_args,
dynamic_shapes,
0,
)

module = _create_stateful_graph_module(
m,
range_constraints=range_constraints,
)

setattr(module, "capture_pre_autograd_graph_tag", True) # noqa: B010
for node in module.graph.nodes:
node.meta["capture_pre_autograd_graph_tag"] = True

error_message = \
"""
Calling train() or eval() is not supported for exported models.
Alternatively, you may override these methods to do custom user behavior as follows:
def _my_train(self, mode: bool = True):
...
def _my_eval(self):
...
model.train = types.MethodType(_my_train, model)
model.eval = types.MethodType(_my_eval, model)
"""

def _train(self, mode: bool = True):
raise NotImplementedError(error_message)

def _eval(self, mode: bool = True):
raise NotImplementedError(error_message)

module.train = types.MethodType(_train, module) # type: ignore[method-assign]
module.eval = types.MethodType(_eval, module) # type: ignore[method-assign]

# Remove Proxy because they cannot be deepcopied or pickled.
if hasattr(module, "_buffers"):
torch._export.utils.remove_proxy_from_state_dict(
module._buffers, in_place=True
)
return module


# We only want to print this once to avoid flooding logs in workflows where aot_compile_warning
# is called multiple times.
@lru_cache
Expand Down
4 changes: 0 additions & 4 deletions torch/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,6 @@ def log_torch_jit_trace_exportability(
return


def capture_pre_autograd_graph_using_training_ir() -> bool:
return False


def justknobs_check(name: str, default: bool = True) -> bool:
"""
This function can be used to killswitch functionality in FB prod,
Expand Down
14 changes: 0 additions & 14 deletions torch/ao/quantization/pt2e/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,6 @@ def _replace_dropout(m: torch.fx.GraphModule, train_to_eval: bool):
m.graph.eliminate_dead_code()
m.recompile()

from torch._export import gm_using_training_ir

using_training_ir = gm_using_training_ir(m)

for inplace in [False, True]:

def dropout_train(x):
Expand All @@ -72,23 +68,19 @@ def dropout_eval(x):
match_pattern = _get_aten_graph_module_for_pattern(
_WrapperModule(dropout_train),
example_inputs,
using_training_ir=using_training_ir,
)
replacement_pattern = _get_aten_graph_module_for_pattern(
_WrapperModule(dropout_eval),
example_inputs,
using_training_ir=using_training_ir,
)
else:
match_pattern = _get_aten_graph_module_for_pattern(
_WrapperModule(dropout_eval),
example_inputs,
using_training_ir=using_training_ir,
)
replacement_pattern = _get_aten_graph_module_for_pattern(
_WrapperModule(dropout_train),
example_inputs,
using_training_ir=using_training_ir,
)

from torch.fx.subgraph_rewriter import replace_pattern_with_filters
Expand Down Expand Up @@ -122,10 +114,6 @@ def _replace_batchnorm(m: torch.fx.GraphModule, train_to_eval: bool):
m.graph.eliminate_dead_code()
m.recompile()

from torch._export import gm_using_training_ir

using_training_ir = gm_using_training_ir(m)

def bn_train(
x: torch.Tensor,
bn_weight: torch.Tensor,
Expand Down Expand Up @@ -162,13 +150,11 @@ def bn_eval(
_WrapperModule(bn_train),
example_inputs,
is_cuda,
using_training_ir=using_training_ir,
)
bn_eval_aten = _get_aten_graph_module_for_pattern(
_WrapperModule(bn_eval),
example_inputs,
is_cuda,
using_training_ir=using_training_ir,
)

if train_to_eval:
Expand Down
12 changes: 0 additions & 12 deletions torch/ao/quantization/pt2e/qat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,16 +667,11 @@ def _fuse_conv_bn_qat_helper(
m.graph.eliminate_dead_code()
m.recompile()

from torch._export import gm_using_training_ir

using_training_ir = gm_using_training_ir(m)

conv_bn_pattern = _get_conv_bn_pattern(conv_fn)
match_pattern = _get_aten_graph_module_for_pattern(
conv_bn_pattern,
example_inputs,
is_cuda,
using_training_ir=using_training_ir,
)

# Step (1): Replace patterns with conv bias
Expand All @@ -690,7 +685,6 @@ def _fuse_conv_bn_qat_helper(
qat_conv_bn_pattern,
example_inputs,
is_cuda,
using_training_ir=using_training_ir,
)
replacements_with_conv_bias = replace_pattern_with_filters(
m,
Expand All @@ -708,7 +702,6 @@ def _fuse_conv_bn_qat_helper(
qat_conv_bn_pattern_no_conv_bias,
example_inputs,
is_cuda,
using_training_ir=using_training_ir,
)
replacements_no_conv_bias = replace_pattern_with_filters(
m,
Expand Down Expand Up @@ -922,9 +915,6 @@ def _fold_conv_bn_qat_helper(
"""
Replace the quantized (conv + bn) pattern with conv with bn weights folded into the weights of conv.
"""
from torch._export import gm_using_training_ir

using_training_ir = gm_using_training_ir(m)

m.graph.eliminate_dead_code()
m.recompile()
Expand Down Expand Up @@ -958,7 +948,6 @@ def _fold_conv_bn_qat_helper(
match_pattern,
example_inputs,
is_cuda,
using_training_ir=using_training_ir,
**kwargs,
)
replacement_pattern = _get_folded_quantized_qat_conv_bn_pattern(
Expand All @@ -968,7 +957,6 @@ def _fold_conv_bn_qat_helper(
replacement_pattern,
example_inputs,
is_cuda,
using_training_ir=using_training_ir,
**kwargs,
)
replacements.extend(
Expand Down
7 changes: 2 additions & 5 deletions torch/ao/quantization/pt2e/representation/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,19 +797,16 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule:
]

remove_tensor_overload_for_qdq_ops(model)
from torch._export import gm_using_training_ir

using_training_ir = gm_using_training_ir(model)

for rewrite_info in _REWRITE_INFO_LIST:
example_inputs = rewrite_info.example_inputs
pattern = rewrite_info.pattern
replacement = rewrite_info.replacement
pattern_post_trans = rewrite_info.pattern_post_trans
replacement_post_trans = rewrite_info.replacement_post_trans
pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs, using_training_ir=using_training_ir) # type: ignore[arg-type, assignment]
pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs) # type: ignore[arg-type, assignment]
remove_tensor_overload_for_qdq_ops(pattern) # type: ignore[arg-type]
replacement = _get_aten_graph_module_for_pattern(replacement, example_inputs, using_training_ir=using_training_ir) # type: ignore[arg-type, assignment]
replacement = _get_aten_graph_module_for_pattern(replacement, example_inputs) # type: ignore[arg-type, assignment]
remove_tensor_overload_for_qdq_ops(replacement) # type: ignore[arg-type]
if pattern_post_trans:
pattern = pattern_post_trans(pattern)
Expand Down
2 changes: 2 additions & 0 deletions torch/ao/quantization/pt2e/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,8 @@ def _get_aten_graph_module_for_pattern(
[x.cuda() if isinstance(x, torch.Tensor) else x for x in example_inputs]
)

# T199018392
# TODO: remove the using_training_ir flag from function
if using_training_ir:
aten_pattern = torch.export.export_for_training(
pattern, # type: ignore[arg-type]
Expand Down
6 changes: 1 addition & 5 deletions torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,10 +530,6 @@ def _conv_bn(x, conv_weight, conv_bias, bn_weight, bn_bias, bn_rm, bn_rv):
gm.graph.eliminate_dead_code()
gm.recompile()

from torch._export import gm_using_training_ir

using_training_ir = gm_using_training_ir(gm)

matches = []
if is_conv_transpose:
combinations = [
Expand All @@ -556,7 +552,7 @@ def _conv_bn(x, conv_weight, conv_bias, bn_weight, bn_bias, bn_rm, bn_rv):
# Match against all conv dimensions and cuda variants
for (conv_fn, example_inputs), is_cuda, relu_is_inplace in combinations: # type: ignore[misc]
pattern = get_pattern(conv_fn, relu_is_inplace) # type: ignore[has-type]
pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs, is_cuda, using_training_ir=using_training_ir) # type: ignore[has-type]
pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs, is_cuda) # type: ignore[has-type]
pattern.graph.eliminate_dead_code()
pattern.recompile()
matcher = SubgraphMatcherWithNameNodeMap(pattern, ignore_literals=True)
Expand Down

0 comments on commit 8837af9

Please sign in to comment.