diff --git a/src/KernelAbstractions.jl b/src/KernelAbstractions.jl index 741ac12f..dcba43b1 100644 --- a/src/KernelAbstractions.jl +++ b/src/KernelAbstractions.jl @@ -17,13 +17,15 @@ using StaticArrays using Adapt """ - @kernel function f(args) end + @kernel [N] function f(args) end Takes a function definition and generates a [`Kernel`](@ref) constructor from it. The enclosed function is allowed to contain kernel language constructs. In order to call it the kernel has first to be specialized on the backend and then invoked on the arguments. +The optional `N` parameter can be used to fix the number of dimensions used for the ndrange. + # Kernel language - [`@Const`](@ref) @@ -55,7 +57,7 @@ macro kernel(expr) end """ - @kernel config function f(args) end + @kernel [N] config function f(args) end This allows for two different configurations: @@ -585,17 +587,17 @@ in a workgroup. ``` As well as the on-device functionality. """ -struct Kernel{Backend, WorkgroupSize <: _Size, NDRange <: _Size, Fun} +struct Kernel{Backend, N, WorkgroupSize <: _Size, NDRange <: _Size, Fun} backend::Backend f::Fun end -function Base.similar(kernel::Kernel{D, WS, ND}, f::F) where {D, WS, ND, F} - Kernel{D, WS, ND, F}(kernel.backend, f) +function Base.similar(kernel::Kernel{D, N, WS, ND}, f::F) where {D, N, WS, ND, F} + Kernel{D, N, WS, ND, F}(kernel.backend, f) end -workgroupsize(::Kernel{D, WorkgroupSize}) where {D, WorkgroupSize} = WorkgroupSize -ndrange(::Kernel{D, WorkgroupSize, NDRange}) where {D, WorkgroupSize, NDRange} = NDRange +workgroupsize(::Kernel{D, N, WorkgroupSize}) where {D, WorkgroupSize} = WorkgroupSize +ndrange(::Kernel{D, N, WorkgroupSize, NDRange}) where {D, WorkgroupSize, NDRange} = NDRange backend(kernel::Kernel) = kernel.backend """ @@ -658,8 +660,8 @@ Partition a kernel for the given ndrange and workgroupsize. return iterspace, dynamic end -function construct(backend::Backend, ::S, ::NDRange, xpu_name::XPUName) where {Backend <: Union{CPU, GPU}, S <: _Size, NDRange <: _Size, XPUName} - return Kernel{Backend, S, NDRange, XPUName}(backend, xpu_name) +function construct(backend::Backend, ::Val{N}, ::S, ::NDRange, xpu_name::XPUName) where {Backend <: Union{CPU, GPU}, N, S <: _Size, NDRange <: _Size, XPUName} + return Kernel{Backend, N, S, NDRange, XPUName}(backend, xpu_name) end ###