Skip to content

Commit

Permalink
Replace to_host by Array. (#226)
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt authored Sep 11, 2024
1 parent ec6a054 commit 01be2fb
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 25 deletions.
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Breaking changes:
- Argument conversion has been removed; the user should make sure Julia arguments passed to
kernels match the OpenCL argument types (i.e., no empty types, 4-element tuples for
a 3-element `float3` arguments).

- The `to_host` function has been replaced by simply calling `Array` on the `CLArray`.

New features:

Expand Down
31 changes: 14 additions & 17 deletions src/array.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import LinearAlgebra

export CLArray, CLMatrix, CLVector, to_host
export CLArray, CLMatrix, CLVector

mutable struct CLArray{T, N} <: CLObject
ctx::cl.Context
Expand Down Expand Up @@ -66,26 +66,24 @@ Base.size(A::CLArray) = A.size
Base.size(A::CLArray, dim::Integer) = A.size[dim]
Base.ndims(A::CLArray) = length(size(A))
Base.length(A::CLArray) = prod(size(A))
Base.:(==)(A:: CLArray, B:: CLArray) =
buffer(A) == buffer(B) && size(A) == size(B)
Base.:(==)(A:: CLArray, B:: CLArray) = buffer(A) == buffer(B) && size(A) == size(B)

function Base.reshape(A::CLArray, dims...)
@assert prod(dims) == prod(size(A))
return copy(A, size=dims)
end

## show

Base.show(io::IO, A::CLArray{T,N}) where {T, N} =
print(io, "CLArray{$T,$N}($(buffer(A)),$(size(A)))")

## to_host

function to_host(A::CLArray{T,N}) where {T, N}
function Base.Array(A::CLArray{T,N}) where {T, N}
hA = Array{T}(undef, size(A)...)
copy!(hA, buffer(A))
return hA
end

## show

Base.show(io::IO, A::CLArray{T,N}) where {T, N} =
print(io, "CLArray{$T,$N}($(buffer(A)),$(size(A)))")

## other array operations

const TRANSPOSE_FLOAT_PROGRAM_PATH = joinpath(@__DIR__, "kernels", "transpose_float.cl")
Expand All @@ -107,8 +105,8 @@ function LinearAlgebra.transpose!(B::CLMatrix{Float32}, A::CLMatrix{Float32})
block_size=block_size)
h, w = size(A)
lmem = cl.LocalMem(Float32, block_size * (block_size + 1))
cl.set_args!(kernel, buffer(B), buffer(A), UInt32(h), UInt32(w), lmem)
return cl.enqueue_kernel(kernel, (h, w), (block_size, block_size))
return cl.call(kernel, buffer(B), buffer(A), UInt32(h), UInt32(w), lmem;
global_size=(h, w), local_size=(block_size, block_size))
end

"""Transpose CLMatrix A"""
Expand All @@ -128,10 +126,9 @@ function LinearAlgebra.transpose!(B::CLMatrix{Float64}, A::CLMatrix{Float64})
kernel = get_kernel(TRANSPOSE_DOUBLE_PROGRAM_PATH, "transpose",
block_size=block_size)
h, w = size(A)
# lmem = cl.LocalMem(Float64, block_size * (block_size + 1))
lmem = cl.LocalMem(Float64, block_size * block_size)
cl.set_args!(kernel, buffer(B), buffer(A), UInt32(h), UInt32(w), lmem)
return cl.enqueue_kernel(kernel, (h, w), (block_size, block_size))
lmem = cl.LocalMem(Float32, block_size * (block_size + 1))
return cl.call(kernel, buffer(B), buffer(A), UInt32(h), UInt32(w), lmem;
global_size=(h, w), local_size=(block_size, block_size))
end

"""Transpose CLMatrix A"""
Expand Down
12 changes: 5 additions & 7 deletions test/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ using LinearAlgebra
end

@testset "fill" begin
@test to_host(OpenCL.fill(Float32, Float32(0.5),
32, 64)) == fill(Float32(0.5), 32, 64)
@test to_host(OpenCL.zeros(Float32, 64)) == zeros(Float32, 64)
@test to_host(OpenCL.ones(Float32, 64)) == ones(Float32, 64)
@test Array(OpenCL.fill(Float32, Float32(0.5),
32, 64)) == fill(Float32(0.5), 32, 64)
@test Array(OpenCL.zeros(Float32, 64)) == zeros(Float32, 64)
@test Array(OpenCL.ones(Float32, 64)) == ones(Float32, 64)
end

@testset "core functions" begin
Expand All @@ -35,10 +35,8 @@ using LinearAlgebra
@test reshape(B, 128, 64) == A

# transpose
X = CLArray(rand(Float32, 32, 32))
B = OpenCL.zeros(Float32, 64, 128)
ev = transpose!(B, A)
cl.wait(ev)
#@test to_host(copy(A')) == to_host(B)
@test Array(A)' == Array(B)
end
end

0 comments on commit 01be2fb

Please sign in to comment.