Skip to content

Commit

Permalink
add N for hardware indices
Browse files Browse the repository at this point in the history
  • Loading branch information
vchuravy committed Oct 21, 2024
1 parent 73dc429 commit f5a82b8
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions src/KernelAbstractions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@ using StaticArrays
using Adapt

"""
@kernel function f(args) end
@kernel [N] function f(args) end
Takes a function definition and generates a [`Kernel`](@ref) constructor from it.
The enclosed function is allowed to contain kernel language constructs.
In order to call it the kernel has first to be specialized on the backend
and then invoked on the arguments.
The optional `N` parameter can be used to fix the number of dimensions used for the ndrange.
# Kernel language
- [`@Const`](@ref)
Expand Down Expand Up @@ -55,7 +57,7 @@ macro kernel(expr)
end

"""
@kernel config function f(args) end
@kernel [N] config function f(args) end
This allows for two different configurations:
Expand Down Expand Up @@ -585,17 +587,17 @@ in a workgroup.
```
As well as the on-device functionality.
"""
struct Kernel{Backend, WorkgroupSize <: _Size, NDRange <: _Size, Fun}
struct Kernel{Backend, N, WorkgroupSize <: _Size, NDRange <: _Size, Fun}
backend::Backend
f::Fun
end

function Base.similar(kernel::Kernel{D, WS, ND}, f::F) where {D, WS, ND, F}
Kernel{D, WS, ND, F}(kernel.backend, f)
function Base.similar(kernel::Kernel{D, N, WS, ND}, f::F) where {D, N, WS, ND, F}
Kernel{D, N, WS, ND, F}(kernel.backend, f)
end

workgroupsize(::Kernel{D, WorkgroupSize}) where {D, WorkgroupSize} = WorkgroupSize
ndrange(::Kernel{D, WorkgroupSize, NDRange}) where {D, WorkgroupSize, NDRange} = NDRange
workgroupsize(::Kernel{D, N, WorkgroupSize}) where {D, WorkgroupSize} = WorkgroupSize
ndrange(::Kernel{D, N, WorkgroupSize, NDRange}) where {D, WorkgroupSize, NDRange} = NDRange
backend(kernel::Kernel) = kernel.backend

"""
Expand Down Expand Up @@ -658,8 +660,8 @@ Partition a kernel for the given ndrange and workgroupsize.
return iterspace, dynamic
end

function construct(backend::Backend, ::S, ::NDRange, xpu_name::XPUName) where {Backend <: Union{CPU, GPU}, S <: _Size, NDRange <: _Size, XPUName}
return Kernel{Backend, S, NDRange, XPUName}(backend, xpu_name)
function construct(backend::Backend, ::Val{N}, ::S, ::NDRange, xpu_name::XPUName) where {Backend <: Union{CPU, GPU}, N, S <: _Size, NDRange <: _Size, XPUName}
return Kernel{Backend, N, S, NDRange, XPUName}(backend, xpu_name)
end

###
Expand Down

0 comments on commit f5a82b8

Please sign in to comment.