diff --git a/python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias.py b/python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias.py index 1b327b93c..9df2c9222 100644 --- a/python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias.py +++ b/python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias.py @@ -24,6 +24,7 @@ class gemm_rcr_bias(gemm_rcr): """GEMM Specialization: GEMM_RCR(A, B) + Bias + A[RowMajor], B[ColMajor], Bias[RowMajor], C[RowMajor] This operator is equivalent to the following pytorch code: diff --git a/python/aitemplate/compiler/ops/gemm_universal/gemm_rrr_bias.py b/python/aitemplate/compiler/ops/gemm_universal/gemm_rrr_bias.py index 0eff459b8..a8c052d87 100644 --- a/python/aitemplate/compiler/ops/gemm_universal/gemm_rrr_bias.py +++ b/python/aitemplate/compiler/ops/gemm_universal/gemm_rrr_bias.py @@ -23,6 +23,20 @@ class gemm_rrr_bias(gemm_rrr): + """GEMM Specialization: GEMM_RRR(A, B) + Bias + A[RowMajor], B[RowMajor], Bias[RowMajor], C[RowMajor] + + This operator is equivalent to the following pytorch code: + + .. highlight:: python + .. code-block:: python + A = torch.randn(M, K).cuda().half() + B = torch.randn(K, N).cuda().half() + Bias = torch.randn(N).cuda().half() + + y = torch.nn.functional.linear(A, B.t(), bias=Bias) + """ + def __init__(self): super().__init__() self._attrs["op"] = "gemm_rrr_bias" diff --git a/python/aitemplate/compiler/ops/tensor/concatenate.py b/python/aitemplate/compiler/ops/tensor/concatenate.py index eaa544863..1606d1ffd 100644 --- a/python/aitemplate/compiler/ops/tensor/concatenate.py +++ b/python/aitemplate/compiler/ops/tensor/concatenate.py @@ -116,7 +116,7 @@ def _infer_shapes(self, inputs: List[Tensor], dim) -> List[IntVar]: output_shape.append(output_dim) return output_shape - def __call__(self, inputs: List[Tensor], dim=0) -> List[Tensor]: + def __call__(self, inputs: List[Tensor], dim=0) -> Tensor: self._attrs["inputs"] = list(inputs) self._attrs["input_accessors"] = [ TensorAccessor(t) for t in self._attrs["inputs"] diff --git a/python/aitemplate/compiler/ops/tensor/dynamic_slice.py b/python/aitemplate/compiler/ops/tensor/dynamic_slice.py index 547382913..4055a2c5d 100644 --- a/python/aitemplate/compiler/ops/tensor/dynamic_slice.py +++ b/python/aitemplate/compiler/ops/tensor/dynamic_slice.py @@ -133,9 +133,9 @@ def __call__( ---------- x : Tensor Input tensor. - start_indices : List[int] + start_indices : List[Union[IntVar, IntVarTensor, Optional[int]]] Similar to PyTorch and numpy, indices can be negative - end_indices : List[int] + end_indices : List[Union[IntVar, IntVarTensor, Optional[int]]] end_index is not included. Similar to PyTorch and numpy, indices can be negative. diff --git a/python/aitemplate/compiler/ops/tensor/permute0213.py b/python/aitemplate/compiler/ops/tensor/permute0213.py index b6b33c10d..3616bdf73 100644 --- a/python/aitemplate/compiler/ops/tensor/permute0213.py +++ b/python/aitemplate/compiler/ops/tensor/permute0213.py @@ -61,7 +61,7 @@ def _infer_shapes(self, x: Tensor) -> List[IntVar]: x_shape = x._attrs["shape"] return [x_shape[0], x_shape[2], x_shape[1], x_shape[3]] - def __call__(self, x: Tensor) -> List[Tensor]: + def __call__(self, x: Tensor) -> Tensor: """ Parameters ---------- diff --git a/python/aitemplate/compiler/ops/tensor/permute102.py b/python/aitemplate/compiler/ops/tensor/permute102.py index 3c9674186..f6d0af738 100644 --- a/python/aitemplate/compiler/ops/tensor/permute102.py +++ b/python/aitemplate/compiler/ops/tensor/permute102.py @@ -61,7 +61,7 @@ def _infer_shapes(self, x: Tensor) -> List[IntVar]: x_shape = x._attrs["shape"] return [x_shape[1], x_shape[0], x_shape[2]] - def __call__(self, x: Tensor) -> List[Tensor]: + def __call__(self, x: Tensor) -> Tensor: """ Parameters ---------- diff --git a/python/aitemplate/compiler/ops/tensor/permute210.py b/python/aitemplate/compiler/ops/tensor/permute210.py index 70abe3baf..d177073ef 100644 --- a/python/aitemplate/compiler/ops/tensor/permute210.py +++ b/python/aitemplate/compiler/ops/tensor/permute210.py @@ -21,7 +21,7 @@ from aitemplate import backend from aitemplate.backend import registry -from aitemplate.compiler.base import Operator, Tensor +from aitemplate.compiler.base import IntVar, Operator, Tensor # pylint: disable=C0103,W0221 @@ -55,7 +55,7 @@ def __init__(self): super().__init__() self._attrs["op"] = "permute210" - def _infer_shapes(self, x: Tensor): + def _infer_shapes(self, x: Tensor) -> List[IntVar]: """Infers shapes for permute210. Parameters @@ -64,7 +64,7 @@ def _infer_shapes(self, x: Tensor): Returns ------ - Tensor + List[IntVar] Inferred output 3d tensor with input shape. Because its a permute210 operation, Out.d0=In.d2, Out.d2=In.d0. @@ -72,7 +72,7 @@ def _infer_shapes(self, x: Tensor): x_shape = x._attrs["shape"] return [x_shape[2], x_shape[1], x_shape[0]] - def __call__(self, x: Tensor) -> List[Tensor]: + def __call__(self, x: Tensor) -> Tensor: """ Return the output tensor of permute210 diff --git a/python/aitemplate/compiler/transform/mark_param_tensor.py b/python/aitemplate/compiler/transform/mark_param_tensor.py index 677c1df93..1a6b4dc5c 100644 --- a/python/aitemplate/compiler/transform/mark_param_tensor.py +++ b/python/aitemplate/compiler/transform/mark_param_tensor.py @@ -23,6 +23,16 @@ def mark_special_views(sorted_graph: List[Tensor]): + """ + Associate each tensor with an external tensor if any of the conditions are true: + 1. The tensor is a view-of-a-view of an external tensor. + 2. The tensor is a view of an input, constant or output tensor (i.e. external tensor). + + Parameters + ---------- + sorted_graph : List[Tensor] + The graph to mutate. + """ for node in sorted_graph: view = node._attrs["is_view_of"] if view is None: