diff --git a/include/rmm/mr/device/cuda_async_memory_resource.hpp b/include/rmm/mr/device/cuda_async_memory_resource.hpp index b1fc0b112..dbb324a1b 100644 --- a/include/rmm/mr/device/cuda_async_memory_resource.hpp +++ b/include/rmm/mr/device/cuda_async_memory_resource.hpp @@ -53,14 +53,31 @@ class cuda_async_memory_resource final : public device_memory_resource { * memory pools (CUDA 11.2) did not support these flags, so we need a placeholder that can be * used consistently in the constructor of `cuda_async_memory_resource` with all versions of * CUDA >= 11.2. See the `cudaMemAllocationHandleType` docs at - * https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html + * https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html and ensure the enum + * values are kept in sync with the CUDA documentation. */ enum class allocation_handle_type { none = 0x0, ///< Does not allow any export mechanism. posix_file_descriptor = 0x1, ///< Allows a file descriptor to be used for exporting. Permitted ///< only on POSIX systems. win32 = 0x2, ///< Allows a Win32 NT handle to be used for exporting. (HANDLE) - win32_kmt = 0x4 ///< Allows a Win32 KMT handle to be used for exporting. (D3DKMT_HANDLE) + win32_kmt = 0x4, ///< Allows a Win32 KMT handle to be used for exporting. (D3DKMT_HANDLE) + fabric = 0x8 ///< Allows a fabric handle to be used for exporting. (cudaMemFabricHandle_t) + }; + + /** + * @brief Flags for specifying the memory pool accessibility from other devices. + * + * @note These values are exact copies from `cudaMemAccessFlags`. See the `cudaMemAccessFlags` + * docs at https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html and ensure the + * enum values are kept in sync with the CUDA documentation. The default, `none`, marks the pool's + * memory as private to the device in which it was created. `read_write` should only be used if + * memory sharing among devices is required. Note that there is a `cudaMemAccessFlagsProtRead` + * documented, but memory pools don't support read-only access, so it has been omitted. + */ + enum class access_flags { + none = 0, ///< Default, make pool not accessible. + read_write = 3 ///< Make pool read-write accessible. }; /** @@ -77,13 +94,16 @@ class cuda_async_memory_resource final : public device_memory_resource { * @param release_threshold Optional release threshold size in bytes of the pool. If no value is * provided, the release threshold is set to the total amount of memory on the current device. * @param export_handle_type Optional `cudaMemAllocationHandleType` that allocations from this - * resource should support interprocess communication (IPC). Default is - * `cudaMemHandleTypeNone` for no IPC support. + * resource should support interprocess communication (IPC). Default is `cudaMemHandleTypeNone` + * for no IPC support. + * @param access_flag Optional `cudaMemAccessFlags` that controls pool memory accessibility + * from other devices. Default is `cudaMemAccessFlagsProtNone` for no accessibility. */ // NOLINTNEXTLINE(bugprone-easily-swappable-parameters) cuda_async_memory_resource(std::optional initial_pool_size = {}, std::optional release_threshold = {}, - std::optional export_handle_type = {}) + std::optional export_handle_type = {}, + std::optional access_flag = {}) { // Check if cudaMallocAsync Memory pool supported RMM_EXPECTS(rmm::detail::runtime_async_alloc::is_supported(), @@ -115,6 +135,12 @@ class cuda_async_memory_resource final : public device_memory_resource { cudaMemPoolSetAttribute(pool_handle(), cudaMemPoolReuseAllowOpportunistic, &disabled)); } + if (access_flag) { + cudaMemAccessDesc desc = {.location = pool_props.location, + .flags = static_cast(*access_flag)}; + RMM_CUDA_TRY(cudaMemPoolSetAccess(pool_handle(), &desc, 1)); + } + auto const [free, total] = rmm::available_device_memory(); // Need an l-value to take address to pass to cudaMemPoolSetAttribute diff --git a/tests/mr/device/cuda_async_mr_tests.cpp b/tests/mr/device/cuda_async_mr_tests.cpp index a39188548..167adf1e1 100644 --- a/tests/mr/device/cuda_async_mr_tests.cpp +++ b/tests/mr/device/cuda_async_mr_tests.cpp @@ -66,5 +66,43 @@ TEST_F(AsyncMRTest, DifferentPoolsUnequal) EXPECT_FALSE(mr1.is_equal(mr2)); } +class AsyncMRFabricTest : public AsyncMRTest { + void SetUp() override + { + AsyncMRTest::SetUp(); + + auto handle_type = static_cast( + rmm::mr::cuda_async_memory_resource::allocation_handle_type::fabric); + if (!rmm::detail::runtime_async_alloc::is_export_handle_type_supported(handle_type)) { + GTEST_SKIP() << "Fabric handles are not supported in this environment. Skipping test."; + } + } +}; + +TEST_F(AsyncMRFabricTest, FabricHandlesSupport) +{ + const auto pool_init_size{100}; + const auto pool_release_threshold{1000}; + cuda_async_mr mr{pool_init_size, + pool_release_threshold, + rmm::mr::cuda_async_memory_resource::allocation_handle_type::fabric}; + void* ptr = mr.allocate(pool_init_size); + mr.deallocate(ptr, pool_init_size); + RMM_CUDA_TRY(cudaDeviceSynchronize()); +} + +TEST_F(AsyncMRFabricTest, FabricHandlesSupportReadWriteShareable) +{ + const auto pool_init_size{100}; + const auto pool_release_threshold{1000}; + cuda_async_mr mr{pool_init_size, + pool_release_threshold, + rmm::mr::cuda_async_memory_resource::allocation_handle_type::fabric, + rmm::mr::cuda_async_memory_resource::access_flags::read_write}; + void* ptr = mr.allocate(pool_init_size); + mr.deallocate(ptr, pool_init_size); + RMM_CUDA_TRY(cudaDeviceSynchronize()); +} + } // namespace } // namespace rmm::test