Skip to content

Commit

Permalink
Implement NNlib.∇conv_data! and NNlib.∇conv_filter! (#318)
Browse files Browse the repository at this point in the history
* implement NNlib.∇conv_data and NNlib.∇conv_filter

* cond filter flipkernel
  • Loading branch information
Pangoraw authored Dec 3, 2024
1 parent 445770f commit 48ac3d7
Show file tree
Hide file tree
Showing 3 changed files with 260 additions and 10 deletions.
222 changes: 215 additions & 7 deletions ext/ReactantNNlibExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,18 +113,14 @@ function NNlib.conv!(
)
result_type = Reactant.MLIR.IR.TensorType(size(y), Reactant.MLIR.IR.Type(T))

weight = W.mlir_data
weight = W
if !flipkernel
weight = Reactant.MLIR.IR.result(
Reactant.MLIR.Dialects.stablehlo.reverse(
weight; dimensions=collect(kernel_spatial_dims .- 1)
),
)
weight = Reactant.Ops.reverse(weight; dimensions=kernel_spatial_dims)
end

conv = Reactant.MLIR.Dialects.stablehlo.convolution(
x.mlir_data,
weight;
weight.mlir_data;
result_0=result_type,
window_strides=collect(stride),
padding,
Expand Down Expand Up @@ -377,4 +373,216 @@ function NNlib.gather!(dst::TracedRArray, src::AnyTracedRArray, idxs::AbstractAr
return dst
end

dilate_shape(s, d) = max(0, 1 + d * (s - 1))

# see lax._conv_general_dilated_transpose_rhs
# https://github.com/jax-ml/jax/blob/a1dfdc1d6164ad49afb337da9effd269d430d68b/jax/_src/lax/convolution.py#L495
function NNlib.∇conv_filter!(
dw::Reactant.TracedRArray{T,N},
x::AnyTracedRArray,
dy::AnyTracedRArray,
cdims::NNlib.DenseConvDims,
) where {T,N}
# (w, h, cin, b)
# (w, h, cout, b)
# -> (w, h, cin, cout)

x = T.(materialize_traced_array(x))
dy = T.(materialize_traced_array(dy))

num_spatial_dims = N - 2
input_batch_dim = N - 1
input_feature_dim = N

kernel_input_dim = N
kernel_output_dim = N - 1

output_batch_dim = N - 1
output_feature_dim = N

output_spatial_dims = kernel_spatial_dims = input_spatial_dims = 1:num_spatial_dims

padding = reshape(collect(NNlib.padding(cdims)), (2, num_spatial_dims))
stride = NNlib.stride(cdims)
dilation = NNlib.dilation(cdims)
feature_group_count = NNlib.groupcount(cdims)

padding =
let lhs_shape = first(size(x), num_spatial_dims),
rhs_shape = dilate_shape.(first(size(dw), num_spatial_dims), dilation),
out_shape = dilate_shape.(first(size(dy), num_spatial_dims), stride),

padding = reduce(
hcat,
(
let pad_before = padding[1, i],
pad_after = (
out_shape[i] - lhs_shape[i] + rhs_shape[i] - pad_before - 1
)

[pad_before, pad_after]
end for i in 1:num_spatial_dims
),
)

Reactant.MLIR.IR.DenseElementsAttribute(padding')
end

batch_group_count = 1
if feature_group_count > 1
batch_group_count = feature_group_count
feature_group_count = 1
end

dimension_numbers = MLIR.API.stablehloConvDimensionNumbersGet(
MLIR.IR.context(),
Int64(input_batch_dim - 1),
Int64(input_feature_dim - 1),
length(input_spatial_dims),
Int64[i - 1 for i in input_spatial_dims],
Int64(kernel_input_dim - 1),
Int64(kernel_output_dim - 1),
length(kernel_spatial_dims),
Int64[i - 1 for i in kernel_spatial_dims],
Int64(output_batch_dim - 1),
Int64(output_feature_dim - 1),
length(output_spatial_dims),
Int64[i - 1 for i in output_spatial_dims],
)

result_type = Reactant.MLIR.IR.TensorType(size(dw), Reactant.MLIR.IR.Type(T))
conv = MLIR.Dialects.stablehlo.convolution(
x.mlir_data,
dy.mlir_data;
result_0=result_type,
window_strides=collect(dilation),
padding,
dimension_numbers,
rhs_dilation=collect(stride),
feature_group_count,
batch_group_count,
)

dw.mlir_data = MLIR.IR.result(conv)

if !NNlib.flipkernel(cdims)
dw.mlir_data = Reactant.Ops.reverse(dw; dimensions=output_spatial_dims).mlir_data
end

return dw
end

# see lax._conv_general_dilated_transpose_lhs
# https://github.com/jax-ml/jax/blob/a1dfdc1d6164ad49afb337da9effd269d430d68b/jax/_src/lax/convolution.py#L457
function NNlib.∇conv_data!(
dx::Reactant.TracedRArray{T,N},
dy::AnyTracedRArray,
w::AnyTracedRArray,
cdims::NNlib.DenseConvDims,
) where {T,N}
# (w, h, cout, b)
# (w, h, cin, cout)
# -> (w, h, cin, b)

dy = T.(materialize_traced_array(dy))
w = T.(materialize_traced_array(w))

num_spatial_dims = N - 2
input_batch_dim = N
input_feature_dim = N - 1

kernel_input_dim = N
kernel_output_dim = N - 1

output_batch_dim = N
output_feature_dim = N - 1

output_spatial_dims = kernel_spatial_dims = input_spatial_dims = 1:num_spatial_dims

padding = reshape(collect(NNlib.padding(cdims)), (2, num_spatial_dims))
stride = NNlib.stride(cdims)
dilation = NNlib.dilation(cdims)
feature_group_count = NNlib.groupcount(cdims)

# jax does
# (cout, cin, h, w) -> (group, cout ÷ group, cin , h, w) -> (cout ÷ group, group, cin, h, w) -> (cout, cin * group, h, w)
# we perform the same operation but in transposed form
# (w, h, cin, cout) -> (w, h, cin, cout ÷ group, group) -> (w, h, cin, group, cout ÷ group) -> (w, h, cin * group, cout ÷ group)
if feature_group_count > 1
w = reshape(
w,
(size(w, i) for i in kernel_spatial_dims)...,
size(w, N - 1),
size(w, N) ÷ feature_group_count,
feature_group_count,
)
w = permutedims(w, (kernel_spatial_dims..., N - 1, N + 1, N))
w = reshape(
w,
(size(w, i) for i in kernel_spatial_dims)...,
size(w, N - 1) * feature_group_count,
size(w, N + 1),
)
end

padding =
let lhs_shape = first(size(dx), num_spatial_dims),
rhs_shape = dilate_shape.(first(size(w), num_spatial_dims), dilation),
out_shape = dilate_shape.(first(size(dy), num_spatial_dims), stride),

padding = reduce(
hcat,
(
let pad_before = rhs_shape[i] - padding[2i - 1] - 1,
pad_after =
lhs_shape[i] + rhs_shape[i] - 1 - out_shape[i] - pad_before

[pad_before, pad_after]
end for i in input_spatial_dims
),
)

Reactant.MLIR.IR.DenseElementsAttribute(padding')
end

dimension_numbers = MLIR.API.stablehloConvDimensionNumbersGet(
MLIR.IR.context(),
Int64(input_batch_dim - 1),
Int64(input_feature_dim - 1),
length(input_spatial_dims),
Int64[i - 1 for i in input_spatial_dims],
Int64(kernel_input_dim - 1),
Int64(kernel_output_dim - 1),
length(kernel_spatial_dims),
Int64[i - 1 for i in kernel_spatial_dims],
Int64(output_batch_dim - 1),
Int64(output_feature_dim - 1),
length(output_spatial_dims),
Int64[i - 1 for i in output_spatial_dims],
)

result_type = Reactant.MLIR.IR.TensorType(size(dx), Reactant.MLIR.IR.Type(T))

if NNlib.flipkernel(cdims)
w = Reactant.Ops.reverse(w; dimensions=kernel_spatial_dims)
end

conv = MLIR.Dialects.stablehlo.convolution(
dy.mlir_data,
w.mlir_data;
result_0=result_type,
window_strides=1,
padding,
lhs_dilation=collect(stride),
rhs_dilation=collect(dilation),
dimension_numbers,
feature_group_count,
batch_group_count=1,
)

dx.mlir_data = MLIR.IR.result(conv)

return dx
end

end # module ReactantNNlibExt
2 changes: 1 addition & 1 deletion src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,7 @@ function reverse(
stablehlo.reverse(
x.mlir_data;
result=mlir_type(TracedRArray{T,N}, size(x)),
dimensions=MLIR.IR.DenseArrayAttribute(dimensions .- 1),
dimensions=MLIR.IR.DenseArrayAttribute(collect(dimensions .- 1)),
location,
),
)
Expand Down
46 changes: 44 additions & 2 deletions test/nn/nnlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,26 @@ end

conv_dims = DenseConvDims(x, weight; stride, padding, dilation, groups)

output_size = (
NNlib.output_size(conv_dims)...,
size(weight, ndims(weight)),
size(x, ndims(x)),
)
dy = randn(Float32, output_size)
dy_reactant = Reactant.to_rarray(dy)

conv_compiled = Reactant.compile(
NNlib.conv, (x_reactant, weight_reactant, conv_dims)
)

@test conv_compiled(x_reactant, weight_reactant, conv_dims)
NNlib.conv(x, weight, conv_dims)
end

# TODO: test for gradients
@test Reactant.@jit(NNlib.∇conv_data(dy_reactant, weight_reactant, conv_dims))
NNlib.∇conv_data(dy, weight, conv_dims)
@test Reactant.@jit(NNlib.∇conv_filter(x_reactant, dy_reactant, conv_dims))
NNlib.∇conv_filter(x, dy, conv_dims)
end
end

@testset "conv 1d: flip" begin
Expand Down Expand Up @@ -351,3 +362,34 @@ end
@test size(y) == (size(src)[1:(Nsrc - M)]..., size(index)...)
end
end

@testset "∇conv(D = $ndim)" for ndim in 1:3
x_spatial_dim = 4
batch_size = 2
n_in_features = 3
n_out_features = 4
kernel_size = Tuple((2 for _ in 1:ndim))

x = randn(Float32, (x_spatial_dim for _ in 1:ndim)..., n_in_features, batch_size)
x_reactant = Reactant.to_rarray(x)

w = randn(Float32, kernel_size..., n_in_features, n_out_features)
w_reactant = Reactant.to_rarray(w)

@testset "conv: padding=$padding stride=$stride dilation=$dilation groups=$groups" for (
padding, stride, dilation, groups
) in Iterators.product(
(0, 2), (1, 2), (1,), (1,)
)
conv_dims = NNlib.DenseConvDims(x, w; padding, stride, dilation, groups)

output_size = (NNlib.output_size(conv_dims)..., n_out_features, batch_size)
dy = randn(Float32, output_size)
dy_reactant = Reactant.to_rarray(dy)

@test Reactant.@jit(NNlib.∇conv_data(dy_reactant, w_reactant, conv_dims))
NNlib.∇conv_data(dy, w, conv_dims)
@test Reactant.@jit(NNlib.∇conv_filter(x_reactant, dy_reactant, conv_dims))
NNlib.∇conv_filter(x, dy, conv_dims)
end
end

0 comments on commit 48ac3d7

Please sign in to comment.