Skip to content

Commit

Permalink
Add/update documentation (#883)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #883

Found some outdated docstrings and Python types while reading codebase.

Reviewed By: chenyang78, aakhundov

Differential Revision: D48135593

fbshipit-source-id: e39eb11aa38b25a8a172b2dadc4d0f7e492a75f5
  • Loading branch information
ColinPeppler authored and facebook-github-bot committed Aug 14, 2023
1 parent acfc815 commit 664b25d
Show file tree
Hide file tree
Showing 8 changed files with 34 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

class gemm_rcr_bias(gemm_rcr):
"""GEMM Specialization: GEMM_RCR(A, B) + Bias
A[RowMajor], B[ColMajor], Bias[RowMajor], C[RowMajor]
This operator is equivalent to the following pytorch code:
Expand Down
14 changes: 14 additions & 0 deletions python/aitemplate/compiler/ops/gemm_universal/gemm_rrr_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,20 @@


class gemm_rrr_bias(gemm_rrr):
"""GEMM Specialization: GEMM_RRR(A, B) + Bias
A[RowMajor], B[RowMajor], Bias[RowMajor], C[RowMajor]
This operator is equivalent to the following pytorch code:
.. highlight:: python
.. code-block:: python
A = torch.randn(M, K).cuda().half()
B = torch.randn(K, N).cuda().half()
Bias = torch.randn(N).cuda().half()
y = torch.nn.functional.linear(A, B.t(), bias=Bias)
"""

def __init__(self):
super().__init__()
self._attrs["op"] = "gemm_rrr_bias"
Expand Down
2 changes: 1 addition & 1 deletion python/aitemplate/compiler/ops/tensor/concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def _infer_shapes(self, inputs: List[Tensor], dim) -> List[IntVar]:
output_shape.append(output_dim)
return output_shape

def __call__(self, inputs: List[Tensor], dim=0) -> List[Tensor]:
def __call__(self, inputs: List[Tensor], dim=0) -> Tensor:
self._attrs["inputs"] = list(inputs)
self._attrs["input_accessors"] = [
TensorAccessor(t) for t in self._attrs["inputs"]
Expand Down
4 changes: 2 additions & 2 deletions python/aitemplate/compiler/ops/tensor/dynamic_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,9 @@ def __call__(
----------
x : Tensor
Input tensor.
start_indices : List[int]
start_indices : List[Union[IntVar, IntVarTensor, Optional[int]]]
Similar to PyTorch and numpy, indices can be negative
end_indices : List[int]
end_indices : List[Union[IntVar, IntVarTensor, Optional[int]]]
end_index is not included. Similar to PyTorch and
numpy, indices can be negative.
Expand Down
2 changes: 1 addition & 1 deletion python/aitemplate/compiler/ops/tensor/permute0213.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def _infer_shapes(self, x: Tensor) -> List[IntVar]:
x_shape = x._attrs["shape"]
return [x_shape[0], x_shape[2], x_shape[1], x_shape[3]]

def __call__(self, x: Tensor) -> List[Tensor]:
def __call__(self, x: Tensor) -> Tensor:
"""
Parameters
----------
Expand Down
2 changes: 1 addition & 1 deletion python/aitemplate/compiler/ops/tensor/permute102.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def _infer_shapes(self, x: Tensor) -> List[IntVar]:
x_shape = x._attrs["shape"]
return [x_shape[1], x_shape[0], x_shape[2]]

def __call__(self, x: Tensor) -> List[Tensor]:
def __call__(self, x: Tensor) -> Tensor:
"""
Parameters
----------
Expand Down
8 changes: 4 additions & 4 deletions python/aitemplate/compiler/ops/tensor/permute210.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from aitemplate import backend

from aitemplate.backend import registry
from aitemplate.compiler.base import Operator, Tensor
from aitemplate.compiler.base import IntVar, Operator, Tensor

# pylint: disable=C0103,W0221

Expand Down Expand Up @@ -55,7 +55,7 @@ def __init__(self):
super().__init__()
self._attrs["op"] = "permute210"

def _infer_shapes(self, x: Tensor):
def _infer_shapes(self, x: Tensor) -> List[IntVar]:
"""Infers shapes for permute210.
Parameters
Expand All @@ -64,15 +64,15 @@ def _infer_shapes(self, x: Tensor):
Returns
------
Tensor
List[IntVar]
Inferred output 3d tensor with input shape.
Because its a permute210 operation,
Out.d0=In.d2, Out.d2=In.d0.
"""
x_shape = x._attrs["shape"]
return [x_shape[2], x_shape[1], x_shape[0]]

def __call__(self, x: Tensor) -> List[Tensor]:
def __call__(self, x: Tensor) -> Tensor:
"""
Return the output tensor of permute210
Expand Down
10 changes: 10 additions & 0 deletions python/aitemplate/compiler/transform/mark_param_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,16 @@


def mark_special_views(sorted_graph: List[Tensor]):
"""
Associate each tensor with an external tensor if any of the conditions are true:
1. The tensor is a view-of-a-view of an external tensor.
2. The tensor is a view of an input, constant or output tensor (i.e. external tensor).
Parameters
----------
sorted_graph : List[Tensor]
The graph to mutate.
"""
for node in sorted_graph:
view = node._attrs["is_view_of"]
if view is None:
Expand Down

0 comments on commit 664b25d

Please sign in to comment.