diff --git a/fbgemm_gpu/src/merge_pooled_embeddings_gpu.cpp b/fbgemm_gpu/src/merge_pooled_embeddings_gpu.cpp index 4b90c2bda..0a3ad3b63 100644 --- a/fbgemm_gpu/src/merge_pooled_embeddings_gpu.cpp +++ b/fbgemm_gpu/src/merge_pooled_embeddings_gpu.cpp @@ -149,7 +149,12 @@ void all_to_one( }); auto target_device_index = target_device.index(); - TORCH_CHECK(target_device_index < num_gpus && target_device_index >= 0); + TORCH_CHECK( + target_device_index != -1, + "target_device.index() is -1. Please pass target_device with device " + "index, e.g., torch.device(\"cuda:0\")") + + TORCH_CHECK(target_device_index < num_gpus); std::vector two_hop_transfers; two_hop_transfers.reserve(input_tensors.size());