Skip to content

Commit

Permalink
migrate to lux 1
Browse files Browse the repository at this point in the history
  • Loading branch information
prbzrg committed Sep 8, 2024
1 parent c026259 commit 2e55ff4
Show file tree
Hide file tree
Showing 8 changed files with 46 additions and 41 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ Distributions = "0.25"
DistributionsAD = "0.6"
FillArrays = "1"
LinearAlgebra = "1"
Lux = "0.5, 1"
LuxCore = "0.1, 1"
Lux = "1"
LuxCore = "1"
MLJBase = "1"
MLJModelInterface = "1"
MLUtils = "0.4"
Expand Down
2 changes: 1 addition & 1 deletion benchmark/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ ADTypes = "1"
BenchmarkTools = "1"
ComponentArrays = "0.15"
DifferentiationInterface = "0.5"
Lux = "0.5, 1"
Lux = "1"
PkgBenchmark = "0.2"
StableRNGs = "1"
Zygote = "0.6"
Expand Down
4 changes: 2 additions & 2 deletions src/base_icnf.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
function construct(
aicnf::Type{<:AbstractICNF},
nn::LuxCore.AbstractExplicitLayer,
nn::LuxCore.AbstractLuxLayer,
nvars::Int,
naugmented::Int = 0;
data_type::Type{<:AbstractFloat} = Float32,
Expand Down Expand Up @@ -515,7 +515,7 @@ end
@inline function make_ode_func(
icnf::AbstractICNF{T, CM, INPLACE},
mode::Mode,
nn::LuxCore.AbstractExplicitLayer,
nn::LuxCore.AbstractLuxLayer,
st::NamedTuple,
ϵ::AbstractVecOrMat{T},
) where {T <: AbstractFloat, CM, INPLACE}
Expand Down
38 changes: 19 additions & 19 deletions src/icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ struct ICNF{
NORM_Z,
NORM_J,
NORM_Z_AUG,
NN <: LuxCore.AbstractExplicitLayer,
NN <: LuxCore.AbstractLuxLayer,
NVARS <: Int,
RESOURCE <: ComputationalResources.AbstractResource,
BASEDIST <: Distributions.Distribution,
Expand Down Expand Up @@ -113,7 +113,7 @@ function augmented_f(
::Any,
icnf::ICNF{T, <:ADVectorMode, false},
mode::TestMode,
nn::LuxCore.AbstractExplicitLayer,
nn::LuxCore.AbstractLuxLayer,
st::NamedTuple,
ϵ::AbstractVector{T},
) where {T <: AbstractFloat}
Expand All @@ -132,7 +132,7 @@ function augmented_f(
::Any,
icnf::ICNF{T, <:ADVectorMode, true},
mode::TestMode,
nn::LuxCore.AbstractExplicitLayer,
nn::LuxCore.AbstractLuxLayer,
st::NamedTuple,
ϵ::AbstractVector{T},
) where {T <: AbstractFloat}
Expand All @@ -151,7 +151,7 @@ function augmented_f(
::Any,
icnf::ICNF{T, <:DIVectorMode, false},
mode::TestMode,
nn::LuxCore.AbstractExplicitLayer,
nn::LuxCore.AbstractLuxLayer,
st::NamedTuple,
ϵ::AbstractVector{T},
) where {T <: AbstractFloat}
Expand All @@ -170,7 +170,7 @@ function augmented_f(
::Any,
icnf::ICNF{T, <:DIVectorMode, true},
mode::TestMode,
nn::LuxCore.AbstractExplicitLayer,
nn::LuxCore.AbstractLuxLayer,
st::NamedTuple,
ϵ::AbstractVector{T},
) where {T <: AbstractFloat}
Expand All @@ -189,7 +189,7 @@ function augmented_f(
::Any,
icnf::ICNF{T, <:MatrixMode, false},
mode::TestMode,
nn::LuxCore.AbstractExplicitLayer,
nn::LuxCore.AbstractLuxLayer,
st::NamedTuple,
ϵ::AbstractMatrix{T},
) where {T <: AbstractFloat}
Expand All @@ -208,7 +208,7 @@ function augmented_f(
::Any,
icnf::ICNF{T, <:MatrixMode, true},
mode::TestMode,
nn::LuxCore.AbstractExplicitLayer,
nn::LuxCore.AbstractLuxLayer,
st::NamedTuple,
ϵ::AbstractMatrix{T},
) where {T <: AbstractFloat}
Expand All @@ -227,7 +227,7 @@ function augmented_f(
::Any,
icnf::ICNF{T, <:ADVecJacVectorMode, false, COND, AUGMENTED, STEER, NORM_Z, NORM_J},
mode::TrainMode,
nn::LuxCore.AbstractExplicitLayer,
nn::LuxCore.AbstractLuxLayer,
st::NamedTuple,
ϵ::AbstractVector{T},
) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J}
Expand Down Expand Up @@ -261,7 +261,7 @@ function augmented_f(
::Any,
icnf::ICNF{T, <:ADVecJacVectorMode, true, COND, AUGMENTED, STEER, NORM_Z, NORM_J},
mode::TrainMode,
nn::LuxCore.AbstractExplicitLayer,
nn::LuxCore.AbstractLuxLayer,
st::NamedTuple,
ϵ::AbstractVector{T},
) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J}
Expand Down Expand Up @@ -295,7 +295,7 @@ function augmented_f(
::Any,
icnf::ICNF{T, <:ADJacVecVectorMode, false, COND, AUGMENTED, STEER, NORM_Z, NORM_J},
mode::TrainMode,
nn::LuxCore.AbstractExplicitLayer,
nn::LuxCore.AbstractLuxLayer,
st::NamedTuple,
ϵ::AbstractVector{T},
) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J}
Expand Down Expand Up @@ -330,7 +330,7 @@ function augmented_f(
::Any,
icnf::ICNF{T, <:ADJacVecVectorMode, true, COND, AUGMENTED, STEER, NORM_Z, NORM_J},
mode::TrainMode,
nn::LuxCore.AbstractExplicitLayer,
nn::LuxCore.AbstractLuxLayer,
st::NamedTuple,
ϵ::AbstractVector{T},
) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J}
Expand Down Expand Up @@ -365,7 +365,7 @@ function augmented_f(
::Any,
icnf::ICNF{T, <:DIVecJacVectorMode, false, COND, AUGMENTED, STEER, NORM_Z, NORM_J},
mode::TrainMode,
nn::LuxCore.AbstractExplicitLayer,
nn::LuxCore.AbstractLuxLayer,
st::NamedTuple,
ϵ::AbstractVector{T},
) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J}
Expand Down Expand Up @@ -395,7 +395,7 @@ function augmented_f(
::Any,
icnf::ICNF{T, <:DIVecJacVectorMode, true, COND, AUGMENTED, STEER, NORM_Z, NORM_J},
mode::TrainMode,
nn::LuxCore.AbstractExplicitLayer,
nn::LuxCore.AbstractLuxLayer,
st::NamedTuple,
ϵ::AbstractVector{T},
) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J}
Expand Down Expand Up @@ -425,7 +425,7 @@ function augmented_f(
::Any,
icnf::ICNF{T, <:DIJacVecVectorMode, false, COND, AUGMENTED, STEER, NORM_Z, NORM_J},
mode::TrainMode,
nn::LuxCore.AbstractExplicitLayer,
nn::LuxCore.AbstractLuxLayer,
st::NamedTuple,
ϵ::AbstractVector{T},
) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J}
Expand Down Expand Up @@ -455,7 +455,7 @@ function augmented_f(
::Any,
icnf::ICNF{T, <:DIJacVecVectorMode, true, COND, AUGMENTED, STEER, NORM_Z, NORM_J},
mode::TrainMode,
nn::LuxCore.AbstractExplicitLayer,
nn::LuxCore.AbstractLuxLayer,
st::NamedTuple,
ϵ::AbstractVector{T},
) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J}
Expand Down Expand Up @@ -485,7 +485,7 @@ function augmented_f(
::Any,
icnf::ICNF{T, <:DIVecJacMatrixMode, false, COND, AUGMENTED, STEER, NORM_Z, NORM_J},
mode::TrainMode,
nn::LuxCore.AbstractExplicitLayer,
nn::LuxCore.AbstractLuxLayer,
st::NamedTuple,
ϵ::AbstractMatrix{T},
) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J}
Expand Down Expand Up @@ -519,7 +519,7 @@ function augmented_f(
::Any,
icnf::ICNF{T, <:DIVecJacMatrixMode, true, COND, AUGMENTED, STEER, NORM_Z, NORM_J},
mode::TrainMode,
nn::LuxCore.AbstractExplicitLayer,
nn::LuxCore.AbstractLuxLayer,
st::NamedTuple,
ϵ::AbstractMatrix{T},
) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J}
Expand Down Expand Up @@ -549,7 +549,7 @@ function augmented_f(
::Any,
icnf::ICNF{T, <:DIJacVecMatrixMode, false, COND, AUGMENTED, STEER, NORM_Z, NORM_J},
mode::TrainMode,
nn::LuxCore.AbstractExplicitLayer,
nn::LuxCore.AbstractLuxLayer,
st::NamedTuple,
ϵ::AbstractMatrix{T},
) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J}
Expand Down Expand Up @@ -583,7 +583,7 @@ function augmented_f(
::Any,
icnf::ICNF{T, <:DIJacVecMatrixMode, true, COND, AUGMENTED, STEER, NORM_Z, NORM_J},
mode::TrainMode,
nn::LuxCore.AbstractExplicitLayer,
nn::LuxCore.AbstractLuxLayer,
st::NamedTuple,
ϵ::AbstractMatrix{T},
) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J}
Expand Down
4 changes: 2 additions & 2 deletions src/layers/cond_layer.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
struct CondLayer{NN <: LuxCore.AbstractExplicitLayer, AT <: AbstractArray} <:
LuxCore.AbstractExplicitContainerLayer{(:nn,)}
struct CondLayer{NN <: LuxCore.AbstractLuxLayer, AT <: AbstractArray} <:
LuxCore.AbstractLuxContainerLayer{(:nn,)}
nn::NN
ys::AT
end
Expand Down
31 changes: 18 additions & 13 deletions src/layers/planar_layer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ Implementation of Planar Layer from
[Chen, Ricky TQ, Yulia Rubanova, Jesse Bettencourt, and David Duvenaud. "Neural Ordinary Differential Equations." arXiv preprint arXiv:1806.07366 (2018).](https://arxiv.org/abs/1806.07366)
"""
struct PlanarLayer{use_bias, cond, F1, F2, F3, NVARS <: Int} <:
LuxCore.AbstractExplicitLayer
struct PlanarLayer{use_bias, cond, F1, F2, F3, NVARS <: Int} <: LuxCore.AbstractLuxLayer
activation::F1
nvars::NVARS
init_weight::F2
Expand All @@ -18,10 +17,8 @@ function PlanarLayer(
init_weight::Any = Lux.glorot_uniform,
init_bias::Any = Lux.zeros32,
use_bias::Bool = true,
allow_fast_activation::Bool = true,
n_cond::Int = 0,
)
activation = ifelse(allow_fast_activation, NNlib.fast_act(activation), activation)
PlanarLayer{
use_bias,
!iszero(n_cond),
Expand Down Expand Up @@ -66,38 +63,46 @@ function LuxCore.parameterlength(m::PlanarLayer{use_bias, cond}) where {use_bias
m.nvars + ifelse(cond, (m.nvars + m.n_cond), m.nvars) + ifelse(use_bias, 1, 0)
end

function LuxCore.outputsize(m::PlanarLayer)
function LuxCore.outputsize(m::PlanarLayer, ::Any, ::AbstractRNG)
(m.nvars,)
end

@inline function (m::PlanarLayer{true})(z::AbstractVector, ps::Any, st::NamedTuple)
ps.u * m.activation.(LinearAlgebra.dot(ps.w, z) + only(ps.b)), st
activation = NNlib.fast_act(m.activation, z)
ps.u * activation.(LinearAlgebra.dot(ps.w, z) + only(ps.b)), st
end

@inline function (m::PlanarLayer{true})(z::AbstractMatrix, ps::Any, st::NamedTuple)
ps.u * m.activation.(muladd(transpose(ps.w), z, only(ps.b))), st
activation = NNlib.fast_act(m.activation, z)
ps.u * activation.(muladd(transpose(ps.w), z, only(ps.b))), st
end

@inline function (m::PlanarLayer{false})(z::AbstractVector, ps::Any, st::NamedTuple)
ps.u * m.activation.(LinearAlgebra.dot(ps.w, z)), st
activation = NNlib.fast_act(m.activation, z)
ps.u * activation.(LinearAlgebra.dot(ps.w, z)), st
end

@inline function (m::PlanarLayer{false})(z::AbstractMatrix, ps::Any, st::NamedTuple)
ps.u * m.activation.(transpose(ps.w) * z), st
activation = NNlib.fast_act(m.activation, z)
ps.u * activation.(transpose(ps.w) * z), st
end

@inline function pl_h(m::PlanarLayer{true}, z::AbstractVector, ps::Any, st::NamedTuple)
m.activation.(LinearAlgebra.dot(ps.w, z) + only(ps.b)), st
activation = NNlib.fast_act(m.activation, z)
activation.(LinearAlgebra.dot(ps.w, z) + only(ps.b)), st
end

@inline function pl_h(m::PlanarLayer{true}, z::AbstractMatrix, ps::Any, st::NamedTuple)
m.activation.(muladd(transpose(ps.w), z, only(ps.b))), st
activation = NNlib.fast_act(m.activation, z)
activation.(muladd(transpose(ps.w), z, only(ps.b))), st
end

@inline function pl_h(m::PlanarLayer{false}, z::AbstractVector, ps::Any, st::NamedTuple)
m.activation.(LinearAlgebra.dot(ps.w, z)), st
activation = NNlib.fast_act(m.activation, z)
activation.(LinearAlgebra.dot(ps.w, z)), st
end

@inline function pl_h(m::PlanarLayer{false}, z::AbstractMatrix, ps::Any, st::NamedTuple)
m.activation.(transpose(ps.w) * z), st
activation = NNlib.fast_act(m.activation, z)
activation.(transpose(ps.w) * z), st
end
2 changes: 1 addition & 1 deletion src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,6 @@ abstract type AbstractICNF{
AUGMENTED,
STEER,
NORM_Z_AUG,
} <: LuxCore.AbstractExplicitContainerLayer{(:nn,)} end
} <: LuxCore.AbstractLuxContainerLayer{(:nn,)} end

abstract type MLJICNF{AICNF <: AbstractICNF} <: MLJModelInterface.Unsupervised end
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ DifferentiationInterface = "0.5"
Distributions = "0.25"
GPUArraysCore = "0.1"
JET = "0.9"
Lux = "0.5, 1"
Lux = "1"
LuxCUDA = "0.3"
MLJBase = "1"
SciMLBase = "2"
Expand Down

0 comments on commit 2e55ff4

Please sign in to comment.