diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_blockwise_gemm.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_blockwise_gemm.hip index 17d460483..94c9dbe13 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_blockwise_gemm.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_blockwise_gemm.hip @@ -12,11 +12,13 @@ #include #include +#if !defined(USE_ROCM) #include +#endif #include #if defined(USE_ROCM) - +#include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_gemm.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_gemm.hip index afb306432..b57563df5 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_gemm.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_gemm.hip @@ -15,11 +15,13 @@ #include #include +#if !defined(USE_ROCM) #include +#endif #include #if defined(USE_ROCM) - +#include #include "kernels/fp8_rowwise_kernel_manifest.h" namespace fbgemm_gpu { diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_tensorwise_gemm.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_tensorwise_gemm.hip index 6170675a2..67966a43c 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_tensorwise_gemm.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_tensorwise_gemm.hip @@ -12,11 +12,13 @@ #include #include +#if !defined(USE_ROCM) #include +#endif #include #if defined(USE_ROCM) - +#include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_common.h b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_common.h index 717720580..da4087989 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_common.h +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_common.h @@ -12,10 +12,13 @@ #include #include +#if !defined(USE_ROCM) #include +#endif #include #if defined(USE_ROCM) +#include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"