Skip to content

Commit

Permalink
change types in models (#267)
Browse files Browse the repository at this point in the history
  • Loading branch information
prbzrg authored Aug 11, 2023
1 parent e68469c commit 01f16ad
Show file tree
Hide file tree
Showing 12 changed files with 161 additions and 77 deletions.
16 changes: 13 additions & 3 deletions src/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ export construct
function construct(
aicnf::Type{<:AbstractFlows},
nn,
nvars::Integer,
naugmented::Integer = 0;
nvars::Int,
naugmented::Int = 0;
data_type::Type{<:AbstractFloat} = Float32,
compute_mode::Type{<:ComputeMode} = ADVectorMode,
resource::AbstractResource = CPU1(),
Expand All @@ -23,7 +23,17 @@ function construct(
)
steerdist = Uniform{data_type}(-steer_rate, steer_rate)

aicnf{data_type, compute_mode, !iszero(naugmented), !iszero(steer_rate)}(
aicnf{
data_type,
compute_mode,
!iszero(naugmented),
!iszero(steer_rate),
typeof(resource),
typeof(basedist),
typeof(tspan),
typeof(steerdist),
typeof(differentiation_backend),
}(
nn,
nvars,
naugmented,
Expand Down
4 changes: 2 additions & 2 deletions src/base_cond_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ function generate_prob(
ys::AbstractMatrix{<:Real},
ps::Any,
st::Any,
n::Integer;
n::Int;
resource::AbstractResource = icnf.resource,
tspan::NTuple{2} = icnf.tspan,
steerdist::Distribution = icnf.steerdist,
Expand Down Expand Up @@ -310,7 +310,7 @@ function generate(
ys::AbstractMatrix{<:Real},
ps::Any,
st::Any,
n::Integer;
n::Int;
resource::AbstractResource = icnf.resource,
tspan::NTuple{2} = icnf.tspan,
steerdist::Distribution = icnf.steerdist,
Expand Down
4 changes: 2 additions & 2 deletions src/base_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ function generate_prob(
mode::Mode,
ps::Any,
st::Any,
n::Integer;
n::Int;
resource::AbstractResource = icnf.resource,
tspan::NTuple{2} = icnf.tspan,
steerdist::Distribution = icnf.steerdist,
Expand Down Expand Up @@ -294,7 +294,7 @@ function generate(
mode::Mode,
ps::Any,
st::Any,
n::Integer;
n::Int;
resource::AbstractResource = icnf.resource,
tspan::NTuple{2} = icnf.tspan,
steerdist::Distribution = icnf.steerdist,
Expand Down
27 changes: 18 additions & 9 deletions src/cond_ffjord.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,26 @@ export CondFFJORD
"""
Implementation of FFJORD (Conditional Version)
"""
struct CondFFJORD{T <: AbstractFloat, CM <: ComputeMode, AUGMENTED, STEER} <:
AbstractCondICNF{T, CM, AUGMENTED, STEER}
struct CondFFJORD{
T <: AbstractFloat,
CM <: ComputeMode,
AUGMENTED,
STEER,
RESOURCE <: AbstractResource,
BASEDIST <: Distribution,
TSPAN <: NTuple{2, T},
STEERDIST <: Distribution,
DIFFERENTIATION_BACKEND <: AbstractDifferentiation.AbstractBackend,
} <: AbstractCondICNF{T, CM, AUGMENTED, STEER}
nn::LuxCore.AbstractExplicitLayer
nvars::Integer
naugmented::Integer
nvars::Int
naugmented::Int

resource::AbstractResource
basedist::Distribution
tspan::NTuple{2, T}
steerdist::Distribution
differentiation_backend::AbstractDifferentiation.AbstractBackend
resource::RESOURCE
basedist::BASEDIST
tspan::TSPAN
steerdist::STEERDIST
differentiation_backend::DIFFERENTIATION_BACKEND
sol_args::Tuple
sol_kwargs::Dict
end
Expand Down
27 changes: 18 additions & 9 deletions src/cond_planar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,26 @@ export CondPlanar
"""
Implementation of Planar (Conditional Version)
"""
struct CondPlanar{T <: AbstractFloat, CM <: ComputeMode, AUGMENTED, STEER} <:
AbstractCondICNF{T, CM, AUGMENTED, STEER}
struct CondPlanar{
T <: AbstractFloat,
CM <: ComputeMode,
AUGMENTED,
STEER,
RESOURCE <: AbstractResource,
BASEDIST <: Distribution,
TSPAN <: NTuple{2, T},
STEERDIST <: Distribution,
DIFFERENTIATION_BACKEND <: AbstractDifferentiation.AbstractBackend,
} <: AbstractCondICNF{T, CM, AUGMENTED, STEER}
nn::PlanarLayer
nvars::Integer
naugmented::Integer
nvars::Int
naugmented::Int

resource::AbstractResource
basedist::Distribution
tspan::NTuple{2, T}
steerdist::Distribution
differentiation_backend::AbstractDifferentiation.AbstractBackend
resource::RESOURCE
basedist::BASEDIST
tspan::TSPAN
steerdist::STEERDIST
differentiation_backend::DIFFERENTIATION_BACKEND
sol_args::Tuple
sol_kwargs::Dict
end
Expand Down
43 changes: 31 additions & 12 deletions src/cond_rnode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,26 @@ export CondRNODE
"""
Implementation of RNODE (Conditional Version)
"""
struct CondRNODE{T <: AbstractFloat, CM <: ComputeMode, AUGMENTED, STEER} <:
AbstractCondICNF{T, CM, AUGMENTED, STEER}
struct CondRNODE{
T <: AbstractFloat,
CM <: ComputeMode,
AUGMENTED,
STEER,
RESOURCE <: AbstractResource,
BASEDIST <: Distribution,
TSPAN <: NTuple{2, T},
STEERDIST <: Distribution,
DIFFERENTIATION_BACKEND <: AbstractDifferentiation.AbstractBackend,
} <: AbstractCondICNF{T, CM, AUGMENTED, STEER}
nn::LuxCore.AbstractExplicitLayer
nvars::Integer
naugmented::Integer
nvars::Int
naugmented::Int

resource::AbstractResource
basedist::Distribution
tspan::NTuple{2, T}
steerdist::Distribution
differentiation_backend::AbstractDifferentiation.AbstractBackend
resource::RESOURCE
basedist::BASEDIST
tspan::TSPAN
steerdist::STEERDIST
differentiation_backend::DIFFERENTIATION_BACKEND
sol_args::Tuple
sol_kwargs::Dict
λ₁::T
Expand All @@ -23,8 +32,8 @@ end
function construct(
aicnf::Type{<:CondRNODE},
nn,
nvars::Integer,
naugmented::Integer = 0;
nvars::Int,
naugmented::Int = 0;
data_type::Type{<:AbstractFloat} = Float32,
compute_mode::Type{<:ComputeMode} = ADVectorMode,
resource::AbstractResource = CPU1(),
Expand All @@ -45,7 +54,17 @@ function construct(
)
steerdist = Uniform{data_type}(-steer_rate, steer_rate)

aicnf{data_type, compute_mode, !iszero(naugmented), !iszero(steer_rate)}(
aicnf{
data_type,
compute_mode,
!iszero(naugmented),
!iszero(steer_rate),
typeof(resource),
typeof(basedist),
typeof(tspan),
typeof(steerdist),
typeof(differentiation_backend),
}(
nn,
nvars,
naugmented,
Expand Down
8 changes: 4 additions & 4 deletions src/core_cond_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ mutable struct CondICNFModel <: MLJICNF
loss::Function

optimizers::AbstractVector
n_epochs::Integer
n_epochs::Int
adtype::ADTypes.AbstractADType

use_batch::Bool
batch_size::Integer
batch_size::Int
have_callback::Bool

data_type::Type{<:AbstractFloat}
Expand All @@ -22,10 +22,10 @@ function CondICNFModel(
m::AbstractCondICNF{T, CM},
loss::Function = loss;
optimizers::AbstractVector = Any[Optimisers.Lion(),],
n_epochs::Integer = 300,
n_epochs::Int = 300,
adtype::ADTypes.AbstractADType = AutoZygote(),
use_batch::Bool = true,
batch_size::Integer = 32,
batch_size::Int = 32,
have_callback::Bool = true,
) where {T <: AbstractFloat, CM <: ComputeMode}
CondICNFModel(
Expand Down
8 changes: 4 additions & 4 deletions src/core_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ mutable struct ICNFModel <: MLJICNF
loss::Function

optimizers::AbstractVector
n_epochs::Integer
n_epochs::Int
adtype::ADTypes.AbstractADType

use_batch::Bool
batch_size::Integer
batch_size::Int
have_callback::Bool

data_type::Type{<:AbstractFloat}
Expand All @@ -22,10 +22,10 @@ function ICNFModel(
m::AbstractICNF{T, CM},
loss::Function = loss;
optimizers::AbstractVector = Any[Optimisers.Lion(),],
n_epochs::Integer = 300,
n_epochs::Int = 300,
adtype::ADTypes.AbstractADType = AutoZygote(),
use_batch::Bool = true,
batch_size::Integer = 32,
batch_size::Int = 32,
have_callback::Bool = true,
) where {T <: AbstractFloat, CM <: ComputeMode}
ICNFModel(
Expand Down
27 changes: 18 additions & 9 deletions src/ffjord.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,26 @@ Implementation of FFJORD from
[Grathwohl, Will, Ricky TQ Chen, Jesse Bettencourt, Ilya Sutskever, and David Duvenaud. "Ffjord: Free-form continuous dynamics for scalable reversible generative models." arXiv preprint arXiv:1810.01367 (2018).](https://arxiv.org/abs/1810.01367)
"""
struct FFJORD{T <: AbstractFloat, CM <: ComputeMode, AUGMENTED, STEER} <:
AbstractICNF{T, CM, AUGMENTED, STEER}
struct FFJORD{
T <: AbstractFloat,
CM <: ComputeMode,
AUGMENTED,
STEER,
RESOURCE <: AbstractResource,
BASEDIST <: Distribution,
TSPAN <: NTuple{2, T},
STEERDIST <: Distribution,
DIFFERENTIATION_BACKEND <: AbstractDifferentiation.AbstractBackend,
} <: AbstractICNF{T, CM, AUGMENTED, STEER}
nn::LuxCore.AbstractExplicitLayer
nvars::Integer
naugmented::Integer
nvars::Int
naugmented::Int

resource::AbstractResource
basedist::Distribution
tspan::NTuple{2, T}
steerdist::Distribution
differentiation_backend::AbstractDifferentiation.AbstractBackend
resource::RESOURCE
basedist::BASEDIST
tspan::TSPAN
steerdist::STEERDIST
differentiation_backend::DIFFERENTIATION_BACKEND
sol_args::Tuple
sol_kwargs::Dict
end
Expand Down
27 changes: 18 additions & 9 deletions src/planar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,26 @@ Implementation of Planar Flows from
[Chen, Ricky TQ, Yulia Rubanova, Jesse Bettencourt, and David Duvenaud. "Neural Ordinary Differential Equations." arXiv preprint arXiv:1806.07366 (2018).](https://arxiv.org/abs/1806.07366)
"""
struct Planar{T <: AbstractFloat, CM <: ComputeMode, AUGMENTED, STEER} <:
AbstractICNF{T, CM, AUGMENTED, STEER}
struct Planar{
T <: AbstractFloat,
CM <: ComputeMode,
AUGMENTED,
STEER,
RESOURCE <: AbstractResource,
BASEDIST <: Distribution,
TSPAN <: NTuple{2, T},
STEERDIST <: Distribution,
DIFFERENTIATION_BACKEND <: AbstractDifferentiation.AbstractBackend,
} <: AbstractICNF{T, CM, AUGMENTED, STEER}
nn::PlanarLayer
nvars::Integer
naugmented::Integer
nvars::Int
naugmented::Int

resource::AbstractResource
basedist::Distribution
tspan::NTuple{2, T}
steerdist::Distribution
differentiation_backend::AbstractDifferentiation.AbstractBackend
resource::RESOURCE
basedist::BASEDIST
tspan::TSPAN
steerdist::STEERDIST
differentiation_backend::DIFFERENTIATION_BACKEND
sol_args::Tuple
sol_kwargs::Dict
end
Expand Down
4 changes: 2 additions & 2 deletions src/planar_layer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ struct PlanarLayer{use_bias, cond, F1, F2, F3} <: LuxCore.AbstractExplicitLayer
nvars::Int
init_weight::F2
init_bias::F3
n_cond::Integer
n_cond::Int
end

function PlanarLayer(
Expand All @@ -21,7 +21,7 @@ function PlanarLayer(
use_bias::Bool = true,
allow_fast_activation::Bool = true,
cond::Bool = false,
n_cond::Integer = 0,
n_cond::Int = 0,
)
activation = allow_fast_activation ? NNlib.fast_act(activation) : activation
PlanarLayer{use_bias, cond, typeof(activation), typeof(init_weight), typeof(init_bias)}(
Expand Down
Loading

0 comments on commit 01f16ad

Please sign in to comment.