Skip to content

Commit

Permalink
Query total memory in failure_callback_resource_adaptor tests (#1734)
Browse files Browse the repository at this point in the history
Fixes #1733 by querying total device memory and using twice as much in tests that are expected to fail allocation.

Authors:
  - Mark Harris (https://github.com/harrism)

Approvers:
  - Bradley Dice (https://github.com/bdice)

URL: #1734
  • Loading branch information
harrism authored Nov 20, 2024
1 parent 929a595 commit 3b5f6af
Showing 1 changed file with 19 additions and 13 deletions.
32 changes: 19 additions & 13 deletions python/rmm/rmm/tests/test_rmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,10 +795,28 @@ def callback(nbytes: int) -> bool:
rmm.mr.set_current_device_resource(mr)

with pytest.raises(MemoryError):
rmm.DeviceBuffer(size=int(1e11))
from rmm.mr import available_device_memory

total_memory = available_device_memory()[1]
rmm.DeviceBuffer(size=total_memory * 2)
assert retried[0]


def test_failure_callback_resource_adaptor_error():
def callback(nbytes: int) -> bool:
raise RuntimeError("MyError")

cuda_mr = rmm.mr.CudaMemoryResource()
mr = rmm.mr.FailureCallbackResourceAdaptor(cuda_mr, callback)
rmm.mr.set_current_device_resource(mr)

with pytest.raises(RuntimeError, match="MyError"):
from rmm.mr import available_device_memory

total_memory = available_device_memory()[1]
rmm.DeviceBuffer(size=total_memory * 2)


@pytest.mark.parametrize("managed", [True, False])
def test_prefetch_resource_adaptor(managed):
if managed:
Expand All @@ -823,18 +841,6 @@ def test_prefetch_resource_adaptor(managed):
assert_prefetched(db, device)


def test_failure_callback_resource_adaptor_error():
def callback(nbytes: int) -> bool:
raise RuntimeError("MyError")

cuda_mr = rmm.mr.CudaMemoryResource()
mr = rmm.mr.FailureCallbackResourceAdaptor(cuda_mr, callback)
rmm.mr.set_current_device_resource(mr)

with pytest.raises(RuntimeError, match="MyError"):
rmm.DeviceBuffer(size=int(1e11))


def test_dev_buf_circle_ref_dealloc():
# This test creates a reference cycle containing a `DeviceBuffer`
# and ensures that the garbage collector does not clear it, i.e.,
Expand Down

0 comments on commit 3b5f6af

Please sign in to comment.