Skip to content

Commit

Permalink
open-source SLL jagged_dense_elementwise_mul_jagged_out (#3354)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#446


as title, also added CPU and Meta implementation

Differential Revision: D65827782
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Nov 14, 2024
1 parent fd8230c commit 9125eae
Show file tree
Hide file tree
Showing 4 changed files with 561 additions and 0 deletions.
44 changes: 44 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/sll/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,18 @@
cpu_dense_jagged_cat_jagged_out,
cpu_jagged2_to_padded_dense,
cpu_jagged_dense_bmm,
cpu_jagged_dense_elementwise_mul_jagged_out,
cpu_jagged_jagged_bmm,
cpu_jagged_self_substraction_jagged_out,
meta_jagged_dense_elementwise_mul_jagged_out,
meta_jagged_self_substraction_jagged_out,
)

from fbgemm_gpu.sll.triton_sll import ( # noqa F401
dense_jagged_cat_jagged_out,
jagged2_to_padded_dense,
jagged_dense_bmm,
jagged_dense_elementwise_mul_jagged_out,
jagged_jagged_bmm,
triton_jagged_self_substraction_jagged_out,
)
Expand Down Expand Up @@ -119,6 +122,17 @@ def op_registeration(
"""
)

if "fbgemm::sll_jagged_dense_elementwise_mul_jagged_out" not in torch.library._defs:
lib.define(
"""sll_jagged_dense_elementwise_mul_jagged_out(
Tensor x,
Tensor y,
Tensor x_seq_lengths,
Tensor x_offsets,
int max_seq_len
) -> Tensor
"""
)

op_registeration(lib, "sll_jagged_dense_bmm", jagged_dense_bmm, "CUDA")
op_registeration(lib, "sll_jagged_dense_bmm", jagged_dense_bmm, "AutogradCUDA")
Expand Down Expand Up @@ -160,3 +174,33 @@ def op_registeration(
op_registeration(
lib, "sll_jagged2_to_padded_dense", cpu_jagged2_to_padded_dense, "AutogradCPU"
)
op_registeration(
lib,
"sll_jagged_dense_elementwise_mul_jagged_out",
jagged_dense_elementwise_mul_jagged_out,
"CUDA",
)
op_registeration(
lib,
"sll_jagged_dense_elementwise_mul_jagged_out",
jagged_dense_elementwise_mul_jagged_out,
"AutogradCUDA",
)
op_registeration(
lib,
"sll_jagged_dense_elementwise_mul_jagged_out",
cpu_jagged_dense_elementwise_mul_jagged_out,
"CPU",
)
op_registeration(
lib,
"sll_jagged_dense_elementwise_mul_jagged_out",
cpu_jagged_dense_elementwise_mul_jagged_out,
"AutogradCPU",
)
op_registeration(
lib,
"sll_jagged_dense_elementwise_mul_jagged_out",
meta_jagged_dense_elementwise_mul_jagged_out,
"Meta",
)
151 changes: 151 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/sll/cpu_sll.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,3 +255,154 @@ def cpu_jagged2_to_padded_dense(
dense_output[b, 0:Ni, 0:Ni] = values[begin:end].view(Ni, Ni)

return dense_output


class CPUJaggedDenseElementwiseMul(torch.autograd.Function):
"""
Compute elementwise multiplication between jagged tensor and dense tensor.
z = x * y
x: [sum_B(L_i)]
y: dense tensor
z: [sum_B(L_i)]
"""

@staticmethod
def jagged_dense_elementwise_mul_jagged_out(
jagged: torch.Tensor,
dense: torch.Tensor,
seq_lengths: torch.Tensor,
offsets: torch.Tensor,
max_seq_len: int,
) -> torch.Tensor:
out = torch.empty_like(jagged)
for i in range(seq_lengths.size(0)):
if seq_lengths[i] == 0:
continue
a = jagged[offsets[i] : offsets[i + 1]]
a = a.view(int(seq_lengths[i]), int(seq_lengths[i]))
out[offsets[i] : offsets[i + 1]] = (
a * dense[0 : seq_lengths[i], 0 : seq_lengths[i]]
).flatten()
return out

@staticmethod
# pyre-fixme
def forward(
ctx, # pyre-ignore [2]
x: torch.Tensor,
y: torch.Tensor,
x_seq_lengths: torch.Tensor,
x_offsets: torch.Tensor,
max_seq_len: int,
):
ctx.max_seq_len = max_seq_len

ctx.save_for_backward(
x,
y,
x_seq_lengths,
x_offsets,
)

return CPUJaggedDenseElementwiseMul.jagged_dense_elementwise_mul_jagged_out(
x,
y,
x_seq_lengths,
x_offsets,
max_seq_len,
)

@staticmethod
# pyre-fixme
def backward(ctx, grad_output: torch.Tensor):
(
x,
y,
x_seq_lengths,
x_offsets,
) = ctx.saved_tensors

grad_x = CPUJaggedDenseElementwiseMul.jagged_dense_elementwise_mul_jagged_out(
grad_output,
y,
x_seq_lengths,
x_offsets,
ctx.max_seq_len,
)

return grad_x, None, None, None, None


def cpu_jagged_dense_elementwise_mul_jagged_out(
x: torch.Tensor,
y: torch.Tensor,
x_seq_lengths: torch.Tensor,
x_offsets: torch.Tensor,
max_seq_len: int,
) -> torch.Tensor:
return CPUJaggedDenseElementwiseMul.apply(
x,
y,
x_seq_lengths,
x_offsets,
max_seq_len,
)


class MetaJaggedDenseElementwiseMul(torch.autograd.Function):
@staticmethod
# pyre-fixme
def forward(
ctx, # pyre-ignore [2]
x: torch.Tensor,
y: torch.Tensor,
x_seq_lengths: torch.Tensor,
x_offsets: torch.Tensor,
max_seq_len: int,
) -> torch.Tensor:
ctx.max_seq_len = max_seq_len

ctx.save_for_backward(
x,
y,
x_seq_lengths,
x_offsets,
)

total_L = x.size(0)
jagged_C = torch.zeros((total_L), device=x.device, dtype=x.dtype)

return jagged_C

@staticmethod
# pyre-fixme
def backward(ctx, grad_output: torch.Tensor):
(
x,
y,
x_seq_lengths,
x_offsets,
) = ctx.saved_tensors

total_L = grad_output.size(0)
jagged_C = torch.zeros(
(total_L), device=grad_output.device, dtype=grad_output.dtype
)

return jagged_C, None, None, None, None


def meta_jagged_dense_elementwise_mul_jagged_out(
x: torch.Tensor,
y: torch.Tensor,
x_seq_lengths: torch.Tensor,
x_offsets: torch.Tensor,
max_seq_len: int,
) -> torch.Tensor:
return MetaJaggedDenseElementwiseMul.apply(
x,
y,
x_seq_lengths,
x_offsets,
max_seq_len,
)
Loading

0 comments on commit 9125eae

Please sign in to comment.