Skip to content

Commit

Permalink
Make sure fake tensor functions return on proper device (pytorch#3258)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#3258

X-link: facebookresearch/FBGEMM#359

I didnt realize that the device of faketensors matters in abstract functions, but torch.compile will check it in some cases. This small diff adds proper device placement to all fbgemm abstract operators.

Reviewed By: jiawenliu64

Differential Revision: D64667681

fbshipit-source-id: 79b36af21cf8ad867d52beeb59f137368f0a48da
  • Loading branch information
jwfromm authored and facebook-github-bot committed Oct 20, 2024
1 parent f728c94 commit 0c13ab9
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions fbgemm_gpu/experimental/gen_ai/gen_ai/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def f8f8bf16_blockwise_abstract(
return torch.empty(
[M, N],
dtype=torch.bfloat16,
device=XQ.device,
)


Expand All @@ -51,6 +52,7 @@ def f8f8bf16_tensorwise_abstract(
return torch.empty(
[M, N],
dtype=torch.bfloat16,
device=XQ.device,
)


Expand All @@ -69,6 +71,7 @@ def f8f8bf16_rowwise_abstract(
return torch.empty(
[M, N],
dtype=torch.bfloat16,
device=XQ.device,
)


Expand All @@ -83,8 +86,8 @@ def quantize_fp8_per_tensor_abstract(
fp8_dtype = torch.float8_e4m3fnuz
else:
fp8_dtype = torch.float8_e4m3fn
output = torch.empty_like(input, dtype=fp8_dtype)
scale = torch.empty([], dtype=torch.bfloat16)
output = torch.empty_like(input, dtype=fp8_dtype, device=input.device)
scale = torch.empty([], dtype=torch.bfloat16, device=input.device)
return output, scale


Expand All @@ -100,8 +103,8 @@ def quantize_fp8_per_row_abstract(
fp8_dtype = torch.float8_e4m3fnuz
else:
fp8_dtype = torch.float8_e4m3fn
output = torch.empty_like(input, dtype=fp8_dtype)
scale = torch.empty([], dtype=torch.bfloat16)
output = torch.empty_like(input, dtype=fp8_dtype, device=input.device)
scale = torch.empty([], dtype=torch.bfloat16, device=input.device)
return output, scale


Expand All @@ -115,8 +118,8 @@ def quantize_fp8_per_col_abstract(
fp8_dtype = torch.float8_e4m3fnuz
else:
fp8_dtype = torch.float8_e4m3fn
output = torch.empty_like(input, dtype=fp8_dtype)
scale = torch.empty([], dtype=torch.bfloat16)
output = torch.empty_like(input, dtype=fp8_dtype, device=input.device)
scale = torch.empty([], dtype=torch.bfloat16, device=input.device)
return output, scale


Expand All @@ -135,6 +138,7 @@ def i8i8bf16_abstract(
return torch.empty(
[M, N],
dtype=torch.bfloat16,
device=XQ.device,
)

@torch.library.register_fake("fbgemm::f8f8bf16")
Expand All @@ -149,6 +153,7 @@ def f8f8bf16_abstract(
return torch.empty(
[M, N],
dtype=torch.bfloat16,
device=XQ.device,
)

@torch.library.register_fake("fbgemm::f8f8bf16_cublas")
Expand All @@ -165,6 +170,7 @@ def f8f8bf16_cublas_abstract(
return torch.empty(
[M, N],
dtype=torch.bfloat16,
device=A.device,
)

@torch.library.register_fake("fbgemm::f8f8bf16_rowwise_batched")
Expand All @@ -182,6 +188,7 @@ def f8f8bf16_rowwise_batched_abstract(
return torch.empty(
[M, N],
dtype=torch.bfloat16,
device=XQ.device,
)

@torch.library.register_fake("fbgemm::f8i4bf16_rowwise")
Expand All @@ -197,6 +204,7 @@ def f8i4bf16_rowwise_abstract(
return torch.empty(
[M, N],
dtype=torch.bfloat16,
device=XQ.device,
)

@torch.library.register_fake("fbgemm::bf16i4bf16_rowwise")
Expand All @@ -211,6 +219,7 @@ def bf16i4bf16_rowwise_abstract(
return torch.empty(
[M, N],
dtype=torch.bfloat16,
device=X.device,
)

@torch.library.register_fake("fbgemm::bf16i4bf16_rowwise_batched")
Expand All @@ -225,4 +234,5 @@ def bf16i4bf16_rowwise_batched_abstract(
return torch.empty(
[M, N],
dtype=torch.bfloat16,
device=X.device,
)

0 comments on commit 0c13ab9

Please sign in to comment.