Skip to content

Commit

Permalink
Merge pull request #18 from fjarri/reikna-fixes-p2
Browse files Browse the repository at this point in the history
Reikna fixes part 2
  • Loading branch information
fjarri authored Aug 1, 2024
2 parents 7b7addd + 4d25dca commit 47144ae
Show file tree
Hide file tree
Showing 15 changed files with 289 additions and 173 deletions.
4 changes: 4 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ Buffers and arrays
.. autoclass:: Buffer()
:members:

.. autoclass:: ArrayMetadata
:members:
:special-members: __getitem__

.. autoclass:: ArrayMetadataLike()
:show-inheritance:
:members:
Expand Down
10 changes: 10 additions & 0 deletions docs/history.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,19 @@ Changed
^^^^^^^

* ``local_mem`` keyword parameter of kernel calls renamed to ``cu_dynamic_local_mem``. (PR_17_)
* Renamed ``no_async`` keyword parameter to ``sync``. (PR_18_)


Added
^^^^^

* Made ``ArrayMetadata`` public. (PR_18_)
* ``metadata`` attribute to ``Array``. (PR_18_)
* ``ArrayMetadata.buffer_size``, ``span``, ``min_offset``, ``first_element_offset``, and ``get_sub_region()``; ``Array.minimum_subregion()``. (PR_18_)


.. _PR_17: https://github.com/fjarri/grunnur/pull/17
.. _PR_18: https://github.com/fjarri/grunnur/pull/18



Expand Down
2 changes: 1 addition & 1 deletion grunnur/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
opencl_api_id,
)
from .array import Array, ArrayLike, MultiArray
from .array_metadata import ArrayMetadataLike
from .array_metadata import ArrayMetadata, ArrayMetadataLike
from .buffer import Buffer
from .context import Context
from .device import Device, DeviceFilter
Expand Down
2 changes: 1 addition & 1 deletion grunnur/adapter_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def set(
queue_adapter: QueueAdapter,
source: numpy.ndarray[Any, numpy.dtype[Any]] | BufferAdapter,
*,
no_async: bool = False,
sync: bool = False,
) -> None:
pass

Expand Down
6 changes: 3 additions & 3 deletions grunnur/adapter_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ def set(
queue_adapter: QueueAdapter,
source: numpy.ndarray[Any, numpy.dtype[Any]] | BufferAdapter,
*,
no_async: bool = False,
sync: bool = False,
) -> None:
# Will be checked in the upper levels.
assert isinstance(queue_adapter, CuQueueAdapter) # noqa: S101
Expand All @@ -503,15 +503,15 @@ def set(
ptr = int(self._ptr) if isinstance(self._ptr, numpy.number) else self._ptr

if isinstance(source, numpy.ndarray):
if no_async:
if sync:
pycuda_driver.memcpy_htod(ptr, source)
else:
pycuda_driver.memcpy_htod_async(ptr, source, stream=queue_adapter._pycuda_stream) # noqa: SLF001
else:
# Will be checked in the upper levels.
assert isinstance(source, CuBufferAdapter) # noqa: S101
buf_ptr = int(source._ptr) if isinstance(source._ptr, numpy.number) else source._ptr # noqa: SLF001
if no_async:
if sync:
pycuda_driver.memcpy_dtod(ptr, buf_ptr, source.size)
else:
pycuda_driver.memcpy_dtod_async(
Expand Down
4 changes: 2 additions & 2 deletions grunnur/adapter_opencl.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ def set(
queue_adapter: QueueAdapter,
source: numpy.ndarray[Any, numpy.dtype[Any]] | BufferAdapter,
*,
no_async: bool = False,
sync: bool = False,
) -> None:
# Will be checked in the upper levels.
assert isinstance(queue_adapter, OclQueueAdapter) # noqa: S101
Expand All @@ -506,7 +506,7 @@ def set(
# This keyword is only supported for transfers involving hosts in PyOpenCL
kwds = {}
if not isinstance(source, OclBufferAdapter):
kwds["is_blocking"] = no_async
kwds["is_blocking"] = sync

pyopencl.enqueue_copy(
queue_adapter._pyopencl_queue, # noqa: SLF001
Expand Down
73 changes: 42 additions & 31 deletions grunnur/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ class Array:
strides: tuple[int, ...]
"""Array strides."""

offset: int
"""Offset of the first element in the associated buffer."""

metadata: ArrayMetadata
"""Array metadata object."""

@classmethod
def from_host(
cls,
Expand Down Expand Up @@ -60,7 +66,7 @@ def from_host(
def empty(
cls,
device: BoundDevice,
shape: Sequence[int],
shape: Sequence[int] | int,
dtype: DTypeLike,
strides: Sequence[int] | None = None,
first_element_offset: int = 0,
Expand All @@ -80,62 +86,71 @@ def empty(
metadata = ArrayMetadata(
shape, dtype, strides=strides, first_element_offset=first_element_offset
)
size = metadata.buffer_size

if allocator is None:
allocator = Buffer.allocate
data = allocator(device, size)
data = allocator(device, metadata.buffer_size)

return cls(metadata, data)

@classmethod
def empty_like(cls, device: BoundDevice, array_like: ArrayMetadataLike) -> Array:
"""Creates an empty array with the same shape and dtype as ``array_like``."""
# TODO: take other information like strides and offset
return cls.empty(device, array_like.shape, array_like.dtype)

def __init__(self, array_metadata: ArrayMetadata, data: Buffer):
self._metadata = array_metadata
def __init__(self, metadata: ArrayMetadata, data: Buffer):
if data.size < metadata.buffer_size:
raise ValueError(
f"The buffer size required by the given metadata ({metadata.buffer_size}) "
f"is larger than the given buffer size ({data.size})"
)

self.metadata = metadata
self.device = data.device
self.shape = self._metadata.shape
self.dtype = self._metadata.dtype
self.strides = self._metadata.strides
self.first_element_offset = self._metadata.first_element_offset
self.buffer_size = self._metadata.buffer_size

if data.size < self.buffer_size:
raise ValueError(
"Provided data buffer is not big enough to hold the array "
"(minimum required {self.buffer_size})"
)
self.shape = self.metadata.shape
self.dtype = self.metadata.dtype
self.strides = self.metadata.strides

self.data = data

def _view(self, slices: slice | tuple[slice, ...]) -> Array:
new_metadata = self._metadata[slices]

origin, size, new_metadata = new_metadata.minimal_subregion()
def minimum_subregion(self) -> Array:
"""
Returns a new array with the same metadata and the buffer substituted with
the minimum-sized subregion of the original buffer,
such that all the elements described by the metadata still fit in it.
"""
# TODO: some platforms (e.g. POCL) require this to be aligned.
origin = self.metadata.min_offset
size = self.metadata.span
data = self.data.get_sub_region(origin, size)
return Array(new_metadata, data)
metadata = self.metadata.get_sub_region(origin, size)
return Array(metadata, data)

def __getitem__(self, slices: slice | tuple[slice, ...]) -> Array:
"""Returns a view of this array."""
return Array(self.metadata[slices], self.data)

def set(
self,
queue: Queue,
array: numpy.ndarray[Any, numpy.dtype[Any]] | Array,
*,
no_async: bool = False,
sync: bool = False,
) -> None:
"""
Copies the contents of the host array to the array.
:param queue: the queue to use for the transfer.
:param array: the source array.
:param no_async: if `True`, the transfer blocks until completion.
:param sync: if `True`, the transfer blocks until completion.
"""
array_data: numpy.ndarray[Any, numpy.dtype[Any]] | Buffer
if isinstance(array, numpy.ndarray):
array_data = array
elif isinstance(array, Array):
if not array._metadata.contiguous: # noqa: SLF001
if not array.metadata.is_contiguous:
raise ValueError("Setting from a non-contiguous device array is not supported")
array_data = array.data
else:
Expand All @@ -146,7 +161,7 @@ def set(
if self.dtype != array.dtype:
raise ValueError(f"Dtype mismatch: expected {self.dtype}, got {array.dtype}")

self.data.set(queue, array_data, no_async=no_async)
self.data.set(queue, array_data, sync=sync)

def get(
self,
Expand All @@ -167,10 +182,6 @@ def get(
self.data.get(queue, dest, async_=async_)
return dest

def __getitem__(self, slices: slice | tuple[slice, ...]) -> Array:
"""Returns a view of this array."""
return self._view(slices)


@runtime_checkable
class ArrayLike(ArrayMetadataLike, Protocol):
Expand Down Expand Up @@ -371,14 +382,14 @@ def set(
mqueue: MultiQueue,
array: numpy.ndarray[Any, numpy.dtype[Any]] | MultiArray,
*,
no_async: bool = False,
sync: bool = False,
) -> None:
"""
Copies the contents of the host array to the array.
:param mqueue: the queue to use for the transfer.
:param array: the source array.
:param no_async: if `True`, the transfer blocks until completion.
:param sync: if `True`, the transfer blocks until completion.
"""
subarrays: Mapping[BoundDevice, Array | numpy.ndarray[Any, numpy.dtype[Any]]]
if isinstance(array, numpy.ndarray):
Expand All @@ -392,4 +403,4 @@ def set(
raise ValueError("Mismatched device sets in the source and the destination")

for device in self.subarrays:
self.subarrays[device].set(mqueue.queues[device], subarrays[device], no_async=no_async)
self.subarrays[device].set(mqueue.queues[device], subarrays[device], sync=sync)
Loading

0 comments on commit 47144ae

Please sign in to comment.