From 367c78b1e9c6684c3f721ac94259517a0d19b4d3 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Mon, 21 Oct 2024 17:17:44 +0200 Subject: [PATCH] add N for hardware indices --- src/KernelAbstractions.jl | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/KernelAbstractions.jl b/src/KernelAbstractions.jl index 741ac12f..dcba43b1 100644 --- a/src/KernelAbstractions.jl +++ b/src/KernelAbstractions.jl @@ -17,13 +17,15 @@ using StaticArrays using Adapt """ - @kernel function f(args) end + @kernel [N] function f(args) end Takes a function definition and generates a [`Kernel`](@ref) constructor from it. The enclosed function is allowed to contain kernel language constructs. In order to call it the kernel has first to be specialized on the backend and then invoked on the arguments. +The optional `N` parameter can be used to fix the number of dimensions used for the ndrange. + # Kernel language - [`@Const`](@ref) @@ -55,7 +57,7 @@ macro kernel(expr) end """ - @kernel config function f(args) end + @kernel [N] config function f(args) end This allows for two different configurations: @@ -585,17 +587,17 @@ in a workgroup. ``` As well as the on-device functionality. """ -struct Kernel{Backend, WorkgroupSize <: _Size, NDRange <: _Size, Fun} +struct Kernel{Backend, N, WorkgroupSize <: _Size, NDRange <: _Size, Fun} backend::Backend f::Fun end -function Base.similar(kernel::Kernel{D, WS, ND}, f::F) where {D, WS, ND, F} - Kernel{D, WS, ND, F}(kernel.backend, f) +function Base.similar(kernel::Kernel{D, N, WS, ND}, f::F) where {D, N, WS, ND, F} + Kernel{D, N, WS, ND, F}(kernel.backend, f) end -workgroupsize(::Kernel{D, WorkgroupSize}) where {D, WorkgroupSize} = WorkgroupSize -ndrange(::Kernel{D, WorkgroupSize, NDRange}) where {D, WorkgroupSize, NDRange} = NDRange +workgroupsize(::Kernel{D, N, WorkgroupSize}) where {D, WorkgroupSize} = WorkgroupSize +ndrange(::Kernel{D, N, WorkgroupSize, NDRange}) where {D, WorkgroupSize, NDRange} = NDRange backend(kernel::Kernel) = kernel.backend """ @@ -658,8 +660,8 @@ Partition a kernel for the given ndrange and workgroupsize. return iterspace, dynamic end -function construct(backend::Backend, ::S, ::NDRange, xpu_name::XPUName) where {Backend <: Union{CPU, GPU}, S <: _Size, NDRange <: _Size, XPUName} - return Kernel{Backend, S, NDRange, XPUName}(backend, xpu_name) +function construct(backend::Backend, ::Val{N}, ::S, ::NDRange, xpu_name::XPUName) where {Backend <: Union{CPU, GPU}, N, S <: _Size, NDRange <: _Size, XPUName} + return Kernel{Backend, N, S, NDRange, XPUName}(backend, xpu_name) end ###