From ef51f93df48213237bec1c764483168ab4a8df46 Mon Sep 17 00:00:00 2001 From: Yanming Wang Date: Tue, 25 Jul 2023 23:53:17 +0000 Subject: [PATCH 1/2] [fuse_elementwise] Fix external outputs collection when one subgraph has multiple external outputs with different shapes --- .../aitemplate/compiler/transform/fuse_ops.py | 69 ++++++----- .../ops/test_fused_elementwise_broadcast.py | 107 +++++++++++++++++- 2 files changed, 140 insertions(+), 36 deletions(-) diff --git a/python/aitemplate/compiler/transform/fuse_ops.py b/python/aitemplate/compiler/transform/fuse_ops.py index 60a496553..5685736fe 100644 --- a/python/aitemplate/compiler/transform/fuse_ops.py +++ b/python/aitemplate/compiler/transform/fuse_ops.py @@ -17,7 +17,7 @@ """ import collections import logging -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Set from aitemplate.compiler.base import Operator, Tensor @@ -75,7 +75,7 @@ def get_node_groups(self) -> List[Set[Any]]: return node_groups -def _find_fusable_elementwise_ops(op: Operator) -> Set[Operator]: +def _find_fusable_elementwise_ops(src_op: Operator) -> Set[Operator]: """ Given an elementwise op, returns a list of parent elementwise ops which can be fused with this elementwise op. @@ -83,7 +83,7 @@ def _find_fusable_elementwise_ops(op: Operator) -> Set[Operator]: # Get parent ops. dependent_ops = set() - for input_tensor in op._attrs["inputs"]: + for input_tensor in src_op._attrs["inputs"]: dependent_ops.update(input_tensor._attrs["src_ops"]) original_ops = set(dependent_ops) @@ -147,16 +147,23 @@ class FusedElementwiseInfo: external_outputs: Set[Tensor] -def _partition_subgraphs(ops: Set[Operator]) -> Dict[str, Set[Operator]]: +@dataclass +class SubgraphInfo: + partitioned_ops: Set[Operator] = field(default_factory=set) + external_outputs: Set[Tensor] = field(default_factory=set) + + +def _partition_subgraphs(ops: Set[Operator]) -> Dict[str, SubgraphInfo]: """ Given ops of candidate graph of fused_elementwise op graph and partition into subgraph based on output shape, returns dict of - {output shape: ops to form subgraph based on the shape} + {output shape: ops to form subgraph based on the shape and external outputs of the subgraph} """ # Partition graph of elementwise into subgraph based on output shape. - output_op_map = collections.defaultdict(set) + subgraph_info_map = collections.defaultdict(SubgraphInfo) for op in ops: shapes = [] + external_outputs = [] # Find output nodes for output_tensor in op._attrs["outputs"]: if ( @@ -164,16 +171,19 @@ def _partition_subgraphs(ops: Set[Operator]) -> Dict[str, Set[Operator]]: or len(output_tensor._attrs["dst_ops"] - ops) > 0 ): shapes.append("_".join(map(str, output_tensor._attrs["shape"]))) + external_outputs.append(output_tensor) # Find anscestor of output node. # Outputs with the same shape should form the same graph if shapes: key = "|".join(shapes) - op_set = output_op_map[key] + subgraph_info = subgraph_info_map[key] + subgraph_info.external_outputs.update(external_outputs) + op_set = subgraph_info.partitioned_ops for anc_op in ops: if transform_utils.is_ancestor(anc_op, op): op_set.add(anc_op) op_set.add(op) - return output_op_map + return subgraph_info_map def _get_inputs_outputs( @@ -182,11 +192,9 @@ def _get_inputs_outputs( """ Given ops of a partitioned subgraph based on output shape, and ops of full graph to form a complete graph with fused_elementwise op, returns all inputs/outputs of - the ops and the external input/output of the subgraph, which will serve as input/output - of fused_elementwise op. + the ops and the external input of the subgraph, which will serve as input of fused_elementwise op. """ external_inputs = set() - external_outputs = set() tmp_inputs = set() tmp_outputs = set() @@ -201,9 +209,6 @@ def _get_inputs_outputs( assert op in input_tensor._attrs["dst_ops"] for output_tensor in op._attrs["outputs"]: tmp_outputs.add(output_tensor) - dst_ops = set(output_tensor._attrs["dst_ops"]) - if output_tensor._attrs["is_output"] or len(dst_ops - all_ops) > 0: - external_outputs.add(output_tensor) assert len(output_tensor._attrs["src_ops"]) == 1 assert list(output_tensor._attrs["src_ops"])[0] == op @@ -212,22 +217,11 @@ def _get_inputs_outputs( ), "external_inputs: {} is not equal to tmp_inputs: {} - tmp_outputs: {}.".format( external_inputs, tmp_inputs, tmp_outputs ) - assert ( - len(tmp_outputs - tmp_inputs - external_outputs) == 0 - ), "tmp_outputs: {} - tmp_inputs: {} - external_outputs: {} is not empty.".format( - tmp_outputs, tmp_inputs, external_outputs - ) - assert ( - len(external_outputs - tmp_outputs) == 0 - ), "external_outputs: {} - tmp_outputs: {} is not empty.".format( - external_outputs, tmp_outputs - ) - - return [tmp_inputs, tmp_outputs, external_inputs, external_outputs] + return [tmp_inputs, tmp_outputs, external_inputs] def _collect_info( - output_op_map: Dict[str, Set[Operator]], + subgraph_info_map: Dict[str, SubgraphInfo], all_ops: Set[Operator], sorted_graph: List[Tensor], ) -> List[FusedElementwiseInfo]: @@ -241,9 +235,10 @@ def _collect_info( their external input/output, serving as input/output of fused_elementwise op. """ info_list = [] - for op_set in output_op_map.values(): + for subgraph_info in subgraph_info_map.values(): # Toposort the op_set into op_list # because fuse_elementwise stores elementwise ops in topological order + op_set = subgraph_info.partitioned_ops topo_set = set() op_list = [] for tensor in sorted_graph: @@ -259,8 +254,13 @@ def _collect_info( ), "Unable to find topological order of op list for fused_elementwise!" # Get all inputs/outputs of elementwise ops and their external input/output, # which will serve as input/output of fused_elementwise op. - inputs_outputs = _get_inputs_outputs(op_list, all_ops) - fused_op_info = FusedElementwiseInfo(op_list, *inputs_outputs) + tmp_inputs, tmp_outputs, external_inputs = _get_inputs_outputs(op_list, all_ops) + # Note external outputs were generated earlier because we need to group + # them by their shapes. + external_outputs = subgraph_info.external_outputs + fused_op_info = FusedElementwiseInfo( + op_list, tmp_inputs, tmp_outputs, external_inputs, external_outputs + ) info_list.append(fused_op_info) return info_list @@ -321,9 +321,9 @@ def fuse_elementwise(sorted_graph: List[Tensor], workdir: str = None) -> List[Te for ops in to_be_fused_op_groups: # Partition subgraph based on output shape. - output_op_map = _partition_subgraphs(ops) + subgraph_info_map = _partition_subgraphs(ops) # Collect information to create fuse ops. - info_list = _collect_info(output_op_map, ops, sorted_graph) + info_list = _collect_info(subgraph_info_map, ops, sorted_graph) # Create fuse ops. _create_fuse_ops(info_list) @@ -353,10 +353,9 @@ def process_singleton_elementwise( for ops in to_be_fused_op_groups: # Partition subgraph based on output shape. - # output_op_map = {op._attrs["op"]: set(op) for op in ops} - output_op_map = _partition_subgraphs(ops) + subgraph_info_map = _partition_subgraphs(ops) # Collect information to create fuse ops. - info_list = _collect_info(output_op_map, set(ops), sorted_graph) + info_list = _collect_info(subgraph_info_map, set(ops), sorted_graph) # Create fuse ops. _create_fuse_ops(info_list) diff --git a/tests/unittest/ops/test_fused_elementwise_broadcast.py b/tests/unittest/ops/test_fused_elementwise_broadcast.py index d0432bc46..9fb778480 100644 --- a/tests/unittest/ops/test_fused_elementwise_broadcast.py +++ b/tests/unittest/ops/test_fused_elementwise_broadcast.py @@ -25,9 +25,14 @@ from aitemplate.compiler.ops.common.epilogue import FuncEnum from aitemplate.frontend import Tensor from aitemplate.testing import detect_target -from aitemplate.testing.test_utils import get_random_torch_tensor +from aitemplate.testing.test_utils import ( + get_random_torch_tensor, + get_torch_empty_tensor, +) from aitemplate.utils import graph_utils, shape_utils +from parameterized import parameterized + class FusedElementwiseBroadcastTestCase(unittest.TestCase): @classmethod @@ -957,6 +962,106 @@ def test_vectorization_fp32(self): dtype="float", ) + @parameterized.expand([("float16"), ("float")]) + def test_fused_elementwise_broadcast_with_skip_connection(self, dtype): + r""" + X0 X1 (8) X2 (1) X3 + \ / \ / + Div_0 (R0) Sub_1 (R1) + \ | X4 (-1) + \ | / + \ Mul_2 (R2) + \ / \ + \ / \ + Add_3 (R3) \ + | \ + Softmax_4 (R4) / + \ / + \ / + \ / + Add_5 (R5) (output) + + X0 ([1,12,512,512]) and X3 ([1,1,1,512]) have different but broadcastable shapes. + """ + target = detect_target() + if dtype == "float" and target.name == "rocm": + self.skipTest("float tensors not supported by rocm") + shape0 = [1, 12, 512, 512] + shape1 = [1, 1, 1, 512] + X0 = Tensor( + shape=shape0, + dtype=dtype, + name="X0", + is_input=True, + ) + X1 = Tensor( + shape=[], + dtype=dtype, + name="X1", + value=8.0, + ) + X2 = Tensor( + shape=[], + dtype=dtype, + name="X2", + value=1.0, + ) + X3 = Tensor( + shape=shape1, + dtype=dtype, + name="X3", + is_input=True, + ) + X4 = Tensor( + shape=[], + dtype=dtype, + name="X4", + value=-1.0, + ) + + R0 = ops.elementwise(FuncEnum.DIV)(X0, X1) # Div_0 + R1 = ops.elementwise(FuncEnum.SUB)(X2, X3) # Sub_1 + R2 = ops.elementwise(FuncEnum.MUL)(R1, X4) # Mul_2 + R3 = ops.elementwise(FuncEnum.ADD)(R0, R2) # Add_3 + R4 = ops.softmax()(R3, -1) # Softmax_4 + R5 = ops.elementwise(FuncEnum.ADD)(R4, R2) # Add_5 + R5._attrs["name"] = "R5" + R5._attrs["is_output"] = True + + module = compile_model( + [R5], + target, + "./tmp", + f"test_fused_elementwise_broadcast_with_skip_connection_{dtype}", + ) + debug_sorted_graph = module.debug_sorted_graph + sorted_ops = graph_utils.get_sorted_ops(debug_sorted_graph) + self.assertEqual(len(sorted_ops), 4) + + x0_pt = get_random_torch_tensor(shape0, dtype) + x3_pt = get_random_torch_tensor(shape1, dtype) + + r0_pt = x0_pt / 8.0 + r1_pt = 1.0 - x3_pt + r2_pt = r1_pt * (-1.0) + r3_pt = r0_pt + r2_pt + r4_pt = torch.nn.functional.softmax(r3_pt, -1) + r5_pt = r4_pt + r2_pt + + r5 = get_torch_empty_tensor(x0_pt.shape, dtype) + + input_name_to_idx_mapping = module.get_input_name_to_index_map() + inputs = [None] * len(input_name_to_idx_mapping) + input_name_to_pt_mapping = { + "X0": x0_pt, + "X3": x3_pt, + } + for input_name, pt in input_name_to_pt_mapping.items(): + inputs[input_name_to_idx_mapping[input_name]] = pt + module.run_with_tensors(inputs, {"R5": r5}) + + self.assertTrue(torch.allclose(r5, r5_pt, atol=1e-2, rtol=1e-2)) + if __name__ == "__main__": unittest.main() From 0fb3ba5b10be0d76288c2ec15fb27a69db1f0b4d Mon Sep 17 00:00:00 2001 From: Yanming Wang Date: Wed, 26 Jul 2023 03:01:32 +0000 Subject: [PATCH 2/2] Refactor to avoid conflict with another test case --- .../aitemplate/compiler/transform/fuse_ops.py | 28 +++++++++++++++---- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/python/aitemplate/compiler/transform/fuse_ops.py b/python/aitemplate/compiler/transform/fuse_ops.py index 5685736fe..b27ec4c40 100644 --- a/python/aitemplate/compiler/transform/fuse_ops.py +++ b/python/aitemplate/compiler/transform/fuse_ops.py @@ -192,9 +192,11 @@ def _get_inputs_outputs( """ Given ops of a partitioned subgraph based on output shape, and ops of full graph to form a complete graph with fused_elementwise op, returns all inputs/outputs of - the ops and the external input of the subgraph, which will serve as input of fused_elementwise op. + the ops and the external input/output of the subgraph, which will serve as input/output + of fused_elementwise op. """ external_inputs = set() + external_outputs = set() tmp_inputs = set() tmp_outputs = set() @@ -209,6 +211,9 @@ def _get_inputs_outputs( assert op in input_tensor._attrs["dst_ops"] for output_tensor in op._attrs["outputs"]: tmp_outputs.add(output_tensor) + dst_ops = set(output_tensor._attrs["dst_ops"]) + if output_tensor._attrs["is_output"] or len(dst_ops - all_ops) > 0: + external_outputs.add(output_tensor) assert len(output_tensor._attrs["src_ops"]) == 1 assert list(output_tensor._attrs["src_ops"])[0] == op @@ -217,7 +222,18 @@ def _get_inputs_outputs( ), "external_inputs: {} is not equal to tmp_inputs: {} - tmp_outputs: {}.".format( external_inputs, tmp_inputs, tmp_outputs ) - return [tmp_inputs, tmp_outputs, external_inputs] + assert ( + len(tmp_outputs - tmp_inputs - external_outputs) == 0 + ), "tmp_outputs: {} - tmp_inputs: {} - external_outputs: {} is not empty.".format( + tmp_outputs, tmp_inputs, external_outputs + ) + assert ( + len(external_outputs - tmp_outputs) == 0 + ), "external_outputs: {} - tmp_outputs: {} is not empty.".format( + external_outputs, tmp_outputs + ) + + return [tmp_inputs, tmp_outputs, external_inputs, external_outputs] def _collect_info( @@ -254,9 +270,11 @@ def _collect_info( ), "Unable to find topological order of op list for fused_elementwise!" # Get all inputs/outputs of elementwise ops and their external input/output, # which will serve as input/output of fused_elementwise op. - tmp_inputs, tmp_outputs, external_inputs = _get_inputs_outputs(op_list, all_ops) - # Note external outputs were generated earlier because we need to group - # them by their shapes. + tmp_inputs, tmp_outputs, external_inputs, _ = _get_inputs_outputs( + op_list, all_ops + ) + # Use the external outputs we already collected because the external outputs returned by + # _get_inputs_outputs may have different shapes. external_outputs = subgraph_info.external_outputs fused_op_info = FusedElementwiseInfo( op_list, tmp_inputs, tmp_outputs, external_inputs, external_outputs