diff --git a/fbgemm_gpu/experimental/gen_ai/gen_ai/quantize_ops.py b/fbgemm_gpu/experimental/gen_ai/gen_ai/quantize_ops.py index 577a301a1..3e716affc 100644 --- a/fbgemm_gpu/experimental/gen_ai/gen_ai/quantize_ops.py +++ b/fbgemm_gpu/experimental/gen_ai/gen_ai/quantize_ops.py @@ -36,6 +36,7 @@ def f8f8bf16_blockwise_abstract( return torch.empty( [M, N], dtype=torch.bfloat16, + device=XQ.device, ) @@ -51,6 +52,7 @@ def f8f8bf16_tensorwise_abstract( return torch.empty( [M, N], dtype=torch.bfloat16, + device=XQ.device, ) @@ -69,6 +71,7 @@ def f8f8bf16_rowwise_abstract( return torch.empty( [M, N], dtype=torch.bfloat16, + device=XQ.device, ) @@ -83,8 +86,8 @@ def quantize_fp8_per_tensor_abstract( fp8_dtype = torch.float8_e4m3fnuz else: fp8_dtype = torch.float8_e4m3fn - output = torch.empty_like(input, dtype=fp8_dtype) - scale = torch.empty([], dtype=torch.bfloat16) + output = torch.empty_like(input, dtype=fp8_dtype, device=input.device) + scale = torch.empty([], dtype=torch.bfloat16, device=input.device) return output, scale @@ -100,8 +103,8 @@ def quantize_fp8_per_row_abstract( fp8_dtype = torch.float8_e4m3fnuz else: fp8_dtype = torch.float8_e4m3fn - output = torch.empty_like(input, dtype=fp8_dtype) - scale = torch.empty([], dtype=torch.bfloat16) + output = torch.empty_like(input, dtype=fp8_dtype, device=input.device) + scale = torch.empty([], dtype=torch.bfloat16, device=input.device) return output, scale @@ -115,8 +118,8 @@ def quantize_fp8_per_col_abstract( fp8_dtype = torch.float8_e4m3fnuz else: fp8_dtype = torch.float8_e4m3fn - output = torch.empty_like(input, dtype=fp8_dtype) - scale = torch.empty([], dtype=torch.bfloat16) + output = torch.empty_like(input, dtype=fp8_dtype, device=input.device) + scale = torch.empty([], dtype=torch.bfloat16, device=input.device) return output, scale @@ -135,6 +138,7 @@ def i8i8bf16_abstract( return torch.empty( [M, N], dtype=torch.bfloat16, + device=XQ.device, ) @torch.library.register_fake("fbgemm::f8f8bf16") @@ -149,6 +153,7 @@ def f8f8bf16_abstract( return torch.empty( [M, N], dtype=torch.bfloat16, + device=XQ.device, ) @torch.library.register_fake("fbgemm::f8f8bf16_cublas") @@ -165,6 +170,7 @@ def f8f8bf16_cublas_abstract( return torch.empty( [M, N], dtype=torch.bfloat16, + device=A.device, ) @torch.library.register_fake("fbgemm::f8f8bf16_rowwise_batched") @@ -182,6 +188,7 @@ def f8f8bf16_rowwise_batched_abstract( return torch.empty( [M, N], dtype=torch.bfloat16, + device=XQ.device, ) @torch.library.register_fake("fbgemm::f8i4bf16_rowwise") @@ -197,6 +204,7 @@ def f8i4bf16_rowwise_abstract( return torch.empty( [M, N], dtype=torch.bfloat16, + device=XQ.device, ) @torch.library.register_fake("fbgemm::bf16i4bf16_rowwise") @@ -211,6 +219,7 @@ def bf16i4bf16_rowwise_abstract( return torch.empty( [M, N], dtype=torch.bfloat16, + device=X.device, ) @torch.library.register_fake("fbgemm::bf16i4bf16_rowwise_batched") @@ -225,4 +234,5 @@ def bf16i4bf16_rowwise_batched_abstract( return torch.empty( [M, N], dtype=torch.bfloat16, + device=X.device, )