Skip to content

Commit

Permalink
Transition GPUArrays to KernelAbstractions
Browse files Browse the repository at this point in the history
Co-authored-by: Valentin Churavy <[email protected]>
  • Loading branch information
2 people authored and leios committed Mar 30, 2024
1 parent 7970d56 commit 0800eee
Show file tree
Hide file tree
Showing 25 changed files with 324 additions and 614 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "10.0.2"
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Expand Down
2 changes: 0 additions & 2 deletions docs/src/interface.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ all, you need to provide a type that represents your execution back-end and a wa
kernels:

```@docs
GPUArrays.AbstractGPUBackend
GPUArrays.AbstractKernelContext
GPUArrays.gpu_call
GPUArrays.thread_block_heuristic
```
Expand Down
6 changes: 3 additions & 3 deletions lib/GPUArraysCore/src/GPUArraysCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -222,10 +222,10 @@ end
Gets the GPUArrays back-end responsible for managing arrays of type `T`.
"""
backend(::Type) = error("This object is not a GPU array") # COV_EXCL_LINE
backend(x) = backend(typeof(x))
get_backend(::Type) = error("This object is not a GPU array") # COV_EXCL_LINE
get_backend(x) = get_backend(typeof(x))

# WrappedArray from Adapt for Base wrappers.
backend(::Type{WA}) where WA<:WrappedArray = backend(unwrap_type(WA))
get_backend(::Type{WA}) where WA<:WrappedArray = backend(unwrap_type(WA))

end # module GPUArraysCore
3 changes: 2 additions & 1 deletion lib/JLArrays/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ version = "0.1.4"
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[compat]
Adapt = "2.0, 3.0, 4.0"
GPUArrays = "10"
julia = "1.8"
Random = "1"
julia = "1.8"
198 changes: 79 additions & 119 deletions lib/JLArrays/src/JLArrays.jl
Original file line number Diff line number Diff line change
@@ -1,54 +1,29 @@
# reference implementation on the CPU

# note that most of the code in this file serves to define a functional array type,
# the actual implementation of GPUArrays-interfaces is much more limited.
# This acts as a wrapper around KernelAbstractions's parallel CPU
# functionality. It is useful for testing GPUArrays (and other packages)
# when no GPU is present.
# This file follows conventions from AMDGPU.jl

module JLArrays

export JLArray, JLVector, JLMatrix, jl

using GPUArrays

using Adapt
import KernelAbstractions
import KernelAbstractions: Adapt, StaticArrays, Backend, Kernel, StaticSize, DynamicSize, partition, blocks, workitems, launch_config

export JLArray, JLVector, JLMatrix, jl, JLBackend

#
# Device functionality
#

const MAXTHREADS = 256


## execution

struct JLBackend <: AbstractGPUBackend end

mutable struct JLKernelContext <: AbstractKernelContext
blockdim::Int
griddim::Int
blockidx::Int
threadidx::Int

localmem_counter::Int
localmems::Vector{Vector{Array}}
end

function JLKernelContext(threads::Int, blockdim::Int)
blockcount = prod(blockdim)
lmems = [Vector{Array}() for i in 1:blockcount]
JLKernelContext(threads, blockdim, 1, 1, 0, lmems)
struct JLBackend <: KernelAbstractions.GPU
static::Bool
JLBackend(;static::Bool=false) = new(static)
end

function JLKernelContext(ctx::JLKernelContext, threadidx::Int)
JLKernelContext(
ctx.blockdim,
ctx.griddim,
ctx.blockidx,
threadidx,
0,
ctx.localmems
)
end

struct Adaptor end
jlconvert(arg) = adapt(Adaptor(), arg)
Expand All @@ -60,28 +35,35 @@ end
Base.getindex(r::JlRefValue) = r.x
Adapt.adapt_structure(to::Adaptor, r::Base.RefValue) = JlRefValue(adapt(to, r[]))

function GPUArrays.gpu_call(::JLBackend, f, args, threads::Int, blocks::Int;
name::Union{String,Nothing})
ctx = JLKernelContext(threads, blocks)
device_args = jlconvert.(args)
tasks = Array{Task}(undef, threads)
for blockidx in 1:blocks
ctx.blockidx = blockidx
for threadidx in 1:threads
thread_ctx = JLKernelContext(ctx, threadidx)
tasks[threadidx] = @async f(thread_ctx, device_args...)
# TODO: require 1.3 and use Base.Threads.@spawn for actual multithreading
# (this would require a different synchronization mechanism)
end
for t in tasks
fetch(t)
end
mutable struct JLArray{T, N} <: AbstractGPUArray{T, N}
data::DataRef{Vector{UInt8}}

offset::Int # offset of the data in the buffer, in number of elements

dims::Dims{N}

# allocating constructor
function JLArray{T,N}(::UndefInitializer, dims::Dims{N}) where {T,N}
check_eltype(T)
maxsize = prod(dims) * sizeof(T)
data = Vector{UInt8}(undef, maxsize)
ref = DataRef(data)
obj = new{T,N}(ref, 0, dims)
finalizer(unsafe_free!, obj)
end
return
end

# low-level constructor for wrapping existing data
function JLArray{T,N}(ref::DataRef{Vector{UInt8}}, dims::Dims{N};
offset::Int=0) where {T,N}
check_eltype(T)
obj = new{T,N}(ref, offset, dims)
finalizer(unsafe_free!, obj)
end
end

## executed on-device
Adapt.adapt_storage(::JLBackend, a::Array) = Adapt.adapt(JLArrays.JLArray, a)
Adapt.adapt_storage(::JLBackend, a::JLArrays.JLArray) = a
Adapt.adapt_storage(::KernelAbstractions.CPU, a::JLArrays.JLArray) = convert(Array, a)

# array type

Expand All @@ -107,43 +89,6 @@ end
@inline Base.getindex(A::JLDeviceArray, index::Integer) = getindex(typed_data(A), index)
@inline Base.setindex!(A::JLDeviceArray, x, index::Integer) = setindex!(typed_data(A), x, index)


# indexing

for f in (:blockidx, :blockdim, :threadidx, :griddim)
@eval GPUArrays.$f(ctx::JLKernelContext) = ctx.$f
end

# memory

function GPUArrays.LocalMemory(ctx::JLKernelContext, ::Type{T}, ::Val{dims}, ::Val{id}) where {T, dims, id}
ctx.localmem_counter += 1
lmems = ctx.localmems[blockidx(ctx)]

# first invocation in block
data = if length(lmems) < ctx.localmem_counter
lmem = fill(zero(T), dims)
push!(lmems, lmem)
lmem
else
lmems[ctx.localmem_counter]
end

N = length(dims)
JLDeviceArray{T,N}(data, tuple(dims...))
end

# synchronization

@inline function GPUArrays.synchronize_threads(::JLKernelContext)
# All threads are getting started asynchronously, so a yield will yield to the next
# execution of the same function, which should call yield at the exact same point in the
# program, leading to a chain of yields effectively syncing the tasks (threads).
yield()
return
end


#
# Host abstractions
#
Expand All @@ -157,32 +102,6 @@ function check_eltype(T)
end
end

mutable struct JLArray{T, N} <: AbstractGPUArray{T, N}
data::DataRef{Vector{UInt8}}

offset::Int # offset of the data in the buffer, in number of elements

dims::Dims{N}

# allocating constructor
function JLArray{T,N}(::UndefInitializer, dims::Dims{N}) where {T,N}
check_eltype(T)
maxsize = prod(dims) * sizeof(T)
data = Vector{UInt8}(undef, maxsize)
ref = DataRef(data)
obj = new{T,N}(ref, 0, dims)
finalizer(unsafe_free!, obj)
end

# low-level constructor for wrapping existing data
function JLArray{T,N}(ref::DataRef{Vector{UInt8}}, dims::Dims{N};
offset::Int=0) where {T,N}
check_eltype(T)
obj = new{T,N}(ref, offset, dims)
finalizer(unsafe_free!, obj)
end
end

unsafe_free!(a::JLArray) = GPUArrays.unsafe_free!(a.data)

# conversion of untyped data to a typed Array
Expand Down Expand Up @@ -392,8 +311,6 @@ end

## GPUArrays interfaces

GPUArrays.backend(::Type{<:JLArray}) = JLBackend()

Adapt.adapt_storage(::Adaptor, x::JLArray{T,N}) where {T,N} =
JLDeviceArray{T,N}(x.data[], x.offset, x.dims)

Expand All @@ -406,4 +323,47 @@ function GPUArrays.mapreducedim!(f, op, R::AnyJLArray, A::Union{AbstractArray,Br
R
end

## KernelAbstractions interface

KernelAbstractions.get_backend(a::JLA) where JLA <: JLArray = JLBackend()

function KernelAbstractions.mkcontext(kernel::Kernel{JLBackend}, I, _ndrange, iterspace, ::Dynamic) where Dynamic
return KernelAbstractions.CompilerMetadata{KernelAbstractions.ndrange(kernel), Dynamic}(I, _ndrange, iterspace)
end

KernelAbstractions.allocate(::JLBackend, ::Type{T}, dims::Tuple) where T = JLArray{T}(undef, dims)

@inline function launch_config(kernel::Kernel{JLBackend}, ndrange, workgroupsize)
if ndrange isa Integer
ndrange = (ndrange,)
end
if workgroupsize isa Integer
workgroupsize = (workgroupsize, )
end

if KernelAbstractions.workgroupsize(kernel) <: DynamicSize && workgroupsize === nothing
workgroupsize = (1024,) # Vectorization, 4x unrolling, minimal grain size
end
iterspace, dynamic = partition(kernel, ndrange, workgroupsize)
# partition checked that the ndrange's agreed
if KernelAbstractions.ndrange(kernel) <: StaticSize
ndrange = nothing
end

return ndrange, workgroupsize, iterspace, dynamic
end

KernelAbstractions.isgpu(b::JLBackend) = false

function convert_to_cpu(obj::Kernel{JLBackend, W, N, F}) where {W, N, F}
return Kernel{typeof(KernelAbstractions.CPU(; static = obj.backend.static)), W, N, F}(KernelAbstractions.CPU(; static = obj.backend.static), obj.f)
end

function (obj::Kernel{JLBackend})(args...; ndrange=nothing, workgroupsize=nothing)
device_args = jlconvert.(args)
new_obj = convert_to_cpu(obj)
new_obj(device_args...; ndrange, workgroupsize)

end

end
8 changes: 3 additions & 5 deletions src/GPUArrays.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module GPUArrays

using KernelAbstractions
using Serialization
using Random
using LinearAlgebra
Expand All @@ -14,14 +15,11 @@ using LLVM.Interop
using Reexport
@reexport using GPUArraysCore

# device functionality
include("device/execution.jl")
## executed on-device
include("device/execution.jl")
include("device/abstractarray.jl")
include("device/indexing.jl")
include("device/memory.jl")
include("device/synchronization.jl")

using KernelAbstractions
# host abstractions
include("host/abstractarray.jl")
include("host/construction.jl")
Expand Down
Loading

0 comments on commit 0800eee

Please sign in to comment.