diff --git a/fbgemm_gpu/codegen/embedding_forward_quantized_split_nbit_kernel_template.cu b/fbgemm_gpu/codegen/embedding_forward_quantized_split_nbit_kernel_template.cu index e3af7034d..9e087ac37 100644 --- a/fbgemm_gpu/codegen/embedding_forward_quantized_split_nbit_kernel_template.cu +++ b/fbgemm_gpu/codegen/embedding_forward_quantized_split_nbit_kernel_template.cu @@ -309,7 +309,7 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no {% for params in emb_weight_type.template_params %} {% if output_type == 'at::BFloat16' %} -#if !( \ +#if defined(USE_ROCM) || !( \ ((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \ (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))) {% endif %} diff --git a/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh b/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh index 78b931236..29a550abc 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh @@ -1651,9 +1651,9 @@ struct __align__(4) __nv_bfloat162 { }; #endif -#if !( \ - ((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \ - (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))) +#if defined(USE_ROCM) || \ + !(((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \ + (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))) struct __align__(8) bfloat16_4 { __host__ __device__ bfloat16_4() {} __nv_bfloat162 vals[2]; @@ -1771,9 +1771,9 @@ static DEVICE_INLINE void quantize_float_store( *output = input; } -#if !( \ - ((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \ - (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))) +#if defined(USE_ROCM) || \ + !(((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \ + (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))) DEVICE_INLINE __nv_bfloat16 to_bfloat16(float v) { return __float2bfloat16(v); } @@ -2347,9 +2347,9 @@ struct VecNT<1, PrimitiveType::FP> { *reinterpret_cast<__half*>(output_ptr) = val; } -#if !( \ - ((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \ - (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))) +#if defined(USE_ROCM) || \ + !(((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \ + (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))) DEVICE_INLINE void store( at::BFloat16* output_ptr, const int num_valid_outputs = 1) { @@ -2440,9 +2440,9 @@ struct VecNT<2, PrimitiveType::FP> { } } -#if !( \ - ((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \ - (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))) +#if defined(USE_ROCM) || \ + !(((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \ + (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))) DEVICE_INLINE void store( at::BFloat16* output_ptr, const int num_valid_outputs = 2) { @@ -2578,9 +2578,9 @@ struct VecNT<4, PrimitiveType::FP> { } } -#if !( \ - ((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \ - (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))) +#if defined(USE_ROCM) || \ + !(((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \ + (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))) DEVICE_INLINE void store( at::BFloat16* output_ptr, const int num_valid_outputs = 4) { @@ -2733,9 +2733,9 @@ struct VecNT<4, PrimitiveType::INT> { } } -#if !( \ - ((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \ - (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))) +#if defined(USE_ROCM) || \ + !(((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \ + (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))) DEVICE_INLINE void store( at::BFloat16* output_ptr, const int num_valid_outputs = 4) { @@ -2903,9 +2903,9 @@ struct VecNT<8, PrimitiveType::INT> { } } -#if !( \ - ((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \ - (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))) +#if defined(USE_ROCM) || \ + !(((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \ + (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))) DEVICE_INLINE void store( at::BFloat16* output_ptr, const int num_valid_outputs = 8) { @@ -3090,9 +3090,9 @@ struct VecNT<16, PrimitiveType::INT> { } } -#if !( \ - ((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \ - (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))) +#if defined(USE_ROCM) || \ + !(((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \ + (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))) DEVICE_INLINE void store( at::BFloat16* output_ptr, const int num_valid_outputs = 16) {