diff --git a/tests/mr/device/mr_ref_multithreaded_tests.cpp b/tests/mr/device/mr_ref_multithreaded_tests.cpp index 7d749efd1..944ba1807 100644 --- a/tests/mr/device/mr_ref_multithreaded_tests.cpp +++ b/tests/mr/device/mr_ref_multithreaded_tests.cpp @@ -109,8 +109,13 @@ TEST_P(mr_ref_test_mt, SetCurrentDeviceResourceRef_mt) { // single thread changes default resource, then multiple threads use it auto old = rmm::mr::set_current_device_resource_ref(this->ref); + test_get_current_device_resource_ref(); - spawn([mr = this->ref]() { + int device; + RMM_CUDA_TRY(cudaGetDevice(&device)); + + spawn([device, mr = this->ref]() { + RMM_CUDA_TRY(cudaSetDevice(device)); EXPECT_EQ(mr, rmm::mr::get_current_device_resource_ref()); test_get_current_device_resource_ref(); // test allocating with the new default resource }); @@ -156,7 +161,17 @@ TEST_P(mr_ref_test_mt, SetCurrentDeviceResourceRefPerThread_mt) } } -TEST_P(mr_ref_test_mt, Allocate) { spawn(test_various_allocations, this->ref); } +TEST_P(mr_ref_test_mt, Allocate) +{ + int device; + RMM_CUDA_TRY(cudaGetDevice(&device)); + + auto mr = this->ref; + spawn([device, mr]() { + RMM_CUDA_TRY(cudaSetDevice(device)); + test_various_allocations(mr); + }); +} TEST_P(mr_ref_test_mt, AllocateDefaultStream) { diff --git a/tests/mr/device/test_utils.hpp b/tests/mr/device/test_utils.hpp index 2b9513793..5b7ef197b 100644 --- a/tests/mr/device/test_utils.hpp +++ b/tests/mr/device/test_utils.hpp @@ -31,17 +31,14 @@ inline bool is_device_accessible_memory(void* ptr) { cudaPointerAttributes attributes{}; if (cudaSuccess != cudaPointerGetAttributes(&attributes, ptr)) { return false; } - return (attributes.type == cudaMemoryTypeDevice) or (attributes.type == cudaMemoryTypeManaged) or - ((attributes.type == cudaMemoryTypeHost) and (attributes.devicePointer != nullptr)) or - ((attributes.type == cudaMemoryTypeUnregistered) and - (rmm::mr::detail::is_system_memory_supported(rmm::get_current_cuda_device()))); + return attributes.devicePointer != nullptr; } inline bool is_host_memory(void* ptr) { cudaPointerAttributes attributes{}; if (cudaSuccess != cudaPointerGetAttributes(&attributes, ptr)) { return false; } - return attributes.type == cudaMemoryTypeHost; + return attributes.hostPointer != nullptr || attributes.type == cudaMemoryTypeUnregistered; } inline bool is_properly_aligned(void* ptr)