diff --git a/python/aitemplate/backend/codegen.py b/python/aitemplate/backend/codegen.py index e3b79cafe..8353fe93f 100644 --- a/python/aitemplate/backend/codegen.py +++ b/python/aitemplate/backend/codegen.py @@ -631,19 +631,25 @@ def _codegen_output_tensor(self, tensor: Tensor) -> None: if is_param: self._codegen_param_setup(tensor) self.device_to_device_copies.append(device_copy(tensor, tensor, output_idx)) - elif external_tensor is not None: - # Special view cases for outputs; we can hit this case if the output - # is a view of a constant, input, or another output. + elif is_view or external_tensor is not None: assert ( is_view - ), f"orig_tensor is not None, but node {name} is not marked as a view! Node: {tensor}" + ), f"External tensor is not None, but node {name} is not marked as a view! Node: {tensor}" + view_name = view._attrs["name"] + self.set_inputs.append(set_value(name, view_name)) self.set_inputs.append( - check_not_null(tensor, output_idx, skip_if_lower_bound_is_zero=True) + check_not_null(tensor, skip_if_lower_bound_is_zero=True) ) - self.set_inputs.append(set_value(name, view._attrs["name"])) - self.device_to_device_copies.append( - device_copy(tensor, external_tensor, output_idx) + + view_assigned_to_another_output = ( + self._get_output_idx(view_name) != output_idx ) + if external_tensor or view_assigned_to_another_output: + # Copy from original tensor so this output can also have the data. + original_tensor = external_tensor if external_tensor else view + self.device_to_device_copies.append( + device_copy(tensor, original_tensor, output_idx) + ) elif is_input: # Inputs that are also outputs require an extra copy self.set_inputs.append( @@ -655,11 +661,6 @@ def _codegen_output_tensor(self, tensor: Tensor) -> None: self._record_param_tensor_info(tensor, self.input_idx) self.device_to_device_copies.append(device_copy(tensor, tensor, output_idx)) self.input_idx += 1 - elif is_view: - self.set_inputs.append(set_value(name, view._attrs["name"])) - self.set_inputs.append( - check_not_null(tensor, skip_if_lower_bound_is_zero=True) - ) else: self.set_inputs.append( set_value( diff --git a/tests/unittest/backend/test_codegen_output_tensor.py b/tests/unittest/backend/test_codegen_output_tensor.py index b2589ecc5..d290edad7 100644 --- a/tests/unittest/backend/test_codegen_output_tensor.py +++ b/tests/unittest/backend/test_codegen_output_tensor.py @@ -5,13 +5,15 @@ import unittest from typing import Sequence +import torch + from aitemplate.backend.codegen import device_copy, set_value from aitemplate.compiler import compile_model, ops from aitemplate.compiler.ops.common.epilogue import FuncEnum from aitemplate.testing import detect_target -from aitemplate.testing.test_utils import gen_input_tensor +from aitemplate.testing.test_utils import gen_input_tensor, get_random_torch_tensor class TestCodegenOutput(unittest.TestCase): @@ -89,6 +91,8 @@ def test_double_alias(self, test_name="double_alias"): Case: Two outputs are a view of the same tensor. Graph: ( gelu ) <--view-- ( output_0 ) <--view-- ( output_1 ) + Expect: If a tensor is a view for multiple outputs, then it's assigned to + only one of the outputs' ptrs. We expect D2D copies for the remaining outputs. """ # AIT, two outputs. x = gen_input_tensor(shape=self.SHAPE, name="input_x") @@ -101,7 +105,7 @@ def test_double_alias(self, test_name="double_alias"): output1._attrs["is_output"] = True output1._attrs["name"] = "output_1" - compile_model( + model = compile_model( [output0, output1], detect_target(), self.WORKDIR, @@ -122,11 +126,26 @@ def test_double_alias(self, test_name="double_alias"): expected_codegen = ( set_value("output_0", view_name), set_value("output_1", view_name), + device_copy(output0, view, dst_idx=1), ) self._assert_codegen_exists( test_name, expected_codegen, self.MODEL_GENERATED_FILE ) + # This is an edge case -- test the accuracy. + x_pt = get_random_torch_tensor(self.SHAPE) + gelu_pt = torch.nn.functional.gelu(x_pt) + output0_pt = torch.unsqueeze(gelu_pt, dim=0) + output1_pt = torch.flatten(gelu_pt) + output0_ait = torch.empty_like(output0_pt) + output1_ait = torch.empty_like(output1_pt) + + model.run_with_tensors( + {"input_x": x_pt}, {"output_0": output0_ait, "output_1": output1_ait} + ) + self.assertTrue(torch.allclose(output0_ait, output0_pt, atol=1e-2, rtol=1e-2)) + self.assertTrue(torch.allclose(output1_ait, output1_pt, atol=1e-2, rtol=1e-2)) + def test_output_is_view_of_output(self, test_name="output_is_view_of_output"): """ Case: An output is a view of an output.