diff --git a/fbgemm_gpu/bench/batched_unary_embeddings_benchmark.py b/fbgemm_gpu/bench/batched_unary_embeddings_benchmark.py index 2ab90148b..2af0b5876 100644 --- a/fbgemm_gpu/bench/batched_unary_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/batched_unary_embeddings_benchmark.py @@ -23,7 +23,10 @@ else: from fbgemm_gpu.bench.bench_utils import benchmark_torch_function - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + if torch.version.hip: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") + else: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") diff --git a/fbgemm_gpu/bench/histogram_binning_calibration_benchmark.py b/fbgemm_gpu/bench/histogram_binning_calibration_benchmark.py index 4ae3fd406..d7d33a9fd 100644 --- a/fbgemm_gpu/bench/histogram_binning_calibration_benchmark.py +++ b/fbgemm_gpu/bench/histogram_binning_calibration_benchmark.py @@ -18,7 +18,10 @@ # pyre-ignore[21] from fbgemm_gpu import open_source # noqa: F401 except Exception: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + if torch.version.hip: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") + else: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") diff --git a/fbgemm_gpu/bench/jagged_tensor_benchmark.py b/fbgemm_gpu/bench/jagged_tensor_benchmark.py index 9142ced87..d8f2f8e39 100644 --- a/fbgemm_gpu/bench/jagged_tensor_benchmark.py +++ b/fbgemm_gpu/bench/jagged_tensor_benchmark.py @@ -25,7 +25,10 @@ else: from fbgemm_gpu.bench.bench_utils import benchmark_torch_function - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + if torch.version.hip: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") + else: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") diff --git a/fbgemm_gpu/bench/merge_embeddings_benchmark.py b/fbgemm_gpu/bench/merge_embeddings_benchmark.py index 8059c8554..90023b347 100644 --- a/fbgemm_gpu/bench/merge_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/merge_embeddings_benchmark.py @@ -37,7 +37,14 @@ else: from fbgemm_gpu.bench.bench_utils import benchmark_torch_function - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:merge_pooled_embeddings") + if torch.version.hip: + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu:merge_pooled_embeddings_hip" + ) + else: + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu:merge_pooled_embeddings" + ) torch.ops.load_library( "//deeplearning/fbgemm/fbgemm_gpu:merge_pooled_embeddings_cpu" ) diff --git a/fbgemm_gpu/bench/quantize_ops_benchmark.py b/fbgemm_gpu/bench/quantize_ops_benchmark.py index b02bbe62f..d5647e36a 100644 --- a/fbgemm_gpu/bench/quantize_ops_benchmark.py +++ b/fbgemm_gpu/bench/quantize_ops_benchmark.py @@ -25,7 +25,10 @@ else: from fbgemm_gpu.bench.bench_utils import benchmark_torch_function - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + if torch.version.hip: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") + else: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") diff --git a/fbgemm_gpu/bench/sparse_ops_benchmark.py b/fbgemm_gpu/bench/sparse_ops_benchmark.py index 0602d0ae8..3da7721fd 100644 --- a/fbgemm_gpu/bench/sparse_ops_benchmark.py +++ b/fbgemm_gpu/bench/sparse_ops_benchmark.py @@ -29,7 +29,10 @@ else: from fbgemm_gpu.bench.bench_utils import benchmark_torch_function - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + if torch.version.hip: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") + else: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:index_select_ops") diff --git a/fbgemm_gpu/bench/split_embeddings_cache_benchmark.py b/fbgemm_gpu/bench/split_embeddings_cache_benchmark.py index e7b836aa4..e1b615cbf 100644 --- a/fbgemm_gpu/bench/split_embeddings_cache_benchmark.py +++ b/fbgemm_gpu/bench/split_embeddings_cache_benchmark.py @@ -29,10 +29,16 @@ # pyre-ignore[21] from fbgemm_gpu import open_source # noqa: F401 except Exception: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:cumem_utils") - torch.ops.load_library( - "//deeplearning/fbgemm/fbgemm_gpu:split_table_batched_embeddings" - ) + if torch.version.hip: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:cumem_utils_hip") + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu:split_table_batched_embeddings_hip" + ) + else: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:cumem_utils") + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu:split_table_batched_embeddings" + ) # pyre-ignore diff --git a/fbgemm_gpu/bench/ssd_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/ssd_table_batched_embeddings_benchmark.py index 335e44342..df4377aaa 100644 --- a/fbgemm_gpu/bench/ssd_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/ssd_table_batched_embeddings_benchmark.py @@ -27,9 +27,14 @@ logging.basicConfig(level=logging.DEBUG) -torch.ops.load_library( - "//deeplearning/fbgemm/fbgemm_gpu:ssd_split_table_batched_embeddings" -) +if torch.version.hip: + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu:ssd_split_table_batched_embeddings_hip" + ) +else: + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu:ssd_split_table_batched_embeddings" + ) logging.basicConfig(level=logging.DEBUG) diff --git a/fbgemm_gpu/bench/stride_gemm_benchmark.py b/fbgemm_gpu/bench/stride_gemm_benchmark.py index bad5ba6ae..c80505938 100644 --- a/fbgemm_gpu/bench/stride_gemm_benchmark.py +++ b/fbgemm_gpu/bench/stride_gemm_benchmark.py @@ -17,7 +17,10 @@ # pyre-ignore[21] from fbgemm_gpu import open_source # noqa: F401 except Exception: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + if torch.version.hip: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") + else: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") diff --git a/fbgemm_gpu/codegen/split_embedding_codegen_lookup_invoker.template b/fbgemm_gpu/codegen/split_embedding_codegen_lookup_invoker.template index 3a00d7d1f..090d11235 100644 --- a/fbgemm_gpu/codegen/split_embedding_codegen_lookup_invoker.template +++ b/fbgemm_gpu/codegen/split_embedding_codegen_lookup_invoker.template @@ -16,17 +16,30 @@ from .lookup_args import * # Provide compatibility to downstream packages for eventual migration to the split training / inference packages try: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cuda_training") + if torch.version.hip: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_hip_training") + else: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cuda_training") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cpu_training") except Exception: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops") + if torch.version.hip: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_hip") + else: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cpu") -torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:cumem_utils") -torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") +if torch.version.hip: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:cumem_utils_hip") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:split_table_batched_embeddings_hip") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:embedding_inplace_update_hip") +else: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:cumem_utils") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:split_table_batched_embeddings") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:embedding_inplace_update") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") -torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:split_table_batched_embeddings") -torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:embedding_inplace_update") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:embedding_inplace_update_cpu") {%- endif %} diff --git a/fbgemm_gpu/fbgemm_gpu/batched_unary_embeddings_ops.py b/fbgemm_gpu/fbgemm_gpu/batched_unary_embeddings_ops.py index db9b4c48d..b4ee5c1be 100644 --- a/fbgemm_gpu/fbgemm_gpu/batched_unary_embeddings_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/batched_unary_embeddings_ops.py @@ -15,7 +15,11 @@ # pyre-ignore[21] from fbgemm_gpu import open_source # noqa: F401 except Exception: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + if torch.version.hip: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") + else: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") diff --git a/fbgemm_gpu/fbgemm_gpu/permute_pooled_embedding_modules.py b/fbgemm_gpu/fbgemm_gpu/permute_pooled_embedding_modules.py index cb66088b9..ffdc34449 100644 --- a/fbgemm_gpu/fbgemm_gpu/permute_pooled_embedding_modules.py +++ b/fbgemm_gpu/fbgemm_gpu/permute_pooled_embedding_modules.py @@ -18,9 +18,14 @@ torch.ops.load_library( "//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_cpu" ) - torch.ops.load_library( - "//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_gpu" - ) + if torch.version.hip: + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_gpu_hip" + ) + else: + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_gpu" + ) class PermutePooledEmbeddings: diff --git a/fbgemm_gpu/fbgemm_gpu/quantize_utils.py b/fbgemm_gpu/fbgemm_gpu/quantize_utils.py index ff5b2b123..1c8292616 100644 --- a/fbgemm_gpu/fbgemm_gpu/quantize_utils.py +++ b/fbgemm_gpu/fbgemm_gpu/quantize_utils.py @@ -15,7 +15,11 @@ # pyre-ignore[21] from fbgemm_gpu import open_source # noqa: F401 except Exception: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + if torch.version.hip: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") + else: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") TORCH_HALF_MIN: float = torch.finfo(torch.float16).min diff --git a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py index 01495d2d7..a774c5ff4 100644 --- a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py @@ -15,15 +15,28 @@ # pyre-ignore from fbgemm_gpu import open_source # noqa: F401 except Exception: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + if torch.version.hip: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu:merge_pooled_embeddings_hip" + ) + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_hip" + ) + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:input_combine_hip") + else: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu:merge_pooled_embeddings" + ) + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:input_combine") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:merge_pooled_embeddings") torch.ops.load_library( "//deeplearning/fbgemm/fbgemm_gpu:merge_pooled_embeddings_cpu" ) - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cpu") - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:input_combine") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:input_combine_cpu") import torch.utils._pytree as pytree diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py index c8dccace9..83f988b06 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py @@ -49,7 +49,12 @@ ) try: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops") + if torch.version.hip: + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_hip" + ) + else: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cpu") except Exception: pass diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py index d5b0d66d2..4e160858c 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py @@ -30,9 +30,14 @@ ) try: - torch.ops.load_library( - "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cuda_inference" - ) + if torch.version.hip: + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_hip_inference" + ) + else: + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cuda_inference" + ) torch.ops.load_library( "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cpu_inference" ) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index 75f3ee6c4..2f49390eb 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -34,9 +34,14 @@ ) try: - torch.ops.load_library( - "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cuda_training" - ) + if torch.version.hip: + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_hip_training" + ) + else: + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cuda_training" + ) torch.ops.load_library( "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cpu_training" ) diff --git a/fbgemm_gpu/fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py b/fbgemm_gpu/fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py index ea3760d06..f45bec33a 100644 --- a/fbgemm_gpu/fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py @@ -37,9 +37,14 @@ from torch.autograd.profiler import record_function try: - torch.ops.load_library( - "//deeplearning/fbgemm/fbgemm_gpu:ssd_split_table_batched_embeddings" - ) + if torch.version.hip: + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu:ssd_split_table_batched_embeddings_hip" + ) + else: + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu:ssd_split_table_batched_embeddings" + ) except OSError: # Keep for BC: will be deprecated soon. torch.ops.load_library( diff --git a/fbgemm_gpu/test/batched_unary_embeddings_test.py b/fbgemm_gpu/test/batched_unary_embeddings_test.py index 1577a11f3..6c0bb02d8 100644 --- a/fbgemm_gpu/test/batched_unary_embeddings_test.py +++ b/fbgemm_gpu/test/batched_unary_embeddings_test.py @@ -24,7 +24,11 @@ from test_utils import gpu_unavailable except Exception: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + if torch.version.hip: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") + else: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") from fbgemm_gpu.test.test_utils import gpu_unavailable diff --git a/fbgemm_gpu/test/input_combine_test.py b/fbgemm_gpu/test/input_combine_test.py index 32e1dd246..a71bde0f4 100644 --- a/fbgemm_gpu/test/input_combine_test.py +++ b/fbgemm_gpu/test/input_combine_test.py @@ -20,7 +20,10 @@ # pyre-ignore[21] from test_utils import cpu_and_maybe_gpu, optests except Exception: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:input_combine") + if torch.version.hip: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:input_combine_hip") + else: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:input_combine") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:input_combine_cpu") from fbgemm_gpu.test.test_utils import cpu_and_maybe_gpu, optests diff --git a/fbgemm_gpu/test/jagged_tensor_ops_test.py b/fbgemm_gpu/test/jagged_tensor_ops_test.py index 01e4333db..a98fa8d27 100644 --- a/fbgemm_gpu/test/jagged_tensor_ops_test.py +++ b/fbgemm_gpu/test/jagged_tensor_ops_test.py @@ -33,7 +33,11 @@ TEST_WITH_ROCM, ) except Exception: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + if torch.version.hip: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") + else: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") import fbgemm_gpu.sparse_ops # noqa: F401, E402 from fbgemm_gpu.test.test_utils import ( diff --git a/fbgemm_gpu/test/layout_transform_ops_test.py b/fbgemm_gpu/test/layout_transform_ops_test.py index 4bfc603d6..3c5a9e705 100644 --- a/fbgemm_gpu/test/layout_transform_ops_test.py +++ b/fbgemm_gpu/test/layout_transform_ops_test.py @@ -20,7 +20,11 @@ from test_utils import gpu_unavailable except Exception: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + if torch.version.hip: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") + else: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") from fbgemm_gpu.test.test_utils import gpu_unavailable diff --git a/fbgemm_gpu/test/merge_pooled_embeddings_test.py b/fbgemm_gpu/test/merge_pooled_embeddings_test.py index a686cdb3e..ba16a1d58 100644 --- a/fbgemm_gpu/test/merge_pooled_embeddings_test.py +++ b/fbgemm_gpu/test/merge_pooled_embeddings_test.py @@ -21,7 +21,15 @@ # pyre-ignore[21] from test_utils import gpu_unavailable except Exception: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:merge_pooled_embeddings") + if torch.version.hip: + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu:merge_pooled_embeddings_hip" + ) + else: + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu:merge_pooled_embeddings" + ) + torch.ops.load_library( "//deeplearning/fbgemm/fbgemm_gpu:merge_pooled_embeddings_cpu" ) diff --git a/fbgemm_gpu/test/metric_ops_test.py b/fbgemm_gpu/test/metric_ops_test.py index ae5bf434e..73db3f434 100644 --- a/fbgemm_gpu/test/metric_ops_test.py +++ b/fbgemm_gpu/test/metric_ops_test.py @@ -16,7 +16,10 @@ from fbgemm_gpu import open_source # noqa: F401 except Exception: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:metric_ops") + if torch.version.hip: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:metric_ops_hip") + else: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:metric_ops") class MetricOpsTest(unittest.TestCase): diff --git a/fbgemm_gpu/test/quantize_ops_test.py b/fbgemm_gpu/test/quantize_ops_test.py index 2b1d6dcde..5cc8527de 100644 --- a/fbgemm_gpu/test/quantize_ops_test.py +++ b/fbgemm_gpu/test/quantize_ops_test.py @@ -37,7 +37,11 @@ symint_vector_unsupported, ) except Exception: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + if torch.version.hip: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") + else: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") from fbgemm_gpu.test.test_utils import ( bytes_to_half_floats, diff --git a/fbgemm_gpu/test/sparse_ops_test.py b/fbgemm_gpu/test/sparse_ops_test.py index 8f6483701..e838174b4 100644 --- a/fbgemm_gpu/test/sparse_ops_test.py +++ b/fbgemm_gpu/test/sparse_ops_test.py @@ -36,7 +36,11 @@ # pyre-ignore[21] from test_utils import gpu_available, gpu_unavailable, skipIfRocm except Exception: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + if torch.version.hip: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") + else: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:index_select_ops") import fbgemm_gpu.sparse_ops # noqa: F401, E402