Skip to content

Commit

Permalink
fix planar error
Browse files Browse the repository at this point in the history
  • Loading branch information
prbzrg committed May 14, 2024
1 parent 355fa2c commit fb7b747
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 11 deletions.
3 changes: 2 additions & 1 deletion src/ContinuousNormalizingFlows.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,10 @@ export construct,
PlanarLayer,
MulLayer

include(joinpath("layers", "base_layer.jl"))
include(joinpath("layers", "cond_layer.jl"))
include(joinpath("layers", "planar_layer.jl"))
include(joinpath("layers", "mul_layer.jl"))
include(joinpath("layers", "planar_layer.jl"))

include("types.jl")

Expand Down
11 changes: 11 additions & 0 deletions src/layers/base_layer.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
@inline function apply_act(::typeof(identity), x::Any)
x

Check warning on line 2 in src/layers/base_layer.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/base_layer.jl#L1-L2

Added lines #L1 - L2 were not covered by tests
end

@inline function apply_act(activation::Any, x::Number)
activation(x)
end

@inline function apply_act(activation::Any, x::AbstractArray)
activation.(x)
end
2 changes: 1 addition & 1 deletion src/layers/mul_layer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,5 @@ function LuxCore.outputsize(m::MulLayer)
end

@inline function (m::MulLayer)(x::AbstractVecOrMat, ps::Any, st::NamedTuple)
return Lux.apply_activation(m.activation, Octavian.matmul(ps.weight, x)), st
apply_act(m.activation, Octavian.matmul(ps.weight, x)), st

Check warning on line 34 in src/layers/mul_layer.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/mul_layer.jl#L34

Added line #L34 was not covered by tests
end
17 changes: 8 additions & 9 deletions src/layers/planar_layer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,34 +71,33 @@ function LuxCore.outputsize(m::PlanarLayer)
end

@inline function (m::PlanarLayer{true})(z::AbstractVector, ps::Any, st::NamedTuple)
ps.u * Lux.apply_bias_activation(m.activation, LinearAlgebra.dot(ps.w, z), only(ps.b)),
st
ps.u * apply_act(m.activation, LinearAlgebra.dot(ps.w, z) + only(ps.b)), st
end

@inline function (m::PlanarLayer{true})(z::AbstractMatrix, ps::Any, st::NamedTuple)
ps.u * Lux.apply_bias_activation(m.activation, transpose(ps.w) * z, only(ps.b)), st
ps.u * apply_act(m.activation, muladd(transpose(ps.w), z, only(ps.b))), st
end

@inline function (m::PlanarLayer{false})(z::AbstractVector, ps::Any, st::NamedTuple)
ps.u * Lux.apply_activation(m.activation, LinearAlgebra.dot(ps.w, z)), st
ps.u * apply_act(m.activation, LinearAlgebra.dot(ps.w, z)), st

Check warning on line 82 in src/layers/planar_layer.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/planar_layer.jl#L82

Added line #L82 was not covered by tests
end

@inline function (m::PlanarLayer{false})(z::AbstractMatrix, ps::Any, st::NamedTuple)
ps.u * Lux.apply_activation(m.activation, transpose(ps.w) * z), st
ps.u * apply_act(m.activation, transpose(ps.w) * z), st

Check warning on line 86 in src/layers/planar_layer.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/planar_layer.jl#L86

Added line #L86 was not covered by tests
end

@inline function pl_h(m::PlanarLayer{true}, z::AbstractVector, ps::Any, st::NamedTuple)
Lux.apply_bias_activation(m.activation, LinearAlgebra.dot(ps.w, z), only(ps.b)), st
apply_act(m.activation, LinearAlgebra.dot(ps.w, z) + only(ps.b)), st

Check warning on line 90 in src/layers/planar_layer.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/planar_layer.jl#L90

Added line #L90 was not covered by tests
end

@inline function pl_h(m::PlanarLayer{true}, z::AbstractMatrix, ps::Any, st::NamedTuple)
Lux.apply_bias_activation(m.activation, transpose(ps.w) * z, only(ps.b)), st
apply_act(m.activation, muladd(transpose(ps.w), z, only(ps.b))), st

Check warning on line 94 in src/layers/planar_layer.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/planar_layer.jl#L94

Added line #L94 was not covered by tests
end

@inline function pl_h(m::PlanarLayer{false}, z::AbstractVector, ps::Any, st::NamedTuple)
Lux.apply_activation(m.activation, LinearAlgebra.dot(ps.w, z)), st
apply_act(m.activation, LinearAlgebra.dot(ps.w, z)), st

Check warning on line 98 in src/layers/planar_layer.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/planar_layer.jl#L98

Added line #L98 was not covered by tests
end

@inline function pl_h(m::PlanarLayer{false}, z::AbstractMatrix, ps::Any, st::NamedTuple)
Lux.apply_activation(m.activation, transpose(ps.w) * z), st
apply_act(m.activation, transpose(ps.w) * z), st

Check warning on line 102 in src/layers/planar_layer.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/planar_layer.jl#L102

Added line #L102 was not covered by tests
end

0 comments on commit fb7b747

Please sign in to comment.