From a1066bab03dc95b3b0916ae7bec6bc0c344fc31e Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Sat, 13 Jan 2024 19:18:40 +0100 Subject: [PATCH 01/26] initial work for initializers --- Project.toml | 1 + src/esn/esn_reservoirs.jl | 423 +++++++------------------------------- test/runtests.jl | 36 +++- 3 files changed, 97 insertions(+), 363 deletions(-) diff --git a/Project.toml b/Project.toml index b3207f6e..b003aecd 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLJLinearModels = "6ee0df7b-362f-4a72-a706-9e79364fb692" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Optim = "429524aa-4258-5aef-a3af-852621145aeb" +PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" diff --git a/src/esn/esn_reservoirs.jl b/src/esn/esn_reservoirs.jl index e014e4e7..047fb0bf 100644 --- a/src/esn/esn_reservoirs.jl +++ b/src/esn/esn_reservoirs.jl @@ -1,393 +1,108 @@ -abstract type AbstractReservoir end - -function get_ressize(reservoir::AbstractReservoir) - return reservoir.res_size -end - function get_ressize(reservoir) return size(reservoir, 1) end -struct RandSparseReservoir{T, C} <: AbstractReservoir - res_size::Int - radius::T - sparsity::C -end - -""" - RandSparseReservoir(res_size, radius, sparsity) - RandSparseReservoir(res_size; radius=1.0, sparsity=0.1) - - -Returns a random sparse reservoir initializer, which generates a matrix of size `res_size x res_size` with the specified `sparsity` and scaled spectral radius according to `radius`. This type of reservoir initializer is commonly used in Echo State Networks (ESNs) for capturing complex temporal dependencies. - -# Arguments -- `res_size`: The size of the reservoir matrix. -- `radius`: The desired spectral radius of the reservoir. By default, it is set to 1.0. -- `sparsity`: The sparsity level of the reservoir matrix, controlling the fraction of zero elements. By default, it is set to 0.1. - -# Returns -A RandSparseReservoir object that can be used as a reservoir initializer in ESN construction. - -# References -This type of reservoir initialization is a common choice in ESN construction for its ability to capture temporal dependencies in data. However, there is no specific reference associated with this function. -""" -function RandSparseReservoir(res_size; radius = 1.0, sparsity = 0.1) - return RandSparseReservoir(res_size, radius, sparsity) -end - """ - create_reservoir(reservoir::AbstractReservoir, res_size) - create_reservoir(reservoir, args...) + rand_sparse(rng::AbstractRNG, ::Type{T}, dims::Integer...; radius=1.0, sparsity=0.1) -Given an `AbstractReservoir` constructor and the size of the reservoir (`res_size`), this function returns the corresponding reservoir matrix. Alternatively, it accepts a pre-generated matrix. +Create and return a random sparse reservoir matrix for use in Echo State Networks (ESNs). The matrix will be of size specified by `dims`, with specified `sparsity` and scaled spectral radius according to `radius`. # Arguments -- `reservoir`: An `AbstractReservoir` object or constructor. -- `res_size`: The size of the reservoir matrix. -- `matrix_type`: The type of the resulting matrix. By default, it is set to `Matrix{Float64}`. +- `rng`: An instance of `AbstractRNG` for random number generation. +- `T`: The data type for the elements of the matrix. +- `dims`: Dimensions of the reservoir matrix. +- `radius`: The desired spectral radius of the reservoir. Defaults to 1.0. +- `sparsity`: The sparsity level of the reservoir matrix, controlling the fraction of zero elements. Defaults to 0.1. # Returns -A matrix representing the reservoir, generated based on the properties of the specified `reservoir` object or constructor. +A matrix representing the random sparse reservoir. # References -The choice of reservoir initialization is crucial in Echo State Networks (ESNs) for achieving effective temporal modeling. Specific references for reservoir initialization methods may vary based on the type of reservoir used, but the practice of initializing reservoirs for ESNs is widely documented in the ESN literature. -""" -function create_reservoir(reservoir::RandSparseReservoir, - res_size; - matrix_type = Matrix{Float64}) - reservoir_matrix = Matrix(sprand(res_size, res_size, reservoir.sparsity)) +This type of reservoir initialization is commonly used in ESNs for capturing temporal dependencies in data. +""" +function rand_sparse(rng::AbstractRNG, + ::Type{T}, + dims::Integer...; + radius = 1.0, + sparsity = 0.1) where {T <: Number} + reservoir_matrix = Matrix{T}(sprand(rng, dims..., sparsity)) reservoir_matrix = 2.0 .* (reservoir_matrix .- 0.5) replace!(reservoir_matrix, -1.0 => 0.0) rho_w = maximum(abs.(eigvals(reservoir_matrix))) - reservoir_matrix .*= reservoir.radius / rho_w - #TODO: change to explicit if - Inf in unique(reservoir_matrix) || -Inf in unique(reservoir_matrix) ? - error("Sparsity too low for size of the matrix. - Increase res_size or increase sparsity") : nothing - return Adapt.adapt(matrix_type, reservoir_matrix) -end - -function create_reservoir(reservoir, args...; kwargs...) - return reservoir -end - -#= -function create_reservoir(res_size, reservoir::RandReservoir) - sparsity = degree/res_size - W = Matrix(sprand(Float64, res_size, res_size, sparsity)) - W = 2.0 .*(W.-0.5) - replace!(W, -1.0=>0.0) - rho_w = maximum(abs.(eigvals(W))) - W .*= radius/rho_w - W -end -=# - -struct PseudoSVDReservoir{T, C} <: AbstractReservoir - res_size::Int - max_value::T - sparsity::C - sorted::Bool - reverse_sort::Bool -end - -function PseudoSVDReservoir(res_size; - max_value = 1.0, - sparsity = 0.1, - sorted = true, - reverse_sort = false) - return PseudoSVDReservoir(res_size, max_value, sparsity, sorted, reverse_sort) + reservoir_matrix .*= radius / rho_w + if Inf in unique(reservoir_matrix) || -Inf in unique(reservoir_matrix) + error("Sparsity too low for size of the matrix. Increase res_size or increase sparsity") + end + return reservoir_matrix end """ - PseudoSVDReservoir(max_value, sparsity, sorted, reverse_sort) - PseudoSVDReservoir(max_value, sparsity; sorted=true, reverse_sort=false) + delay_line(rng::AbstractRNG, ::Type{T}, dims::Integer...; weight=0.1) where {T <: Number} -Returns an initializer to build a sparse reservoir matrix with the given `sparsity` by using a pseudo-SVD approach as described in [^yang]. +Create and return a delay line reservoir matrix for use in Echo State Networks (ESNs). A delay line reservoir is a deterministic structure where each unit is connected only to its immediate predecessor with a specified weight. This method is particularly useful for tasks that require specific temporal processing. # Arguments -- `res_size`: The size of the reservoir matrix. -- `max_value`: The maximum absolute value of elements in the matrix. -- `sparsity`: The desired sparsity level of the reservoir matrix. -- `sorted`: A boolean indicating whether to sort the singular values before creating the diagonal matrix. By default, it is set to `true`. -- `reverse_sort`: A boolean indicating whether to reverse the sorted singular values. By default, it is set to `false`. +- `rng`: An instance of `AbstractRNG` for random number generation. This argument is not used in the current implementation but is included for consistency with other initialization functions. +- `T`: The data type for the elements of the matrix. +- `dims`: Dimensions of the reservoir matrix. Typically, this should be a tuple of two equal integers representing a square matrix. +- `weight`: The weight determines the absolute value of all connections in the reservoir. Defaults to 0.1. # Returns -A PseudoSVDReservoir object that can be used as a reservoir initializer in ESN construction. - -# References -This reservoir initialization method, based on a pseudo-SVD approach, is inspired by the work in [^yang], which focuses on designing polynomial echo state networks for time series prediction. - -[^yang]: Yang, Cuili, et al. "_Design of polynomial echo state networks for time series prediction._" Neurocomputing 290 (2018): 148-160. -""" -function PseudoSVDReservoir(res_size, max_value, sparsity; sorted = true, - reverse_sort = false) - return PseudoSVDReservoir(res_size, max_value, sparsity, sorted, reverse_sort) -end +A delay line reservoir matrix with dimensions specified by `dims`. The matrix is initialized such that each element in the `i+1`th row and `i`th column is set to `weight`, and all other elements are zeros. -function create_reservoir(reservoir::PseudoSVDReservoir, - res_size; - matrix_type = Matrix{Float64}) - sorted = reservoir.sorted - reverse_sort = reservoir.reverse_sort - reservoir_matrix = create_diag(res_size, reservoir.max_value, sorted = sorted, - reverse_sort = reverse_sort) - tmp_sparsity = get_sparsity(reservoir_matrix, res_size) +# Example +```julia +reservoir = delay_line(Float64, 100, 100; weight=0.2) +``` - while tmp_sparsity <= reservoir.sparsity - reservoir_matrix *= create_qmatrix(res_size, rand(1:res_size), rand(1:res_size), - rand() * 2 - 1) - tmp_sparsity = get_sparsity(reservoir_matrix, res_size) +# References +This type of reservoir initialization is described in: +Rodan, Ali, and Peter Tino. "Minimum complexity echo state network." IEEE Transactions on Neural Networks 22.1 (2010): 131-144. +""" +function delay_line(rng::AbstractRNG, + ::Type{T}, + dims::Integer...; + weight = 0.1) where {T <: Number} + reservoir_matrix = zeros(T, dims...) + @assert length(dims) == 2 && dims[1] == dims[2], + "The dimensions must define a square matrix (e.g., (100, 100))" + + for i in 1:(dims[1] - 1) + reservoir_matrix[i + 1, i] = weight end - return Adapt.adapt(matrix_type, reservoir_matrix) + return reservoir_matrix end -function create_diag(dim, max_value; sorted = true, reverse_sort = false) - diagonal_matrix = zeros(dim, dim) - if sorted == true - if reverse_sort == true - diagonal_values = sort(rand(dim) .* max_value, rev = true) - diagonal_values[1] = max_value - else - diagonal_values = sort(rand(dim) .* max_value) - diagonal_values[end] = max_value - end - else - diagonal_values = rand(dim) .* max_value +for initializer in (:rand_sparse, :delay_line) + NType = ifelse(initializer === :rand_sparse, Real, Number) + @eval function ($initializer)(dims::Integer...; kwargs...) + return $initializer(_default_rng(), Float32, dims...; kwargs...) end - - for i in 1:dim - diagonal_matrix[i, i] = diagonal_values[i] + @eval function ($initializer)(rng::AbstractRNG, dims::Integer...; kwargs...) + return $initializer(rng, Float32, dims...; kwargs...) end - - return diagonal_matrix -end - -function create_qmatrix(dim, coord_i, coord_j, theta) - qmatrix = zeros(dim, dim) - - for i in 1:dim - qmatrix[i, i] = 1.0 + @eval function ($initializer)(::Type{T}, + dims::Integer...; kwargs...) where {T <: $NType} + return $initializer(_default_rng(), T, dims...; kwargs...) end - - qmatrix[coord_i, coord_i] = cos(theta) - qmatrix[coord_j, coord_j] = cos(theta) - qmatrix[coord_i, coord_j] = -sin(theta) - qmatrix[coord_j, coord_i] = sin(theta) - return qmatrix -end - -function get_sparsity(M, dim) - return size(M[M .!= 0], 1) / (dim * dim - size(M[M .!= 0], 1)) #nonzero/zero elements -end - -#from "minimum complexity echo state network" Rodan -# Delay Line Reservoir - -struct DelayLineReservoir{T} <: AbstractReservoir - res_size::Int - weight::T -end - -""" - DelayLineReservoir(res_size, weight) - DelayLineReservoir(res_size; weight=0.1) - -Returns a Delay Line Reservoir matrix constructor to obtain a deterministic reservoir as -described in [^Rodan2010]. - -# Arguments -- `res_size::Int`: The size of the reservoir. -- `weight::T`: The weight determines the absolute value of all the connections in the reservoir. - -# Returns -A `DelayLineReservoir` object. - -# References -[^Rodan2010]: Rodan, Ali, and Peter Tino. "Minimum complexity echo state network." -IEEE transactions on neural networks 22.1 (2010): 131-144. -""" -function DelayLineReservoir(res_size; weight = 0.1) - return DelayLineReservoir(res_size, weight) -end - -function create_reservoir(reservoir::DelayLineReservoir, - res_size; - matrix_type = Matrix{Float64}) - reservoir_matrix = zeros(res_size, res_size) - - for i in 1:(res_size - 1) - reservoir_matrix[i + 1, i] = reservoir.weight + @eval function ($initializer)(rng::AbstractRNG; kwargs...) + return __partial_apply($initializer, (rng, (; kwargs...))) end - - return Adapt.adapt(matrix_type, reservoir_matrix) -end - -#from "minimum complexity echo state network" Rodan -# Delay Line Reservoir with backward connections -struct DelayLineBackwardReservoir{T} <: AbstractReservoir - res_size::Int - weight::T - fb_weight::T -end - -""" - DelayLineBackwardReservoir(res_size, weight, fb_weight) - DelayLineBackwardReservoir(res_size; weight=0.1, fb_weight=0.2) - -Returns a Delay Line Reservoir constructor to create a matrix with backward connections -as described in [^Rodan2010]. The `weight` and `fb_weight` can be passed as either arguments or -keyword arguments, and they determine the absolute values of the connections in the reservoir. - -# Arguments -- `res_size::Int`: The size of the reservoir. -- `weight::T`: The weight determines the absolute value of forward connections in the reservoir. -- `fb_weight::T`: The `fb_weight` determines the absolute value of backward connections in the reservoir. - -# Returns -A `DelayLineBackwardReservoir` object. - -# References -[^Rodan2010]: Rodan, Ali, and Peter Tino. "Minimum complexity echo state network." -IEEE transactions on neural networks 22.1 (2010): 131-144. -""" -function DelayLineBackwardReservoir(res_size; weight = 0.1, fb_weight = 0.2) - return DelayLineBackwardReservoir(res_size, weight, fb_weight) -end - -function create_reservoir(reservoir::DelayLineBackwardReservoir, - res_size; - matrix_type = Matrix{Float64}) - reservoir_matrix = zeros(res_size, res_size) - - for i in 1:(res_size - 1) - reservoir_matrix[i + 1, i] = reservoir.weight - reservoir_matrix[i, i + 1] = reservoir.fb_weight - end - - return Adapt.adapt(matrix_type, reservoir_matrix) -end - -#from "minimum complexity echo state network" Rodan -# Simple cycle reservoir -struct SimpleCycleReservoir{T} <: AbstractReservoir - res_size::Int - weight::T -end - -""" - SimpleCycleReservoir(res_size, weight) - SimpleCycleReservoir(res_size; weight=0.1) - -Returns a Simple Cycle Reservoir constructor to build a reservoir matrix as -described in [^Rodan2010]. The `weight` can be passed as an argument or a keyword argument, and it determines the -absolute value of all the connections in the reservoir. - -# Arguments -- `res_size::Int`: The size of the reservoir. -- `weight::T`: The weight determines the absolute value of connections in the reservoir. - -# Returns -A `SimpleCycleReservoir` object. - -# References -[^Rodan2010]: Rodan, Ali, and Peter Tino. "Minimum complexity echo state network." -IEEE transactions on neural networks 22.1 (2010): 131-144. -""" -function SimpleCycleReservoir(res_size; weight = 0.1) - return SimpleCycleReservoir(res_size, weight) -end - -function create_reservoir(reservoir::SimpleCycleReservoir, - res_size; - matrix_type = Matrix{Float64}) - reservoir_matrix = zeros(Float64, res_size, res_size) - - for i in 1:(res_size - 1) - reservoir_matrix[i + 1, i] = reservoir.weight + @eval function ($initializer)(rng::AbstractRNG, + ::Type{T}; kwargs...) where {T <: $NType} + return __partial_apply($initializer, ((rng, T), (; kwargs...))) end - - reservoir_matrix[1, res_size] = reservoir.weight - return Adapt.adapt(matrix_type, reservoir_matrix) -end - -#from "simple deterministically constructed cycle reservoirs with regular jumps" by Rodan and Tino -# Cycle Reservoir with Jumps -struct CycleJumpsReservoir{T} <: AbstractReservoir - res_size::Int - cycle_weight::T - jump_weight::T - jump_size::Int -end - -""" - CycleJumpsReservoir(res_size; cycle_weight=0.1, jump_weight=0.1, jump_size=3) - CycleJumpsReservoir(res_size, cycle_weight, jump_weight, jump_size) - -Return a Cycle Reservoir with Jumps constructor to create a reservoir matrix as described -in [^Rodan2012]. The `cycle_weight`, `jump_weight`, and `jump_size` can be passed as arguments or keyword arguments, and they -determine the absolute values of connections in the reservoir. The `jump_size` determines the jumps between `jump_weight`s. - -# Arguments -- `res_size::Int`: The size of the reservoir. -- `cycle_weight::T`: The weight of cycle connections. -- `jump_weight::T`: The weight of jump connections. -- `jump_size::Int`: The number of steps between jump connections. - -# Returns -A `CycleJumpsReservoir` object. - -# References -[^Rodan2012]: Rodan, Ali, and Peter Tiňo. "Simple deterministically constructed cycle reservoirs -with regular jumps." Neural computation 24.7 (2012): 1822-1852. -""" -function CycleJumpsReservoir(res_size; cycle_weight = 0.1, jump_weight = 0.1, jump_size = 3) - return CycleJumpsReservoir(res_size, cycle_weight, jump_weight, jump_size) + @eval ($initializer)(; kwargs...) = __partial_apply($initializer, (; kwargs...)) end -function create_reservoir(reservoir::CycleJumpsReservoir, - res_size; - matrix_type = Matrix{Float64}) - reservoir_matrix = zeros(res_size, res_size) - - for i in 1:(res_size - 1) - reservoir_matrix[i + 1, i] = reservoir.cycle_weight - end - - reservoir_matrix[1, res_size] = reservoir.cycle_weight - - for i in 1:(reservoir.jump_size):(res_size - reservoir.jump_size) - tmp = (i + reservoir.jump_size) % res_size - if tmp == 0 - tmp = res_size - end - reservoir_matrix[i, tmp] = reservoir.jump_weight - reservoir_matrix[tmp, i] = reservoir.jump_weight +# from WeightInitializers.jl, TODO @MartinuzziFrancesco consider importing package +function _default_rng() + @static if VERSION >= v"1.7" + return Xoshiro(1234) + else + return MersenneTwister(1234) end - - return Adapt.adapt(matrix_type, reservoir_matrix) end -""" - NullReservoir() - -Return a constructor for a matrix of zeros with dimensions `res_size x res_size`. - -# Arguments -- None - -# Returns -A `NullReservoir` object. - -# References -- None -""" -struct NullReservoir <: AbstractReservoir end - -function create_reservoir(reservoir::NullReservoir, - res_size; - matrix_type = Matrix{Float64}) - return Adapt.adapt(matrix_type, zeros(res_size, res_size)) -end +__partial_apply(fn, inp) = fn$inp diff --git a/test/runtests.jl b/test/runtests.jl index b1b28ad1..2d114e99 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,19 +2,37 @@ using SafeTestsets using Test @testset "Common Utilities" begin - @safetestset "Quality Assurance" begin include("qa.jl") end - @safetestset "States" begin include("test_states.jl") end + @safetestset "Quality Assurance" begin + include("qa.jl") + end + @safetestset "States" begin + include("test_states.jl") + end end @testset "Echo State Networks" begin - @safetestset "ESN Input Layers" begin include("esn/test_input_layers.jl") end - @safetestset "ESN Reservoirs" begin include("esn/test_reservoirs.jl") end - @safetestset "ESN States" begin include("esn/test_states.jl") end - @safetestset "ESN Train and Predict" begin include("esn/test_train.jl") end - @safetestset "ESN Drivers" begin include("esn/test_drivers.jl") end - @safetestset "Hybrid ESN" begin include("esn/test_hybrid.jl") end + @safetestset "ESN Input Layers" begin + include("esn/test_input_layers.jl") + end + @safetestset "ESN Reservoirs" begin + include("esn/test_reservoirs.jl") + end + @safetestset "ESN States" begin + include("esn/test_states.jl") + end + @safetestset "ESN Train and Predict" begin + include("esn/test_train.jl") + end + @safetestset "ESN Drivers" begin + include("esn/test_drivers.jl") + end + @safetestset "Hybrid ESN" begin + include("esn/test_hybrid.jl") + end end @testset "CA based Reservoirs" begin - @safetestset "RECA" begin include("reca/test_predictive.jl") end + @safetestset "RECA" begin + include("reca/test_predictive.jl") + end end From 7530157e5417be4fb6c2d6695976a9c8f2166fab Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Sat, 13 Jan 2024 19:19:53 +0100 Subject: [PATCH 02/26] exports --- src/ReservoirComputing.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/ReservoirComputing.jl b/src/ReservoirComputing.jl index 4c24427a..84b0d145 100644 --- a/src/ReservoirComputing.jl +++ b/src/ReservoirComputing.jl @@ -18,9 +18,7 @@ export StandardRidge, LinearModel export AbstractLayer, create_layer export WeightedLayer, DenseLayer, SparseLayer, MinimumLayer, InformedLayer, NullLayer export BernoulliSample, IrrationalSample -export AbstractReservoir, create_reservoir -export RandSparseReservoir, PseudoSVDReservoir, DelayLineReservoir -export DelayLineBackwardReservoir, SimpleCycleReservoir, CycleJumpsReservoir, NullReservoir +export rand_sparse, delay_line export RNN, MRNN, GRU, GRUParams, FullyGated, Minimal export ESN, Default, Hybrid, train export RECA, train From ba9ec649002d195c25c18d4d52281fa78fd2ef7a Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Thu, 18 Jan 2024 18:17:21 +0100 Subject: [PATCH 03/26] fixing types and starting tests --- Project.toml | 2 + src/ReservoirComputing.jl | 2 + src/esn/esn_reservoirs.jl | 10 ++-- test/esn/test_reservoirs.jl | 99 +++++++++++-------------------------- test/utils.jl | 5 ++ 5 files changed, 42 insertions(+), 76 deletions(-) create mode 100644 test/utils.jl diff --git a/Project.toml b/Project.toml index b003aecd..fb57c37a 100644 --- a/Project.toml +++ b/Project.toml @@ -14,6 +14,7 @@ MLJLinearModels = "6ee0df7b-362f-4a72-a706-9e79364fb692" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Optim = "429524aa-4258-5aef-a3af-852621145aeb" PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -29,6 +30,7 @@ LinearAlgebra = "1.10" MLJLinearModels = "0.9.2" NNlib = "0.8.4, 0.9" Optim = "1" +PartialFunctions = "1.2" Random = "1" SafeTestsets = "0.1" SparseArrays = "1.10" diff --git a/src/ReservoirComputing.jl b/src/ReservoirComputing.jl index 84b0d145..c826b3f3 100644 --- a/src/ReservoirComputing.jl +++ b/src/ReservoirComputing.jl @@ -9,6 +9,8 @@ using LinearAlgebra using MLJLinearModels using NNlib using Optim +using PartialFunctions +using Random using SparseArrays using Statistics diff --git a/src/esn/esn_reservoirs.jl b/src/esn/esn_reservoirs.jl index 047fb0bf..fa57109b 100644 --- a/src/esn/esn_reservoirs.jl +++ b/src/esn/esn_reservoirs.jl @@ -23,11 +23,11 @@ This type of reservoir initialization is commonly used in ESNs for capturing tem function rand_sparse(rng::AbstractRNG, ::Type{T}, dims::Integer...; - radius = 1.0, - sparsity = 0.1) where {T <: Number} + radius = T(1.0), + sparsity = T(0.1)) where {T <: Number} reservoir_matrix = Matrix{T}(sprand(rng, dims..., sparsity)) - reservoir_matrix = 2.0 .* (reservoir_matrix .- 0.5) - replace!(reservoir_matrix, -1.0 => 0.0) + reservoir_matrix = T(2.0) .* (reservoir_matrix .- T(0.5)) + replace!(reservoir_matrix, T(-1.0) => T(0.0)) rho_w = maximum(abs.(eigvals(reservoir_matrix))) reservoir_matrix .*= radius / rho_w if Inf in unique(reservoir_matrix) || -Inf in unique(reservoir_matrix) @@ -62,7 +62,7 @@ Rodan, Ali, and Peter Tino. "Minimum complexity echo state network." IEEE Transa function delay_line(rng::AbstractRNG, ::Type{T}, dims::Integer...; - weight = 0.1) where {T <: Number} + weight = T(0.1)) where {T <: Number} reservoir_matrix = zeros(T, dims...) @assert length(dims) == 2 && dims[1] == dims[2], "The dimensions must define a square matrix (e.g., (100, 100))" diff --git a/test/esn/test_reservoirs.jl b/test/esn/test_reservoirs.jl index ac751712..df58bd3c 100644 --- a/test/esn/test_reservoirs.jl +++ b/test/esn/test_reservoirs.jl @@ -1,79 +1,36 @@ using ReservoirComputing +using LinearAlgebra +using Random +include("../utils.jl") const res_size = 20 const radius = 1.0 const sparsity = 0.1 const weight = 0.2 const jump_size = 3 +const rng = Random.default_rng() + +dtypes = [Float16, Float32, Float64] +reservoir_inits = [rand_sparse] + +@testset "Sizes and types" begin + for init in reservoir_inits + for dt in dtypes + #sizes + @test size(init(res_size, res_size)) == (res_size, res_size) + @test size(init(rng, res_size, res_size)) == (res_size, res_size) + #types + @test eltype(init(dt, res_size, res_size)) == dt + @test eltype(init(rng, dt, res_size, res_size)) == dt + #closure + cl = init(rng) + @test cl(dt, res_size, res_size) isa AbstractArray{dt} + end + end +end + +@testset "rand_sparse" begin + sp = rand_sparse(res_size, res_size) + @test check_radius(sp, radius) +end -#testing RandSparseReservoir implicit and esplicit constructors -reservoir_constructor = RandSparseReservoir(res_size, radius, sparsity) -reservoir_matrix = create_reservoir(reservoir_constructor, res_size) -@test size(reservoir_matrix) == (res_size, res_size) - -reservoir_constructor = RandSparseReservoir(res_size, radius = radius, sparsity = sparsity) -reservoir_matrix = create_reservoir(reservoir_constructor, res_size) -@test size(reservoir_matrix) == (res_size, res_size) - -#testing PseudoSVDReservoir implicit and esplicit constructors -reservoir_constructor = PseudoSVDReservoir(res_size, radius, sparsity) -reservoir_matrix = create_reservoir(reservoir_constructor, res_size) -@test size(reservoir_matrix) == (res_size, res_size) -@test maximum(reservoir_matrix) <= radius - -reservoir_constructor = PseudoSVDReservoir(res_size, max_value = radius, - sparsity = sparsity) -reservoir_matrix = create_reservoir(reservoir_constructor, res_size) -@test size(reservoir_matrix) == (res_size, res_size) -@test maximum(reservoir_matrix) <= radius - -#testing DelayLineReservoir implicit and esplicit constructors -reservoir_constructor = DelayLineReservoir(res_size, weight) -reservoir_matrix = create_reservoir(reservoir_constructor, res_size) -@test size(reservoir_matrix) == (res_size, res_size) -@test maximum(reservoir_matrix) == weight - -reservoir_constructor = DelayLineReservoir(res_size, weight = weight) -reservoir_matrix = create_reservoir(reservoir_constructor, res_size) -@test size(reservoir_matrix) == (res_size, res_size) -@test maximum(reservoir_matrix) == weight - -#testing DelayLineReservoir implicit and esplicit constructors -reservoir_constructor = DelayLineBackwardReservoir(res_size, weight, weight) -reservoir_matrix = create_reservoir(reservoir_constructor, res_size) -@test size(reservoir_matrix) == (res_size, res_size) -@test maximum(reservoir_matrix) == weight - -reservoir_constructor = DelayLineBackwardReservoir(res_size, weight = weight, - fb_weight = weight) -reservoir_matrix = create_reservoir(reservoir_constructor, res_size) -@test size(reservoir_matrix) == (res_size, res_size) -@test maximum(reservoir_matrix) == weight - -#testing SimpleCycleReservoir implicit and esplicit constructors -reservoir_constructor = SimpleCycleReservoir(res_size, weight) -reservoir_matrix = create_reservoir(reservoir_constructor, res_size) -@test size(reservoir_matrix) == (res_size, res_size) -@test maximum(reservoir_matrix) == weight - -reservoir_constructor = SimpleCycleReservoir(res_size, weight = weight) -reservoir_matrix = create_reservoir(reservoir_constructor, res_size) -@test size(reservoir_matrix) == (res_size, res_size) -@test maximum(reservoir_matrix) == weight - -#testing CycleJumpsReservoir implicit and esplicit constructors -reservoir_constructor = CycleJumpsReservoir(res_size, weight, weight, jump_size) -reservoir_matrix = create_reservoir(reservoir_constructor, res_size) -@test size(reservoir_matrix) == (res_size, res_size) -@test maximum(reservoir_matrix) == weight - -reservoir_constructor = CycleJumpsReservoir(res_size, cycle_weight = weight, - jump_weight = weight, jump_size = jump_size) -reservoir_matrix = create_reservoir(reservoir_constructor, res_size) -@test size(reservoir_matrix) == (res_size, res_size) -@test maximum(reservoir_matrix) == weight - -#testing NullReservoir constructors -reservoir_constructor = NullReservoir() -reservoir_matrix = create_reservoir(reservoir_constructor, res_size) -@test size(reservoir_matrix) == (res_size, res_size) diff --git a/test/utils.jl b/test/utils.jl new file mode 100644 index 00000000..9ef6f360 --- /dev/null +++ b/test/utils.jl @@ -0,0 +1,5 @@ +function check_radius(matrix, target_radius; tolerance=1e-5) + eigenvalues = eigvals(matrix) + spectral_radius = maximum(abs.(eigenvalues)) + return isapprox(spectral_radius, target_radius, atol=tolerance) +end \ No newline at end of file From 3a5a62202623a14f2cac040f83bf357cfdbc55e9 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Sat, 20 Jan 2024 18:10:11 +0100 Subject: [PATCH 04/26] start of input_layers, start of streamline to new api --- Project.toml | 1 + src/ReservoirComputing.jl | 27 ++- src/esn/echostatenetwork.jl | 37 ++-- src/esn/esn_input_layers.jl | 381 ++---------------------------------- src/esn/esn_reservoirs.jl | 29 +-- test/esn/test_reservoirs.jl | 7 +- 6 files changed, 71 insertions(+), 411 deletions(-) diff --git a/Project.toml b/Project.toml index fb57c37a..e0d6bb2b 100644 --- a/Project.toml +++ b/Project.toml @@ -17,6 +17,7 @@ PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" [compat] Adapt = "3.3.3, 4" diff --git a/src/ReservoirComputing.jl b/src/ReservoirComputing.jl index c826b3f3..d1e47609 100644 --- a/src/ReservoirComputing.jl +++ b/src/ReservoirComputing.jl @@ -13,13 +13,13 @@ using PartialFunctions using Random using SparseArrays using Statistics +using WeightInitializers export NLADefault, NLAT1, NLAT2, NLAT3 export StandardStates, ExtendedStates, PaddedStates, PaddedExtendedStates export StandardRidge, LinearModel export AbstractLayer, create_layer -export WeightedLayer, DenseLayer, SparseLayer, MinimumLayer, InformedLayer, NullLayer -export BernoulliSample, IrrationalSample +export scaled_rand export rand_sparse, delay_line export RNN, MRNN, GRU, GRUParams, FullyGated, Minimal export ESN, Default, Hybrid, train @@ -72,6 +72,29 @@ function Predictive(prediction_data) Predictive(prediction_data, prediction_len) end +#fallbacks for initializers +for initializer in (:rand_sparse, :delay_line, :scaled_rand) + NType = ifelse(initializer === :rand_sparse, Real, Number) + @eval function ($initializer)(dims::Integer...; kwargs...) + return $initializer(_default_rng(), Float32, dims...; kwargs...) + end + @eval function ($initializer)(rng::AbstractRNG, dims::Integer...; kwargs...) + return $initializer(rng, Float32, dims...; kwargs...) + end + @eval function ($initializer)(::Type{T}, + dims::Integer...; kwargs...) where {T <: $NType} + return $initializer(_default_rng(), T, dims...; kwargs...) + end + @eval function ($initializer)(rng::AbstractRNG; kwargs...) + return __partial_apply($initializer, (rng, (; kwargs...))) + end + @eval function ($initializer)(rng::AbstractRNG, + ::Type{T}; kwargs...) where {T <: $NType} + return __partial_apply($initializer, ((rng, T), (; kwargs...))) + end + @eval ($initializer)(; kwargs...) = __partial_apply($initializer, (; kwargs...)) +end + #general include("states.jl") include("predict.jl") diff --git a/src/esn/echostatenetwork.jl b/src/esn/echostatenetwork.jl index 42fab481..adbf85a4 100644 --- a/src/esn/echostatenetwork.jl +++ b/src/esn/echostatenetwork.jl @@ -90,33 +90,30 @@ train_data = rand(10, 100) # 10 features, 100 time steps esn = ESN(train_data, reservoir=RandSparseReservoir(200), washout=10) ``` """ -function ESN(train_data; - variation = Default(), - input_layer = DenseLayer(), - reservoir = RandSparseReservoir(100), - bias = NullLayer(), - reservoir_driver = RNN(), - nla_type = NLADefault(), - states_type = StandardStates(), - washout = 0, - matrix_type = typeof(train_data)) - if variation isa Hybrid - train_data = vcat(train_data, variation.model_data[:, 1:(end - 1)]) - end +function ESN( + train_data, + in_size, + res_size; + input_layer = scaled_rand, + reservoir = rand_sparse, + bias = zeros64, + reservoir_driver = RNN(), + nla_type = NLADefault(), + states_type = StandardStates(), + washout = 0, + rng = _default_rng(), + matrix_type = typeof(train_data) +) where {T <: Number} if states_type isa AbstractPaddedStates in_size = size(train_data, 1) + 1 train_data = vcat(Adapt.adapt(matrix_type, ones(1, size(train_data, 2))), train_data) - else - in_size = size(train_data, 1) end - input_matrix, reservoir_matrix, bias_vector, res_size = obtain_layers(in_size, - input_layer, - reservoir, bias; - matrix_type = matrix_type) - + reservoir_matrix = reservoir(rng, T, res_size, res_size) + input_matrix = input_layer(rng, T, res_size, in_size) + bias_vector = bias(rng, T, res_size) inner_res_driver = reservoir_driver_params(reservoir_driver, res_size, in_size) states = create_states(inner_res_driver, train_data, washout, reservoir_matrix, input_matrix, bias_vector) diff --git a/src/esn/esn_input_layers.jl b/src/esn/esn_input_layers.jl index e7bb950c..42b82282 100644 --- a/src/esn/esn_input_layers.jl +++ b/src/esn/esn_input_layers.jl @@ -1,371 +1,32 @@ -abstract type AbstractLayer end - -struct WeightedLayer{T} <: AbstractLayer - scaling::T -end - -""" - WeightedInput(scaling) - WeightedInput(;scaling=0.1) - -Creates a `WeightedInput` layer initializer for Echo State Networks. -This initializer generates a weighted input matrix with random non-zero -elements distributed uniformly within the range [-`scaling`, `scaling`], -following the approach in [^Lu]. - -# Parameters -- `scaling`: The scaling factor for the weight distribution (default: 0.1). - -# Returns -- A `WeightedInput` instance to be used for initializing the input layer of an ESN. - -Reference: -[^Lu]: Lu, Zhixin, et al. - "Reservoir observers: Model-free inference of unmeasured variables in chaotic systems." - Chaos: An Interdisciplinary Journal of Nonlinear Science 27.4 (2017): 041102. -""" -function WeightedLayer(; scaling = 0.1) - return WeightedLayer(scaling) -end - -function create_layer(input_layer::WeightedLayer, - approx_res_size, - in_size; - matrix_type = Matrix{Float64}) - scaling = input_layer.scaling - res_size = Int(floor(approx_res_size / in_size) * in_size) - layer_matrix = zeros(res_size, in_size) - q = floor(Int, res_size / in_size) - - for i in 1:in_size - layer_matrix[((i - 1) * q + 1):((i) * q), i] = rand(Uniform(-scaling, scaling), 1, - q) - end - - return Adapt.adapt(matrix_type, layer_matrix) -end - -function create_layer(layer, args...; kwargs...) - return layer -end - -""" - DenseLayer(scaling) - DenseLayer(;scaling=0.1) - -Creates a `DenseLayer` initializer for Echo State Networks, generating a fully connected input layer. -The layer is initialized with random weights uniformly distributed within [-`scaling`, `scaling`]. -This scaling factor can be provided either as an argument or a keyword argument. -The `DenseLayer` is the default input layer in `ESN` construction. - -# Parameters -- `scaling`: The scaling factor for weight distribution (default: 0.1). - -# Returns -- A `DenseLayer` instance for initializing the ESN's input layer. -""" -struct DenseLayer{T} <: AbstractLayer - scaling::T -end - -function DenseLayer(; scaling = 0.1) - return DenseLayer(scaling) -end - -""" - create_layer(input_layer::AbstractLayer, res_size, in_size) - -Generates a matrix layer of size `res_size` x `in_size`, constructed according to the specifications of the `input_layer`. - -# Parameters -- `input_layer`: An instance of `AbstractLayer` determining the layer construction. -- `res_size`: The number of rows (reservoir size) for the layer. -- `in_size`: The number of columns (input size) for the layer. - -# Returns -- A matrix representing the constructed layer. -""" -function create_layer(input_layer::DenseLayer, - res_size, - in_size; - matrix_type = Matrix{Float64}) - scaling = input_layer.scaling - layer_matrix = rand(Uniform(-scaling, scaling), res_size, in_size) - return Adapt.adapt(matrix_type, layer_matrix) -end - -""" - SparseLayer(scaling, sparsity) - SparseLayer(scaling; sparsity=0.1) - SparseLayer(;scaling=0.1, sparsity=0.1) - -Creates a `SparseLayer` initializer for Echo State Networks, generating a sparse input layer. -The layer is initialized with weights distributed within [-`scaling`, `scaling`] -and a specified `sparsity` level. Both `scaling` and `sparsity` can be set as arguments or keyword arguments. - -# Parameters -- `scaling`: Scaling factor for weight distribution (default: 0.1). -- `sparsity`: Sparsity level of the layer (default: 0.1). - -# Returns -- A `SparseLayer` instance for initializing ESN's input layer with sparse connections. -""" -struct SparseLayer{T} <: AbstractLayer - scaling::T - sparsity::T -end - -function SparseLayer(; scaling = 0.1, sparsity = 0.1) - return SparseLayer(scaling, sparsity) -end - -function SparseLayer(scaling_arg; scaling = scaling_arg, sparsity = 0.1) - return SparseLayer(scaling, sparsity) -end - -function create_layer(input_layer::SparseLayer, - res_size, - in_size; - matrix_type = Matrix{Float64}) - layer_matrix = Matrix(sprand(res_size, in_size, input_layer.sparsity)) - layer_matrix = 2.0 .* (layer_matrix .- 0.5) - replace!(layer_matrix, -1.0 => 0.0) - layer_matrix = input_layer.scaling .* layer_matrix - return Adapt.adapt(matrix_type, layer_matrix) -end - -#from "minimum complexity echo state network" Rodan -#and "simple deterministically constructed cycle reservoirs with regular jumps" -#by Rodan and Tino -struct BernoulliSample{T} - p::T -end - -""" - BernoulliSample(p) - BernoulliSample(;p=0.5) - -Creates a `BernoulliSample` constructor for the `MinimumLayer`. -It uses a Bernoulli distribution to determine the sign of weights in the input layer. -The parameter `p` sets the probability of a weight being positive, as per the `Distributions` package. -This method of sign weight determination for input layers is based on the approach in [^Rodan]. - -# Parameters -- `p`: Probability of a positive weight (default: 0.5). - -# Returns -- A `BernoulliSample` instance for generating sign weights in `MinimumLayer`. - -Reference: -[^Rodan]: Rodan, Ali, and Peter Tino. - "Minimum complexity echo state network." - IEEE Transactions on Neural Networks 22.1 (2010): 131-144. -""" -function BernoulliSample(; p = 0.5) - return BernoulliSample(p) -end - -struct IrrationalSample{K} - irrational::Irrational - start::K -end - -""" - IrrationalSample(irrational, start) - IrrationalSample(;irrational=pi, start=1) - -Creates an `IrrationalSample` constructor for the `MinimumLayer`. -It determines the sign of weights in the input layer based on the decimal expansion of an `irrational` number. -The `start` parameter sets the starting point in the decimal sequence. -The signs are assigned based on the thresholding of each decimal digit against 4.5, as described in [^Rodan]. - -# Parameters -- `irrational`: An irrational number for weight sign determination (default: π). -- `start`: Starting index in the decimal expansion (default: 1). - -# Returns -- An `IrrationalSample` instance for generating sign weights in `MinimumLayer`. - -Reference: -[^Rodan]: Rodan, Ali, and Peter Tiňo. - "Simple deterministically constructed cycle reservoirs with regular jumps." - Neural Computation 24.7 (2012): 1822-1852. -""" -function IrrationalSample(; irrational = pi, start = 1) - return IrrationalSample(irrational, start) -end - -struct MinimumLayer{T, K} <: AbstractLayer - weight::T - sampling::K -end - -""" - MinimumLayer(weight, sampling) - MinimumLayer(weight; sampling=BernoulliSample(0.5)) - MinimumLayer(;weight=0.1, sampling=BernoulliSample(0.5)) - -Creates a `MinimumLayer` initializer for Echo State Networks, generating a fully connected input layer. -This layer has a uniform absolute weight value (`weight`) with the sign of each -weight determined by the `sampling` method. This approach, as detailed in [^Rodan1] and [^Rodan2], -allows for controlled weight distribution in the layer. - -# Parameters -- `weight`: Absolute value of weights in the layer. -- `sampling`: Method for determining the sign of weights (default: `BernoulliSample(0.5)`). - -# Returns -- A `MinimumLayer` instance for initializing the ESN's input layer. - -References: -[^Rodan1]: Rodan, Ali, and Peter Tino. - "Minimum complexity echo state network." - IEEE Transactions on Neural Networks 22.1 (2010): 131-144. -[^Rodan2]: Rodan, Ali, and Peter Tiňo. - "Simple deterministically constructed cycle reservoirs with regular jumps." - Neural Computation 24.7 (2012): 1822-1852. -""" -function MinimumLayer(weight; sampling = BernoulliSample(0.5)) - return MinimumLayer(weight, sampling) -end - -function MinimumLayer(; weight = 0.1, sampling = BernoulliSample(0.5)) - return MinimumLayer(weight, sampling) -end - -function create_layer(input_layer::MinimumLayer, - res_size, - in_size; - matrix_type = Matrix{Float64}) - sampling = input_layer.sampling - weight = input_layer.weight - layer_matrix = create_minimum_input(sampling, res_size, in_size, weight) - return Adapt.adapt(matrix_type, layer_matrix) -end - -function create_minimum_input(sampling::BernoulliSample, res_size, in_size, weight) - p = sampling.p - input_matrix = zeros(res_size, in_size) - for i in 1:res_size - for j in 1:in_size - rand(Bernoulli(p)) ? input_matrix[i, j] = weight : input_matrix[i, j] = -weight - end - end - - return input_matrix -end - -function create_minimum_input(sampling::IrrationalSample, res_size, in_size, weight) - setprecision(BigFloat, Int(ceil(log2(10) * (res_size * in_size + sampling.start + 1)))) - ir_string = string(BigFloat(sampling.irrational)) |> collect - deleteat!(ir_string, findall(x -> x == '.', ir_string)) - ir_array = zeros(length(ir_string)) - input_matrix = zeros(res_size, in_size) - - for i in 1:length(ir_string) - ir_array[i] = parse(Int, ir_string[i]) - end - - co = sampling.start - counter = 1 - - for i in 1:res_size - for j in 1:in_size - ir_array[counter] < 5 ? input_matrix[i, j] = -weight : - input_matrix[i, j] = weight - counter += 1 - end - end - - return input_matrix -end - -struct InformedLayer{T, K, M} <: AbstractLayer - scaling::T - gamma::K - model_in_size::M -end - """ - InformedLayer(model_in_size; scaling=0.1, gamma=0.5) + scaled_rand(rng::AbstractRNG, ::Type{T}, dims::Integer...; scaling=T(0.1)) where {T <: Number} -Creates an `InformedLayer` initializer for Echo State Networks (ESNs) that generates -a weighted input layer matrix. The matrix contains random non-zero elements drawn from -the range [-```scaling```, ```scaling```]. This initializer ensures that a fraction (`gamma`) -of reservoir nodes are exclusively connected to the raw inputs, while the rest are -connected to the outputs of a prior knowledge model, as described in [^Pathak]. +Create and return a matrix with random values, uniformly distributed within a range defined by `scaling`. This function is useful for initializing matrices, such as the layers of a neural network, with scaled random values. # Arguments -- `model_in_size`: The size of the prior knowledge model's output, - which determines the number of columns in the input layer matrix. - -# Keyword Arguments -- `scaling`: The absolute value of the weights (default: 0.1). -- `gamma`: The fraction of reservoir nodes connected exclusively to raw inputs (default: 0.5). +- `rng`: An instance of `AbstractRNG` for random number generation. +- `T`: The data type for the elements of the matrix. +- `dims`: Dimensions of the matrix. It must be a 2-element tuple specifying the number of rows and columns (e.g., `(res_size, in_size)`). +- `scaling`: A scaling factor to define the range of the uniform distribution. The matrix elements will be randomly chosen from the range `[-scaling, scaling]`. Defaults to `T(0.1)`. # Returns -- An `InformedLayer` instance for initializing the ESN's input layer matrix. - -Reference: -[^Pathak]: Jaideep Pathak et al. - "Hybrid Forecasting of Chaotic Processes: Using Machine Learning in Conjunction with a Knowledge-Based Model" (2018). -""" -function InformedLayer(model_in_size; scaling = 0.1, gamma = 0.5) - return InformedLayer(scaling, gamma, model_in_size) -end - -function create_layer(input_layer::InformedLayer, - res_size, - in_size; - matrix_type = Matrix{Float64}) - scaling = input_layer.scaling - state_size = in_size - input_layer.model_in_size - - if state_size <= 0 - throw(DimensionMismatch("in_size must be greater than model_in_size")) - end - - input_matrix = zeros(res_size, in_size) - #Vector used to find res nodes not yet connected - zero_connections = zeros(in_size) - #Num of res nodes allotted for raw states - num_for_state = floor(Int, res_size * input_layer.gamma) - #Num of res nodes allotted for prior model input - num_for_model = floor(Int, (res_size * (1 - input_layer.gamma))) - - for i in 1:num_for_state - #find res nodes with no connections - idxs = findall(Bool[zero_connections == input_matrix[i, :] - for i in 1:size(input_matrix, 1)]) - random_row_idx = idxs[rand(1:end)] - random_clm_idx = range(1, state_size, step = 1)[rand(1:end)] - input_matrix[random_row_idx, random_clm_idx] = rand(Uniform(-scaling, scaling)) - end - - for i in 1:num_for_model - idxs = findall(Bool[zero_connections == input_matrix[i, :] - for i in 1:size(input_matrix, 1)]) - random_row_idx = idxs[rand(1:end)] - random_clm_idx = range(state_size + 1, in_size, step = 1)[rand(1:end)] - input_matrix[random_row_idx, random_clm_idx] = rand(Uniform(-scaling, scaling)) - end - - return Adapt.adapt(matrix_type, input_matrix) -end +A matrix of type with dimensions specified by `dims`. Each element of the matrix is a random number uniformly distributed between `-scaling` and `scaling`. +# Example +```julia +rng = Random.default_rng() +matrix = scaled_rand(rng, Float64, (100, 50); scaling=0.2) """ - NullLayer() +function scaled_rand( + rng::AbstractRNG, + ::Type{T}, + dims::Integer...; + scaling=T(0.1) +) where {T <: Number} -Creates a `NullLayer` initializer for Echo State Networks (ESNs) that generates a vector of zeros. - -# Returns -- A `NullLayer` instance for initializing the ESN's input layer matrix. -""" -struct NullLayer <: AbstractLayer end + @assert length(dims) == 2, "The dimensions must define a matrix (e.g., (res_size, in_size))" -function create_layer(input_layer::NullLayer, - res_size, - in_size; - matrix_type = Matrix{Float64}) - return Adapt.adapt(matrix_type, zeros(res_size, in_size)) + res_size, in_size = dims + layer_matrix = rand(rng, Uniform(-scaling, scaling), res_size, in_size) + return layer_matrix end diff --git a/src/esn/esn_reservoirs.jl b/src/esn/esn_reservoirs.jl index fa57109b..ab2eaf42 100644 --- a/src/esn/esn_reservoirs.jl +++ b/src/esn/esn_reservoirs.jl @@ -1,7 +1,3 @@ -function get_ressize(reservoir) - return size(reservoir, 1) -end - """ rand_sparse(rng::AbstractRNG, ::Type{T}, dims::Integer...; radius=1.0, sparsity=0.1) @@ -64,8 +60,7 @@ function delay_line(rng::AbstractRNG, dims::Integer...; weight = T(0.1)) where {T <: Number} reservoir_matrix = zeros(T, dims...) - @assert length(dims) == 2 && dims[1] == dims[2], - "The dimensions must define a square matrix (e.g., (100, 100))" + @assert length(dims) == 2 && dims[1] == dims[2] "The dimensions must define a square matrix (e.g., (100, 100))" for i in 1:(dims[1] - 1) reservoir_matrix[i + 1, i] = weight @@ -74,28 +69,6 @@ function delay_line(rng::AbstractRNG, return reservoir_matrix end -for initializer in (:rand_sparse, :delay_line) - NType = ifelse(initializer === :rand_sparse, Real, Number) - @eval function ($initializer)(dims::Integer...; kwargs...) - return $initializer(_default_rng(), Float32, dims...; kwargs...) - end - @eval function ($initializer)(rng::AbstractRNG, dims::Integer...; kwargs...) - return $initializer(rng, Float32, dims...; kwargs...) - end - @eval function ($initializer)(::Type{T}, - dims::Integer...; kwargs...) where {T <: $NType} - return $initializer(_default_rng(), T, dims...; kwargs...) - end - @eval function ($initializer)(rng::AbstractRNG; kwargs...) - return __partial_apply($initializer, (rng, (; kwargs...))) - end - @eval function ($initializer)(rng::AbstractRNG, - ::Type{T}; kwargs...) where {T <: $NType} - return __partial_apply($initializer, ((rng, T), (; kwargs...))) - end - @eval ($initializer)(; kwargs...) = __partial_apply($initializer, (; kwargs...)) -end - # from WeightInitializers.jl, TODO @MartinuzziFrancesco consider importing package function _default_rng() @static if VERSION >= v"1.7" diff --git a/test/esn/test_reservoirs.jl b/test/esn/test_reservoirs.jl index df58bd3c..debd9be0 100644 --- a/test/esn/test_reservoirs.jl +++ b/test/esn/test_reservoirs.jl @@ -11,7 +11,7 @@ const jump_size = 3 const rng = Random.default_rng() dtypes = [Float16, Float32, Float64] -reservoir_inits = [rand_sparse] +reservoir_inits = [rand_sparse, delay_line] @testset "Sizes and types" begin for init in reservoir_inits @@ -34,3 +34,8 @@ end @test check_radius(sp, radius) end +@testset "delay_line" begin + dl = delay_line(res_size, res_size) + @test unique(dl) == Float32.([0.0, 0.1]) +end + From ab3337cad9bdd669e35d52e57bb3933bc088cabb Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Sun, 21 Jan 2024 16:41:52 +0100 Subject: [PATCH 05/26] made ESN work with new initilaizers, started separation of different models --- README.md | 9 +- src/ReservoirComputing.jl | 11 +- src/esn/deepesn.jl | 84 ++++++++++++ src/esn/echostatenetwork.jl | 266 ------------------------------------ src/esn/esn.jl | 143 +++++++++++++++++++ src/esn/esn_input_layers.jl | 40 +++++- src/esn/esn_predict.jl | 35 +---- src/esn/hybridesn.jl | 82 +++++++++++ 8 files changed, 366 insertions(+), 304 deletions(-) create mode 100644 src/esn/deepesn.jl delete mode 100644 src/esn/echostatenetwork.jl create mode 100644 src/esn/esn.jl create mode 100644 src/esn/hybridesn.jl diff --git a/README.md b/README.md index 9172725e..60f12963 100644 --- a/README.md +++ b/README.md @@ -51,14 +51,15 @@ test = data[:, (shift + train_len):(shift + train_len + predict_len - 1)] Now that we have the data we can initialize the ESN with the chosen parameters. Given that this is a quick example we are going to change the least amount of possible parameters. For more detailed examples and explanations of the functions please refer to the documentation. ```julia +input_size = 3 res_size = 300 -esn = ESN(input_data; - reservoir = RandSparseReservoir(res_size, radius = 1.2, sparsity = 6 / res_size), - input_layer = WeightedLayer(), +esn = ESN(input_data, input_size, res_size; + reservoir = rand_sparse(;radius = 1.2, sparsity = 6 / res_size), + input_layer = weighted_init, nla_type = NLAT2()) ``` -The echo state network can now be trained and tested. If not specified, the training will always be Ordinary Least Squares regression. The full range of training methods is detailed in the documentation. +The echo state network can now be trained and tested. If not specified, the training will always be ordinary least squares regression. The full range of training methods is detailed in the documentation. ```julia output_layer = train(esn, target_data) diff --git a/src/ReservoirComputing.jl b/src/ReservoirComputing.jl index d1e47609..f8668b54 100644 --- a/src/ReservoirComputing.jl +++ b/src/ReservoirComputing.jl @@ -19,10 +19,11 @@ export NLADefault, NLAT1, NLAT2, NLAT3 export StandardStates, ExtendedStates, PaddedStates, PaddedExtendedStates export StandardRidge, LinearModel export AbstractLayer, create_layer -export scaled_rand +export scaled_rand, weighted_init export rand_sparse, delay_line export RNN, MRNN, GRU, GRUParams, FullyGated, Minimal -export ESN, Default, Hybrid, train +export ESN, train +export DeepESN, HybridESN export RECA, train export RandomMapping, RandomMaps export Generative, Predictive, OutputLayer @@ -73,7 +74,7 @@ function Predictive(prediction_data) end #fallbacks for initializers -for initializer in (:rand_sparse, :delay_line, :scaled_rand) +for initializer in (:rand_sparse, :delay_line, :scaled_rand, :weighted_init) NType = ifelse(initializer === :rand_sparse, Real, Number) @eval function ($initializer)(dims::Integer...; kwargs...) return $initializer(_default_rng(), Float32, dims...; kwargs...) @@ -107,7 +108,9 @@ include("train/supportvector_regression.jl") include("esn/esn_input_layers.jl") include("esn/esn_reservoirs.jl") include("esn/esn_reservoir_drivers.jl") -include("esn/echostatenetwork.jl") +include("esn/esn.jl") +include("esn/deepesn.jl") +include("esn/hybridesn.jl") include("esn/esn_predict.jl") #reca diff --git a/src/esn/deepesn.jl b/src/esn/deepesn.jl new file mode 100644 index 00000000..4ab05f39 --- /dev/null +++ b/src/esn/deepesn.jl @@ -0,0 +1,84 @@ +struct DeepESN{I, S, V, N, T, O, M, B, ST, W, IS} <: AbstractEchoStateNetwork + res_size::I + train_data::S + variation::V + nla_type::N + input_matrix::T + reservoir_driver::O + reservoir_matrix::M + bias_vector::B + states_type::ST + washout::W + states::IS +end + +function DeepESN( + train_data, + in_size::Int, + res_size::AbstractArray; + input_layer = scaled_rand, + reservoir = rand_sparse, + bias = zeros64, + reservoir_driver = RNN(), + nla_type = NLADefault(), + states_type = StandardStates(), + washout = 0, + rng = _default_rng(), + T=Float64, + matrix_type = typeof(train_data) +) + + if states_type isa AbstractPaddedStates + in_size = size(train_data, 1) + 1 + train_data = vcat(Adapt.adapt(matrix_type, ones(1, size(train_data, 2))), + train_data) + end + + reservoir_matrix = reservoir(rng, T, res_size, res_size) + input_matrix = input_layer(rng, T, res_size, in_size) + bias_vector = bias(rng, T, res_size) + inner_res_driver = reservoir_driver_params(reservoir_driver, res_size, in_size) + states = create_states(inner_res_driver, train_data, washout, reservoir_matrix, + input_matrix, bias_vector) + train_data = train_data[:, (washout + 1):end] + + ESN(sum(res_size), train_data, variation, nla_type, input_matrix, + inner_res_driver, reservoir_matrix, bias_vector, states_type, washout, + states) +end + +function obtain_layers(in_size, + input_layer, + reservoir::Vector, + bias; + matrix_type = Matrix{Float64}) +esn_depth = length(reservoir) +input_res_sizes = [get_ressize(reservoir[i]) for i in 1:esn_depth] +in_sizes = zeros(Int, esn_depth) +in_sizes[2:end] = input_res_sizes[1:(end - 1)] +in_sizes[1] = in_size + +if input_layer isa Array + input_matrix = [create_layer(input_layer[j], input_res_sizes[j], in_sizes[j], + matrix_type = matrix_type) for j in 1:esn_depth] +else + _input_layer = fill(input_layer, esn_depth) + input_matrix = [create_layer(_input_layer[k], input_res_sizes[k], in_sizes[k], + matrix_type = matrix_type) for k in 1:esn_depth] +end + +res_sizes = [get_ressize(input_matrix[j]) for j in 1:esn_depth] +reservoir_matrix = [create_reservoir(reservoir[k], res_sizes[k], + matrix_type = matrix_type) for k in 1:esn_depth] + +if bias isa Array + bias_vector = [create_layer(bias[j], res_sizes[j], 1, matrix_type = matrix_type) + for j in 1:esn_depth] +else + _bias = fill(bias, esn_depth) + bias_vector = [create_layer(_bias[k], res_sizes[k], 1, matrix_type = matrix_type) + for k in 1:esn_depth] +end + +return input_matrix, reservoir_matrix, bias_vector, res_sizes +end \ No newline at end of file diff --git a/src/esn/echostatenetwork.jl b/src/esn/echostatenetwork.jl deleted file mode 100644 index adbf85a4..00000000 --- a/src/esn/echostatenetwork.jl +++ /dev/null @@ -1,266 +0,0 @@ -abstract type AbstractEchoStateNetwork <: AbstractReservoirComputer end -struct ESN{I, S, V, N, T, O, M, B, ST, W, IS} <: AbstractEchoStateNetwork - res_size::I - train_data::S - variation::V - nla_type::N - input_matrix::T - reservoir_driver::O - reservoir_matrix::M - bias_vector::B - states_type::ST - washout::W - states::IS -end - -""" - Default() - -The `Default` struct specifies the use of the standard model in Echo State Networks (ESNs). -It requires no parameters and is used when no specific variations or customizations of the ESN model are needed. -This struct is ideal for straightforward applications where the default ESN settings are sufficient. -""" -struct Default <: AbstractVariation end -struct Hybrid{T, K, O, I, S, D} <: AbstractVariation - prior_model::T - u0::K - tspan::O - dt::I - datasize::S - model_data::D -end - -""" - Hybrid(prior_model, u0, tspan, datasize) - -Constructs a `Hybrid` variation of Echo State Networks (ESNs) integrating a knowledge-based model -(`prior_model`) with ESNs for advanced training and prediction in chaotic systems. - -# Parameters -- `prior_model`: A knowledge-based model function for integration with ESNs. -- `u0`: Initial conditions for the model. -- `tspan`: Time span as a tuple, indicating the duration for model operation. -- `datasize`: The size of the data to be processed. - -# Returns -- A `Hybrid` struct instance representing the combined ESN and knowledge-based model. - -This method is effective for chaotic processes as highlighted in [^Pathak]. - -Reference: -[^Pathak]: Jaideep Pathak et al. - "Hybrid Forecasting of Chaotic Processes: - Using Machine Learning in Conjunction with a Knowledge-Based Model" (2018). -""" -function Hybrid(prior_model, u0, tspan, datasize) - trange = collect(range(tspan[1], tspan[2], length = datasize)) - dt = trange[2] - trange[1] - tsteps = push!(trange, dt + trange[end]) - tspan_new = (tspan[1], dt + tspan[2]) - model_data = prior_model(u0, tspan_new, tsteps) - return Hybrid(prior_model, u0, tspan, dt, datasize, model_data) -end - -""" - ESN(train_data; kwargs...) -> ESN - -Creates an Echo State Network (ESN) using specified parameters and training data, suitable for various machine learning tasks. - -# Parameters -- `train_data`: Matrix of training data (columns as time steps, rows as features). -- `variation`: Variation of ESN (default: `Default()`). -- `input_layer`: Input layer of ESN (default: `DenseLayer()`). -- `reservoir`: Reservoir of the ESN (default: `RandSparseReservoir(100)`). -- `bias`: Bias vector for each time step (default: `NullLayer()`). -- `reservoir_driver`: Mechanism for evolving reservoir states (default: `RNN()`). -- `nla_type`: Non-linear activation type (default: `NLADefault()`). -- `states_type`: Format for storing states (default: `StandardStates()`). -- `washout`: Initial time steps to discard (default: `0`). -- `matrix_type`: Type of matrices used internally (default: type of `train_data`). - -# Returns -- An initialized ESN instance with specified parameters. - -# Examples -```julia -using ReservoirComputing - -train_data = rand(10, 100) # 10 features, 100 time steps - -esn = ESN(train_data, reservoir=RandSparseReservoir(200), washout=10) -``` -""" -function ESN( - train_data, - in_size, - res_size; - input_layer = scaled_rand, - reservoir = rand_sparse, - bias = zeros64, - reservoir_driver = RNN(), - nla_type = NLADefault(), - states_type = StandardStates(), - washout = 0, - rng = _default_rng(), - matrix_type = typeof(train_data) -) where {T <: Number} - - if states_type isa AbstractPaddedStates - in_size = size(train_data, 1) + 1 - train_data = vcat(Adapt.adapt(matrix_type, ones(1, size(train_data, 2))), - train_data) - end - - reservoir_matrix = reservoir(rng, T, res_size, res_size) - input_matrix = input_layer(rng, T, res_size, in_size) - bias_vector = bias(rng, T, res_size) - inner_res_driver = reservoir_driver_params(reservoir_driver, res_size, in_size) - states = create_states(inner_res_driver, train_data, washout, reservoir_matrix, - input_matrix, bias_vector) - train_data = train_data[:, (washout + 1):end] - - ESN(sum(res_size), train_data, variation, nla_type, input_matrix, - inner_res_driver, reservoir_matrix, bias_vector, states_type, washout, - states) -end - -#shallow esn construction -function obtain_layers(in_size, - input_layer, - reservoir, - bias; - matrix_type = Matrix{Float64}) - input_res_size = get_ressize(reservoir) - input_matrix = create_layer(input_layer, input_res_size, in_size, - matrix_type = matrix_type) - res_size = size(input_matrix, 1) #WeightedInput actually changes the res size - reservoir_matrix = create_reservoir(reservoir, res_size, matrix_type = matrix_type) - @assert size(reservoir_matrix, 1) == res_size - bias_vector = create_layer(bias, res_size, 1, matrix_type = matrix_type) - return input_matrix, reservoir_matrix, bias_vector, res_size -end - -#deep esn construction -#there is a bug going on with WeightedLayer in this construction. -#it works for eny other though -function obtain_layers(in_size, - input_layer, - reservoir::Vector, - bias; - matrix_type = Matrix{Float64}) - esn_depth = length(reservoir) - input_res_sizes = [get_ressize(reservoir[i]) for i in 1:esn_depth] - in_sizes = zeros(Int, esn_depth) - in_sizes[2:end] = input_res_sizes[1:(end - 1)] - in_sizes[1] = in_size - - if input_layer isa Array - input_matrix = [create_layer(input_layer[j], input_res_sizes[j], in_sizes[j], - matrix_type = matrix_type) for j in 1:esn_depth] - else - _input_layer = fill(input_layer, esn_depth) - input_matrix = [create_layer(_input_layer[k], input_res_sizes[k], in_sizes[k], - matrix_type = matrix_type) for k in 1:esn_depth] - end - - res_sizes = [get_ressize(input_matrix[j]) for j in 1:esn_depth] - reservoir_matrix = [create_reservoir(reservoir[k], res_sizes[k], - matrix_type = matrix_type) for k in 1:esn_depth] - - if bias isa Array - bias_vector = [create_layer(bias[j], res_sizes[j], 1, matrix_type = matrix_type) - for j in 1:esn_depth] - else - _bias = fill(bias, esn_depth) - bias_vector = [create_layer(_bias[k], res_sizes[k], 1, matrix_type = matrix_type) - for k in 1:esn_depth] - end - - return input_matrix, reservoir_matrix, bias_vector, res_sizes -end - -function (esn::ESN)(prediction::AbstractPrediction, - output_layer::AbstractOutputLayer; - last_state = esn.states[:, [end]], - kwargs...) - variation = esn.variation - pred_len = prediction.prediction_len - - if variation isa Hybrid - model = variation.prior_model - predict_tsteps = [variation.tspan[2] + variation.dt] - [append!(predict_tsteps, predict_tsteps[end] + variation.dt) for i in 1:pred_len] - tspan_new = (variation.tspan[2] + variation.dt, predict_tsteps[end]) - u0 = variation.model_data[:, end] - model_pred_data = model(u0, tspan_new, predict_tsteps)[:, 2:end] - return obtain_esn_prediction(esn, prediction, last_state, output_layer, - model_pred_data; - kwargs...) - else - return obtain_esn_prediction(esn, prediction, last_state, output_layer; - kwargs...) - end -end - -#training dispatch on esn -""" - train(esn::AbstractEchoStateNetwork, target_data, training_method = StandardRidge(0.0)) - -Trains an Echo State Network (ESN) using the provided target data and a specified training method. - -# Parameters -- `esn::AbstractEchoStateNetwork`: The ESN instance to be trained. -- `target_data`: Supervised training data for the ESN. -- `training_method`: The method for training the ESN (default: `StandardRidge(0.0)`). - -# Returns -- The trained ESN model. Its type and structure depend on `training_method` and the ESN's implementation. - - -# Returns -The trained ESN model. The exact type and structure of the return value depends on the -`training_method` and the specific ESN implementation. - -```julia -using ReservoirComputing - -# Initialize an ESN instance and target data -esn = ESN(train_data, reservoir=RandSparseReservoir(200), washout=10) -target_data = rand(size(train_data, 2)) - -# Train the ESN using the default training method -trained_esn = train(esn, target_data) - -# Train the ESN using a custom training method -trained_esn = train(esn, target_data, training_method=StandardRidge(1.0)) -``` - -# Notes -- When using a `Hybrid` variation, the function extends the state matrix with data from the - physical model included in the `variation`. -- The training is handled by a lower-level `_train` function which takes the new state matrix - and performs the actual training using the specified `training_method`. -""" -function train(esn::AbstractEchoStateNetwork, - target_data, - training_method = StandardRidge(0.0)) - variation = esn.variation - - if esn.variation isa Hybrid - states = vcat(esn.states, esn.variation.model_data[:, 2:end]) - else - states = esn.states - end - states_new = esn.states_type(esn.nla_type, states, esn.train_data[:, 1:end]) - - return _train(states_new, target_data, training_method) -end - -function pad_esnstate(variation::Hybrid, states_type, x_pad, x, model_prediction_data) - x_tmp = vcat(x, model_prediction_data) - x_pad = pad_state!(states_type, x_pad, x_tmp) -end - -function pad_esnstate!(variation, states_type, x_pad, x, args...) - x_pad = pad_state!(states_type, x_pad, x) -end diff --git a/src/esn/esn.jl b/src/esn/esn.jl new file mode 100644 index 00000000..3592ed8d --- /dev/null +++ b/src/esn/esn.jl @@ -0,0 +1,143 @@ +abstract type AbstractEchoStateNetwork <: AbstractReservoirComputer end +struct ESN{I, S, N, T, O, M, B, ST, W, IS} <: AbstractEchoStateNetwork + res_size::I + train_data::S + nla_type::N + input_matrix::T + reservoir_driver::O + reservoir_matrix::M + bias_vector::B + states_type::ST + washout::W + states::IS +end + +""" + ESN(train_data; kwargs...) -> ESN + +Creates an Echo State Network (ESN) using specified parameters and training data, suitable for various machine learning tasks. + +# Parameters +- `train_data`: Matrix of training data (columns as time steps, rows as features). +- `variation`: Variation of ESN (default: `Default()`). +- `input_layer`: Input layer of ESN (default: `DenseLayer()`). +- `reservoir`: Reservoir of the ESN (default: `RandSparseReservoir(100)`). +- `bias`: Bias vector for each time step (default: `NullLayer()`). +- `reservoir_driver`: Mechanism for evolving reservoir states (default: `RNN()`). +- `nla_type`: Non-linear activation type (default: `NLADefault()`). +- `states_type`: Format for storing states (default: `StandardStates()`). +- `washout`: Initial time steps to discard (default: `0`). +- `matrix_type`: Type of matrices used internally (default: type of `train_data`). + +# Returns +- An initialized ESN instance with specified parameters. + +# Examples +```julia +using ReservoirComputing + +train_data = rand(10, 100) # 10 features, 100 time steps + +esn = ESN(train_data, reservoir=RandSparseReservoir(200), washout=10) +``` +""" +function ESN( + train_data, + in_size::Int, + res_size::Int; + input_layer = scaled_rand, + reservoir = rand_sparse, + bias = zeros64, + reservoir_driver = RNN(), + nla_type = NLADefault(), + states_type = StandardStates(), + washout = 0, + rng = _default_rng(), + T = Float32, + matrix_type = typeof(train_data) +) + + if states_type isa AbstractPaddedStates + in_size = size(train_data, 1) + 1 + train_data = vcat(Adapt.adapt(matrix_type, ones(1, size(train_data, 2))), + train_data) + end + + reservoir_matrix = reservoir(rng, T, res_size, res_size) + input_matrix = input_layer(rng, T, in_size, res_size) + bias_vector = bias(rng, res_size) + inner_res_driver = reservoir_driver_params(reservoir_driver, res_size, in_size) + states = create_states(inner_res_driver, train_data, washout, reservoir_matrix, + input_matrix, bias_vector) + train_data = train_data[:, (washout + 1):end] + + ESN(res_size, train_data, nla_type, input_matrix, + inner_res_driver, reservoir_matrix, bias_vector, states_type, washout, + states) +end + +function (esn::ESN)(prediction::AbstractPrediction, + output_layer::AbstractOutputLayer; + last_state = esn.states[:, [end]], + kwargs...) + pred_len = prediction.prediction_len + + return obtain_esn_prediction(esn, prediction, last_state, output_layer; + kwargs...) +end + +#training dispatch on esn +""" + train(esn::AbstractEchoStateNetwork, target_data, training_method = StandardRidge(0.0)) + +Trains an Echo State Network (ESN) using the provided target data and a specified training method. + +# Parameters +- `esn::AbstractEchoStateNetwork`: The ESN instance to be trained. +- `target_data`: Supervised training data for the ESN. +- `training_method`: The method for training the ESN (default: `StandardRidge(0.0)`). + +# Returns +- The trained ESN model. Its type and structure depend on `training_method` and the ESN's implementation. + + +# Returns +The trained ESN model. The exact type and structure of the return value depends on the +`training_method` and the specific ESN implementation. + +```julia +using ReservoirComputing + +# Initialize an ESN instance and target data +esn = ESN(train_data, reservoir=RandSparseReservoir(200), washout=10) +target_data = rand(size(train_data, 2)) + +# Train the ESN using the default training method +trained_esn = train(esn, target_data) + +# Train the ESN using a custom training method +trained_esn = train(esn, target_data, training_method=StandardRidge(1.0)) +``` + +# Notes +- When using a `Hybrid` variation, the function extends the state matrix with data from the + physical model included in the `variation`. +- The training is handled by a lower-level `_train` function which takes the new state matrix + and performs the actual training using the specified `training_method`. +""" +function train(esn::ESN, + target_data, + training_method = StandardRidge(0.0)) + states_new = esn.states_type(esn.nla_type, esn.states, esn.train_data[:, 1:end]) + + return _train(states_new, target_data, training_method) +end + +#function pad_esnstate(variation::Hybrid, states_type, x_pad, x, model_prediction_data) +# x_tmp = vcat(x, model_prediction_data) +# x_pad = pad_state!(states_type, x_pad, x_tmp) +#end + +function pad_esnstate!(variation, states_type, x_pad, x, args...) + x_pad = pad_state!(states_type, x_pad, x) +end diff --git a/src/esn/esn_input_layers.jl b/src/esn/esn_input_layers.jl index 42b82282..e79df1d5 100644 --- a/src/esn/esn_input_layers.jl +++ b/src/esn/esn_input_layers.jl @@ -24,9 +24,45 @@ function scaled_rand( scaling=T(0.1) ) where {T <: Number} - @assert length(dims) == 2, "The dimensions must define a matrix (e.g., (res_size, in_size))" - res_size, in_size = dims layer_matrix = rand(rng, Uniform(-scaling, scaling), res_size, in_size) return layer_matrix end + +""" + weighted_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; scaling=T(0.1)) where {T <: Number} + +Create and return a matrix representing a weighted input layer for Echo State Networks (ESNs). This initializer generates a weighted input matrix with random non-zero elements distributed uniformly within the range [-`scaling`, `scaling`], inspired by the approach in [^Lu]. + +# Arguments +- `rng`: An instance of `AbstractRNG` for random number generation. +- `T`: The data type for the elements of the matrix. +- `dims`: A 2-element tuple specifying the approximate reservoir size and input size (e.g., `(approx_res_size, in_size)`). +- `scaling`: The scaling factor for the weight distribution. Defaults to `T(0.1)`. + +# Returns +A matrix representing the weighted input layer as defined in [^Lu2017]. The matrix dimensions will be adjusted to ensure each input unit connects to an equal number of reservoir units. + +# Example +```julia +rng = Random.default_rng() +input_layer = weighted_init(rng, Float64, (3, 300); scaling=0.2) +``` +# References +[^Lu2017]: Lu, Zhixin, et al. + "Reservoir observers: Model-free inference of unmeasured variables in chaotic systems." + Chaos: An Interdisciplinary Journal of Nonlinear Science 27.4 (2017): 041102. +""" +function weighted_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; scaling=T(0.1)) where {T <: Number} + + in_size, approx_res_size = dims + res_size = Int(floor(approx_res_size / in_size) * in_size) + layer_matrix = zeros(T, res_size, in_size) + q = floor(Int, res_size / in_size) + + for i in 1:in_size + layer_matrix[((i - 1) * q + 1):((i) * q), i] = rand(rng, Uniform(-scaling, scaling), q) + end + + return layer_matrix +end diff --git a/src/esn/esn_predict.jl b/src/esn/esn_predict.jl index daa6fc34..5955c762 100644 --- a/src/esn/esn_predict.jl +++ b/src/esn/esn_predict.jl @@ -13,7 +13,7 @@ function obtain_esn_prediction(esn, out = initial_conditions states = similar(esn.states, size(esn.states, 1), prediction_len) - out_pad = allocate_outpad(esn.variation, esn.states_type, out) + out_pad = allocate_outpad(esn, esn.states_type, out) tmp_array = allocate_tmp(esn.reservoir_driver, typeof(esn.states), esn.res_size) x_new = esn.states_type(esn.nla_type, x, out_pad) @@ -43,7 +43,7 @@ function obtain_esn_prediction(esn, out = initial_conditions states = similar(esn.states, size(esn.states, 1), prediction_len) - out_pad = allocate_outpad(esn.variation, esn.states_type, out) + out_pad = allocate_outpad(esn, esn.states_type, out) tmp_array = allocate_tmp(esn.reservoir_driver, typeof(esn.states), esn.res_size) x_new = esn.states_type(esn.nla_type, x, out_pad) @@ -60,20 +60,6 @@ end #prediction dispatch on esn function next_state_prediction!(esn::ESN, x, x_new, out, out_pad, i, tmp_array, args...) - return _variation_prediction!(esn.variation, esn, x, x_new, out, out_pad, i, tmp_array, - args...) -end - -#dispatch the prediction on the esn variation -function _variation_prediction!(variation, - esn, - x, - x_new, - out, - out_pad, - i, - tmp_array, - args...) out_pad = pad_state!(esn.states_type, out_pad, out) xv = @view x[1:(esn.res_size)] x = next_state!(x, esn.reservoir_driver, xv, out_pad, @@ -82,15 +68,8 @@ function _variation_prediction!(variation, return x, x_new end -function _variation_prediction!(variation::Hybrid, - esn, - x, - x_new, - out, - out_pad, - i, - tmp_array, - model_prediction_data) +#TODO fixme @MatrinuzziFra +function next_state_prediction!(hesn::HybridESN, x, x_new, out, out_pad, i, tmp_array, args...) out_tmp = vcat(out, model_prediction_data[:, i]) out_pad = pad_state!(esn.states_type, out_pad, out_tmp) x = next_state!(x, esn.reservoir_driver, x[1:(esn.res_size)], out_pad, @@ -100,12 +79,12 @@ function _variation_prediction!(variation::Hybrid, return x, x_new end -function allocate_outpad(variation, states_type, out) +function allocate_outpad(ens::ESN, states_type, out) return allocate_singlepadding(states_type, out) end -function allocate_outpad(variation::Hybrid, states_type, out) - pad_length = length(out) + size(variation.model_data[:, 1], 1) +function allocate_outpad(hesn::HybridESN, states_type, out) + pad_length = length(out) + size(hesn.model.model_data[:, 1], 1) out_tmp = Adapt.adapt(typeof(out), zeros(pad_length)) return allocate_singlepadding(states_type, out_tmp) end diff --git a/src/esn/hybridesn.jl b/src/esn/hybridesn.jl new file mode 100644 index 00000000..d1cfdac9 --- /dev/null +++ b/src/esn/hybridesn.jl @@ -0,0 +1,82 @@ +struct HybridESN{I, S, V, N, T, O, M, B, ST, W, IS} <: AbstractEchoStateNetwork + res_size::I + train_data::S + model::V + nla_type::N + input_matrix::T + reservoir_driver::O + reservoir_matrix::M + bias_vector::B + states_type::ST + washout::W + states::IS +end + +struct KnowledgeModel{T, K, O, I, S, D} + prior_model::T + u0::K + tspan::O + dt::I + datasize::S + model_data::D +end + +""" + Hybrid(prior_model, u0, tspan, datasize) + +Constructs a `Hybrid` variation of Echo State Networks (ESNs) integrating a knowledge-based model +(`prior_model`) with ESNs for advanced training and prediction in chaotic systems. + +# Parameters +- `prior_model`: A knowledge-based model function for integration with ESNs. +- `u0`: Initial conditions for the model. +- `tspan`: Time span as a tuple, indicating the duration for model operation. +- `datasize`: The size of the data to be processed. + +# Returns +- A `Hybrid` struct instance representing the combined ESN and knowledge-based model. + +This method is effective for chaotic processes as highlighted in [^Pathak]. + +Reference: +[^Pathak]: Jaideep Pathak et al. + "Hybrid Forecasting of Chaotic Processes: + Using Machine Learning in Conjunction with a Knowledge-Based Model" (2018). +""" +function KnowledgeModel(prior_model, u0, tspan, datasize) + trange = collect(range(tspan[1], tspan[2], length = datasize)) + dt = trange[2] - trange[1] + tsteps = push!(trange, dt + trange[end]) + tspan_new = (tspan[1], dt + tspan[2]) + model_data = prior_model(u0, tspan_new, tsteps) + return Hybrid(prior_model, u0, tspan, dt, datasize, model_data) +end + +function (hesn::HybridESN)(prediction::AbstractPrediction, + output_layer::AbstractOutputLayer; + last_state = esn.states[:, [end]], + kwargs...) + + pred_len = prediction.prediction_len + + model = variation.prior_model + predict_tsteps = [variation.tspan[2] + variation.dt] + [append!(predict_tsteps, predict_tsteps[end] + variation.dt) for i in 1:pred_len] + tspan_new = (variation.tspan[2] + variation.dt, predict_tsteps[end]) + u0 = variation.model_data[:, end] + model_pred_data = model(u0, tspan_new, predict_tsteps)[:, 2:end] + + return obtain_esn_prediction(esn, prediction, last_state, output_layer, + model_pred_data; + kwargs...) +end + +function train(hesn::HybridESN, + target_data, + training_method = StandardRidge(0.0)) + + states = vcat(esn.states, esn.variation.model_data[:, 2:end]) + states_new = esn.states_type(esn.nla_type, states, esn.train_data[:, 1:end]) + + return _train(states_new, target_data, training_method) +end \ No newline at end of file From 8cdc64640194657a04c4f5326bc7b9fb167b375d Mon Sep 17 00:00:00 2001 From: Jay-sanjay <134289328+Jay-sanjay@users.noreply.github.com> Date: Sun, 21 Jan 2024 22:06:02 +0530 Subject: [PATCH 06/26] sparse layer --- src/esn/esn_input_layers.jl | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/src/esn/esn_input_layers.jl b/src/esn/esn_input_layers.jl index e79df1d5..32682b46 100644 --- a/src/esn/esn_input_layers.jl +++ b/src/esn/esn_input_layers.jl @@ -66,3 +66,39 @@ function weighted_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; scaling=T( return layer_matrix end + + +""" + sparse_layer(rng::AbstractRNG, ::Type{T}, dims::Integer...; scaling=T(0.1), sparsity=T(0.1)) where {T <: Number} + +Create and return a sparse layer matrix for use in neural network models. +The matrix will be of size specified by `dims`, with the specified `sparsity` and `scaling`. + +# Arguments +- `rng`: An instance of `AbstractRNG` for random number generation. +- `T`: The data type for the elements of the matrix. +- `dims`: Dimensions of the resulting sparse layer matrix. +- `scaling`: The scaling factor for the sparse layer matrix. Defaults to 0.1. +- `sparsity`: The sparsity level of the sparse layer matrix, controlling the fraction of zero elements. Defaults to 0.1. + +# Returns +A sparse layer matrix. + + +# Example +```julia +rng = Random.default_rng() +input_layer = weighted_init(rng, Float64, (3, 300); scaling=0.2, sparsity=0.1) +``` +""" +function sparse_layer(rng::AbstractRNG,::Type{T}, dims::Integer...; + scaling=T(0.1), sparsity=T(0.1)) where {T <: Number} + + in_size, res_size = dims + layer_matrix = Matrix(sprand(rng, res_size, in_size, sparsity)) + layer_matrix = 2.0 .* (layer_matrix .- 0.5) + replace!(layer_matrix, -1.0 => 0.0) + layer_matrix = scaling .* layer_matrix + + return layer_matrix +end From 2d964054f474f0051aa99d86dfc1ed08a097c6ed Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Sun, 21 Jan 2024 18:08:04 +0100 Subject: [PATCH 07/26] HybridESN working, modified docs and readme to follow changes --- README.md | 3 ++ docs/src/esn_tutorials/hybrid.md | 19 +++++--- docs/src/esn_tutorials/lorenz_basic.md | 14 +++--- src/ReservoirComputing.jl | 3 +- src/esn/esn.jl | 6 +-- src/esn/esn_input_layers.jl | 36 --------------- src/esn/esn_predict.jl | 10 ++--- src/esn/hybridesn.jl | 62 +++++++++++++++++++++----- 8 files changed, 84 insertions(+), 69 deletions(-) diff --git a/README.md b/README.md index 60f12963..2b123d66 100644 --- a/README.md +++ b/README.md @@ -104,3 +104,6 @@ If you use this library in your work, please cite: url = {http://jmlr.org/papers/v23/22-0611.html} } ``` +## Acknowledgements + +This project was possible thanks to initial funding through the [Google summer of code](https://summerofcode.withgoogle.com/) 2020 program. Francesco M. further acknowledges [ScaDS.AI](https://scads.ai/) and [RSC4Earth](https://rsc4earth.de/) for supporting the current progress on the library. diff --git a/docs/src/esn_tutorials/hybrid.md b/docs/src/esn_tutorials/hybrid.md index bf274f01..5682e9db 100644 --- a/docs/src/esn_tutorials/hybrid.md +++ b/docs/src/esn_tutorials/hybrid.md @@ -1,6 +1,6 @@ # Hybrid Echo State Networks -Following the idea of giving physical information to machine learning models, the hybrid echo state networks [^1] try to achieve this results by feeding model data into the ESN. In this example, it is explained how to create and leverage such models in ReservoirComputing.jl. The full script for this example is available [here](https://github.com/MartinuzziFrancesco/reservoir-computing-examples/blob/main/hybrid/hybrid.jl). This example was run on Julia v1.7.2. +Following the idea of giving physical information to machine learning models, the hybrid echo state networks [^1] try to achieve this results by feeding model data into the ESN. In this example, it is explained how to create and leverage such models in ReservoirComputing.jl. ## Generating the data @@ -47,17 +47,22 @@ function prior_model_data_generator(u0, tspan, tsteps, model = lorenz) end ``` -Given the initial condition, time span, and time steps, this function returns the data for the chosen model. Now, using the `Hybrid` method, it is possible to input all this information to the model. +Given the initial condition, time span, and time steps, this function returns the data for the chosen model. Now, using the `KnowledgeModel` method, it is possible to input all this information to `HybridESN`. ```@example hybrid using ReservoirComputing, Random Random.seed!(42) -hybrid = Hybrid(prior_model_data_generator, u0, tspan_train, train_len) +km = KnowledgeModel(prior_model_data_generator, u0, tspan_train, train_len) -esn = ESN(input_data, - reservoir = RandSparseReservoir(300), - variation = hybrid) +in_size = 3 +res_size = 300 +hesn = HybridESN( + km, + input_data, + in_size, + res_size; + reservoir = rand_sparse) ``` ## Training and Prediction @@ -65,7 +70,7 @@ esn = ESN(input_data, The training and prediction of the Hybrid ESN can proceed as usual: ```@example hybrid -output_layer = train(esn, target_data, StandardRidge(0.3)) +output_layer = train(hesn, target_data, StandardRidge(0.3)) output = esn(Generative(predict_len), output_layer) ``` diff --git a/docs/src/esn_tutorials/lorenz_basic.md b/docs/src/esn_tutorials/lorenz_basic.md index f820a36d..364c2c4b 100644 --- a/docs/src/esn_tutorials/lorenz_basic.md +++ b/docs/src/esn_tutorials/lorenz_basic.md @@ -1,6 +1,6 @@ # Lorenz System Forecasting -This example expands on the readme Lorenz system forecasting to better showcase how to use methods and functions provided in the library for Echo State Networks. Here the prediction method used is `Generative`, for a more detailed explanation of the differences between `Generative` and `Predictive` please refer to the other examples given in the documentation. The full script for this example is available [here](https://github.com/MartinuzziFrancesco/reservoir-computing-examples/blob/main/lorenz_basic/lorenz_basic.jl). This example was run on Julia v1.7.2. +This example expands on the readme Lorenz system forecasting to better showcase how to use methods and functions provided in the library for Echo State Networks. Here the prediction method used is `Generative`, for a more detailed explanation of the differences between `Generative` and `Predictive` please refer to the other examples given in the documentation. ## Generating the data @@ -46,15 +46,15 @@ using ReservoirComputing #define ESN parameters res_size = 300 +in_size = 3 res_radius = 1.2 res_sparsity = 6 / 300 input_scaling = 0.1 #build ESN struct -esn = ESN(input_data; - variation = Default(), - reservoir = RandSparseReservoir(res_size, radius = res_radius, sparsity = res_sparsity), - input_layer = WeightedLayer(scaling = input_scaling), +esn = ESN(input_data, in_size, res_size; + reservoir = rand_sparse(;radius = res_radius, sparsity = res_sparsity), + input_layer = weighted_init(;scaling = input_scaling), reservoir_driver = RNN(), nla_type = NLADefault(), states_type = StandardStates()) @@ -62,9 +62,9 @@ esn = ESN(input_data; Most of the parameters chosen here mirror the default ones, so a direct call is not necessary. The readme example is identical to this one, except for the explicit call. Going line by line to see what is happening, starting from `res_size`: this value determines the dimensions of the reservoir matrix. In this case, a size of 300 has been chosen, so the reservoir matrix will be 300 x 300. This is not always the case, since some input layer constructions can modify the dimensions of the reservoir, but in that case, everything is taken care of internally. -The `res_radius` determines the scaling of the spectral radius of the reservoir matrix; a proper scaling is necessary to assure the Echo State Property. The default value in the `RandSparseReservoir()` method is 1.0 in accordance with the most commonly followed guidelines found in the literature (see [^2] and references therein). The `sparsity` of the reservoir matrix in this case is obtained by choosing a degree of connections and dividing that by the reservoir size. Of course, it is also possible to simply choose any value between 0.0 and 1.0 to test behaviors for different sparsity values. In this example, the call to the parameters inside `RandSparseReservoir()` was done explicitly to showcase the meaning of each of them, but it is also possible to simply pass the values directly, like so `RandSparseReservoir(1.2, 6/300)`. +The `res_radius` determines the scaling of the spectral radius of the reservoir matrix; a proper scaling is necessary to assure the Echo State Property. The default value in the `rand_sparse` method is 1.0 in accordance with the most commonly followed guidelines found in the literature (see [^2] and references therein). The `sparsity` of the reservoir matrix in this case is obtained by choosing a degree of connections and dividing that by the reservoir size. Of course, it is also possible to simply choose any value between 0.0 and 1.0 to test behaviors for different sparsity values. -The value of `input_scaling` determines the upper and lower bounds of the uniform distribution of the weights in the `WeightedLayer()`. Like before, this value can be passed either as an argument or as a keyword argument `WeightedLayer(0.1)`. The value of 0.1 represents the default. The default input layer is the `DenseLayer`, a fully connected layer. The details of the weighted version can be found in [^3], for this example, this version returns the best results. +The value of `input_scaling` determines the upper and lower bounds of the uniform distribution of the weights in the `weighted_init`. The value of 0.1 represents the default. The default input layer is the `scaled_rand`, a dense matrix. The details of the weighted version can be found in [^3], for this example, this version returns the best results. The reservoir driver represents the dynamics of the reservoir. In the standard ESN definition, these dynamics are obtained through a Recurrent Neural Network (RNN), and this is reflected by calling the `RNN` driver for the `ESN` struct. This option is set as the default, and unless there is the need to change parameters, it is not needed. The full equation is the following: diff --git a/src/ReservoirComputing.jl b/src/ReservoirComputing.jl index f8668b54..aa38de0c 100644 --- a/src/ReservoirComputing.jl +++ b/src/ReservoirComputing.jl @@ -23,7 +23,8 @@ export scaled_rand, weighted_init export rand_sparse, delay_line export RNN, MRNN, GRU, GRUParams, FullyGated, Minimal export ESN, train -export DeepESN, HybridESN +export HybridESN, KnowledgeModel +export DeepESN export RECA, train export RandomMapping, RandomMaps export Generative, Predictive, OutputLayer diff --git a/src/esn/esn.jl b/src/esn/esn.jl index 3592ed8d..2beb552f 100644 --- a/src/esn/esn.jl +++ b/src/esn/esn.jl @@ -138,6 +138,6 @@ end # x_pad = pad_state!(states_type, x_pad, x_tmp) #end -function pad_esnstate!(variation, states_type, x_pad, x, args...) - x_pad = pad_state!(states_type, x_pad, x) -end +#function pad_esnstate!(variation, states_type, x_pad, x, args...) +# x_pad = pad_state!(states_type, x_pad, x) +#end diff --git a/src/esn/esn_input_layers.jl b/src/esn/esn_input_layers.jl index 32682b46..e79df1d5 100644 --- a/src/esn/esn_input_layers.jl +++ b/src/esn/esn_input_layers.jl @@ -66,39 +66,3 @@ function weighted_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; scaling=T( return layer_matrix end - - -""" - sparse_layer(rng::AbstractRNG, ::Type{T}, dims::Integer...; scaling=T(0.1), sparsity=T(0.1)) where {T <: Number} - -Create and return a sparse layer matrix for use in neural network models. -The matrix will be of size specified by `dims`, with the specified `sparsity` and `scaling`. - -# Arguments -- `rng`: An instance of `AbstractRNG` for random number generation. -- `T`: The data type for the elements of the matrix. -- `dims`: Dimensions of the resulting sparse layer matrix. -- `scaling`: The scaling factor for the sparse layer matrix. Defaults to 0.1. -- `sparsity`: The sparsity level of the sparse layer matrix, controlling the fraction of zero elements. Defaults to 0.1. - -# Returns -A sparse layer matrix. - - -# Example -```julia -rng = Random.default_rng() -input_layer = weighted_init(rng, Float64, (3, 300); scaling=0.2, sparsity=0.1) -``` -""" -function sparse_layer(rng::AbstractRNG,::Type{T}, dims::Integer...; - scaling=T(0.1), sparsity=T(0.1)) where {T <: Number} - - in_size, res_size = dims - layer_matrix = Matrix(sprand(rng, res_size, in_size, sparsity)) - layer_matrix = 2.0 .* (layer_matrix .- 0.5) - replace!(layer_matrix, -1.0 => 0.0) - layer_matrix = scaling .* layer_matrix - - return layer_matrix -end diff --git a/src/esn/esn_predict.jl b/src/esn/esn_predict.jl index 5955c762..1b4cd462 100644 --- a/src/esn/esn_predict.jl +++ b/src/esn/esn_predict.jl @@ -69,13 +69,13 @@ function next_state_prediction!(esn::ESN, x, x_new, out, out_pad, i, tmp_array, end #TODO fixme @MatrinuzziFra -function next_state_prediction!(hesn::HybridESN, x, x_new, out, out_pad, i, tmp_array, args...) +function next_state_prediction!(hesn::HybridESN, x, x_new, out, out_pad, i, tmp_array, model_prediction_data) out_tmp = vcat(out, model_prediction_data[:, i]) - out_pad = pad_state!(esn.states_type, out_pad, out_tmp) - x = next_state!(x, esn.reservoir_driver, x[1:(esn.res_size)], out_pad, - esn.reservoir_matrix, esn.input_matrix, esn.bias_vector, tmp_array) + out_pad = pad_state!(hesn.states_type, out_pad, out_tmp) + x = next_state!(x, hesn.reservoir_driver, x[1:(hesn.res_size)], out_pad, + hesn.reservoir_matrix, hesn.input_matrix, hesn.bias_vector, tmp_array) x_tmp = vcat(x, model_prediction_data[:, i]) - x_new = esn.states_type(esn.nla_type, x_tmp, out_pad) + x_new = hesn.states_type(hesn.nla_type, x_tmp, out_pad) return x, x_new end diff --git a/src/esn/hybridesn.jl b/src/esn/hybridesn.jl index d1cfdac9..f29028fc 100644 --- a/src/esn/hybridesn.jl +++ b/src/esn/hybridesn.jl @@ -49,24 +49,66 @@ function KnowledgeModel(prior_model, u0, tspan, datasize) tsteps = push!(trange, dt + trange[end]) tspan_new = (tspan[1], dt + tspan[2]) model_data = prior_model(u0, tspan_new, tsteps) - return Hybrid(prior_model, u0, tspan, dt, datasize, model_data) + return KnowledgeModel(prior_model, u0, tspan, dt, datasize, model_data) +end + +function HybridESN( + model, + train_data, + in_size::Int, + res_size::Int; + input_layer = scaled_rand, + reservoir = rand_sparse, + bias = zeros64, + reservoir_driver = RNN(), + nla_type = NLADefault(), + states_type = StandardStates(), + washout = 0, + rng = _default_rng(), + T = Float32, + matrix_type = typeof(train_data) +) + + train_data = vcat(train_data, model.model_data[:, 1:(end - 1)]) + + if states_type isa AbstractPaddedStates + in_size = size(train_data, 1) + 1 + train_data = vcat(Adapt.adapt(matrix_type, ones(1, size(train_data, 2))), + train_data) + else + in_size = size(train_data, 1) + end + + reservoir_matrix = reservoir(rng, T, res_size, res_size) + #different from ESN, why? + input_matrix = input_layer(rng, T, res_size, in_size) + bias_vector = bias(rng, res_size) + inner_res_driver = reservoir_driver_params(reservoir_driver, res_size, in_size) + states = create_states(inner_res_driver, train_data, washout, reservoir_matrix, + input_matrix, bias_vector) + train_data = train_data[:, (washout + 1):end] + + HybridESN(res_size, train_data, model, nla_type, input_matrix, + inner_res_driver, reservoir_matrix, bias_vector, states_type, washout, + states) end function (hesn::HybridESN)(prediction::AbstractPrediction, output_layer::AbstractOutputLayer; - last_state = esn.states[:, [end]], + last_state = hesn.states[:, [end]], kwargs...) + km = hesn.model pred_len = prediction.prediction_len - model = variation.prior_model - predict_tsteps = [variation.tspan[2] + variation.dt] - [append!(predict_tsteps, predict_tsteps[end] + variation.dt) for i in 1:pred_len] - tspan_new = (variation.tspan[2] + variation.dt, predict_tsteps[end]) - u0 = variation.model_data[:, end] + model = km.prior_model + predict_tsteps = [km.tspan[2] + km.dt] + [append!(predict_tsteps, predict_tsteps[end] + km.dt) for i in 1:pred_len] + tspan_new = (km.tspan[2] + km.dt, predict_tsteps[end]) + u0 = km.model_data[:, end] model_pred_data = model(u0, tspan_new, predict_tsteps)[:, 2:end] - return obtain_esn_prediction(esn, prediction, last_state, output_layer, + return obtain_esn_prediction(hesn, prediction, last_state, output_layer, model_pred_data; kwargs...) end @@ -75,8 +117,8 @@ function train(hesn::HybridESN, target_data, training_method = StandardRidge(0.0)) - states = vcat(esn.states, esn.variation.model_data[:, 2:end]) - states_new = esn.states_type(esn.nla_type, states, esn.train_data[:, 1:end]) + states = vcat(hesn.states, hesn.model.model_data[:, 2:end]) + states_new = hesn.states_type(hesn.nla_type, states, hesn.train_data[:, 1:end]) return _train(states_new, target_data, training_method) end \ No newline at end of file From 8b6ceb128acb484f1da01ca6b4b24c5b38d1a5ba Mon Sep 17 00:00:00 2001 From: Jay-sanjay <134289328+Jay-sanjay@users.noreply.github.com> Date: Sun, 21 Jan 2024 22:51:08 +0530 Subject: [PATCH 08/26] sparse layer --- src/ReservoirComputing.jl | 4 ++-- src/esn/esn_input_layers.jl | 38 +++++++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/src/ReservoirComputing.jl b/src/ReservoirComputing.jl index aa38de0c..099e840c 100644 --- a/src/ReservoirComputing.jl +++ b/src/ReservoirComputing.jl @@ -19,7 +19,7 @@ export NLADefault, NLAT1, NLAT2, NLAT3 export StandardStates, ExtendedStates, PaddedStates, PaddedExtendedStates export StandardRidge, LinearModel export AbstractLayer, create_layer -export scaled_rand, weighted_init +export scaled_rand, weighted_init, sparse_layer export rand_sparse, delay_line export RNN, MRNN, GRU, GRUParams, FullyGated, Minimal export ESN, train @@ -75,7 +75,7 @@ function Predictive(prediction_data) end #fallbacks for initializers -for initializer in (:rand_sparse, :delay_line, :scaled_rand, :weighted_init) +for initializer in (:rand_sparse, :delay_line, :scaled_rand, :weighted_init, :sparse_layer) NType = ifelse(initializer === :rand_sparse, Real, Number) @eval function ($initializer)(dims::Integer...; kwargs...) return $initializer(_default_rng(), Float32, dims...; kwargs...) diff --git a/src/esn/esn_input_layers.jl b/src/esn/esn_input_layers.jl index e79df1d5..fa73453e 100644 --- a/src/esn/esn_input_layers.jl +++ b/src/esn/esn_input_layers.jl @@ -66,3 +66,41 @@ function weighted_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; scaling=T( return layer_matrix end + + + +""" + sparse_layer(rng::AbstractRNG, ::Type{T}, dims::Integer...; scaling=T(0.1), sparsity=T(0.1)) where {T <: Number} + +Create and return a sparse layer matrix for use in neural network models. +The matrix will be of size specified by `dims`, with the specified `sparsity` and `scaling`. + +# Arguments +- `rng`: An instance of `AbstractRNG` for random number generation. +- `T`: The data type for the elements of the matrix. +- `dims`: Dimensions of the resulting sparse layer matrix. +- `scaling`: The scaling factor for the sparse layer matrix. Defaults to 0.1. +- `sparsity`: The sparsity level of the sparse layer matrix, controlling the fraction of zero elements. Defaults to 0.1. + +# Returns +A sparse layer matrix. + + +# Example +```julia +rng = Random.default_rng() +input_layer = weighted_init(rng, Float64, (3, 300); scaling=0.2, sparsity=0.1) +``` +""" +function sparse_layer(rng::AbstractRNG,::Type{T}, dims::Integer...; + scaling=T(0.1), sparsity=T(0.1)) where {T <: Number} + + in_size, res_size = dims + layer_matrix = Matrix(sprand(rng, res_size, in_size, sparsity)) + layer_matrix = 2.0 .* (layer_matrix .- 0.5) + replace!(layer_matrix, -1.0 => 0.0) + layer_matrix = scaling .* layer_matrix + + return layer_matrix +end + From d8f5822ab1d9d6c938a0644a5745ea05fceea3ca Mon Sep 17 00:00:00 2001 From: Jay-sanjay <134289328+Jay-sanjay@users.noreply.github.com> Date: Sun, 21 Jan 2024 23:33:35 +0530 Subject: [PATCH 09/26] informed layer --- src/ReservoirComputing.jl | 4 +-- src/esn/esn_input_layers.jl | 60 ++++++++++++++++++++++++++++++++++++- 2 files changed, 61 insertions(+), 3 deletions(-) diff --git a/src/ReservoirComputing.jl b/src/ReservoirComputing.jl index 099e840c..8fcbeff6 100644 --- a/src/ReservoirComputing.jl +++ b/src/ReservoirComputing.jl @@ -19,7 +19,7 @@ export NLADefault, NLAT1, NLAT2, NLAT3 export StandardStates, ExtendedStates, PaddedStates, PaddedExtendedStates export StandardRidge, LinearModel export AbstractLayer, create_layer -export scaled_rand, weighted_init, sparse_layer +export scaled_rand, weighted_init, sparse_layer, informed_layer export rand_sparse, delay_line export RNN, MRNN, GRU, GRUParams, FullyGated, Minimal export ESN, train @@ -75,7 +75,7 @@ function Predictive(prediction_data) end #fallbacks for initializers -for initializer in (:rand_sparse, :delay_line, :scaled_rand, :weighted_init, :sparse_layer) +for initializer in (:rand_sparse, :delay_line, :scaled_rand, :weighted_init, :sparse_layer, :informed_layer) NType = ifelse(initializer === :rand_sparse, Real, Number) @eval function ($initializer)(dims::Integer...; kwargs...) return $initializer(_default_rng(), Float32, dims...; kwargs...) diff --git a/src/esn/esn_input_layers.jl b/src/esn/esn_input_layers.jl index fa73453e..8871ee47 100644 --- a/src/esn/esn_input_layers.jl +++ b/src/esn/esn_input_layers.jl @@ -89,7 +89,7 @@ A sparse layer matrix. # Example ```julia rng = Random.default_rng() -input_layer = weighted_init(rng, Float64, (3, 300); scaling=0.2, sparsity=0.1) +input_layer = sparse_layer(rng, Float64, (3, 300); scaling=0.2, sparsity=0.1) ``` """ function sparse_layer(rng::AbstractRNG,::Type{T}, dims::Integer...; @@ -104,3 +104,61 @@ function sparse_layer(rng::AbstractRNG,::Type{T}, dims::Integer...; return layer_matrix end + +""" + informed_layer(rng::AbstractRNG, ::Type{T}, dims::Integer...; scaling=T(0.1), model_in_size, gamma=T(0.5)) where {T <: Number} + +Create a layer of a neural network. + +# Arguments +- `rng::AbstractRNG`: The random number generator. +- `T::Type`: The data type. +- `dims::Integer...`: The dimensions of the layer. +- `scaling::T = T(0.1)`: The scaling factor for the input matrix. +- `model_in_size`: The size of the input model. +- `gamma::T = T(0.5)`: The gamma value. + +# Returns +- `input_matrix`: The created input matrix for the layer. + +# Example +```julia +rng = Random.GLOBAL_RNG +dims = (100, 200) +model_in_size = 50 +input_matrix = informed_layer(rng, Float64, dims; model_in_size=model_in_size) +``` +""" +function informed_layer(rng::AbstractRNG, ::Type{T}, dims::Integer...; + scaling=T(0.1), model_in_size, gamma=T(0.5)) where {T <: Number} + + res_size, in_size = dims + state_size = in_size - model_in_size + + if state_size <= 0 + throw(DimensionMismatch("in_size must be greater than model_in_size")) + end + + input_matrix = zeros(res_size, in_size) + zero_connections = zeros(in_size) + num_for_state = floor(Int, res_size * gamma) + num_for_model = floor(Int, res_size * (1 - gamma)) + + for i in 1:num_for_state + idxs = findall(Bool[zero_connections .== input_matrix[i, :] + for i in 1:size(input_matrix, 1)]) + random_row_idx = idxs[rand(rng, 1:end)] + random_clm_idx = range(1, state_size, step=1)[rand(rng, 1:end)] + input_matrix[random_row_idx, random_clm_idx] = rand(rng, Uniform(-scaling, scaling)) + end + + for i in 1:num_for_model + idxs = findall(Bool[zero_connections .== input_matrix[i, :] + for i in 1:size(input_matrix, 1)]) + random_row_idx = idxs[rand(rng, 1:end)] + random_clm_idx = range(state_size + 1, in_size, step=1)[rand(rng, 1:end)] + input_matrix[random_row_idx, random_clm_idx] = rand(rng, Uniform(-scaling, scaling)) + end + + return input_matrix +end \ No newline at end of file From 9273407ebf36e022f01f6d864dd8357122e7f4f0 Mon Sep 17 00:00:00 2001 From: Jay-sanjay <134289328+Jay-sanjay@users.noreply.github.com> Date: Mon, 22 Jan 2024 01:39:16 +0530 Subject: [PATCH 10/26] bernoulli sample layer --- src/ReservoirComputing.jl | 4 +-- src/esn/esn_input_layers.jl | 59 ++++++++++++++++++++++++++++++++++++- 2 files changed, 60 insertions(+), 3 deletions(-) diff --git a/src/ReservoirComputing.jl b/src/ReservoirComputing.jl index 8fcbeff6..498234c6 100644 --- a/src/ReservoirComputing.jl +++ b/src/ReservoirComputing.jl @@ -19,7 +19,7 @@ export NLADefault, NLAT1, NLAT2, NLAT3 export StandardStates, ExtendedStates, PaddedStates, PaddedExtendedStates export StandardRidge, LinearModel export AbstractLayer, create_layer -export scaled_rand, weighted_init, sparse_layer, informed_layer +export scaled_rand, weighted_init, sparse_layer, informed_layer, bernoulli_sample_layer export rand_sparse, delay_line export RNN, MRNN, GRU, GRUParams, FullyGated, Minimal export ESN, train @@ -75,7 +75,7 @@ function Predictive(prediction_data) end #fallbacks for initializers -for initializer in (:rand_sparse, :delay_line, :scaled_rand, :weighted_init, :sparse_layer, :informed_layer) +for initializer in (:rand_sparse, :delay_line, :scaled_rand, :weighted_init, :sparse_layer, :informed_layer, :bernoulli_sample_layer) NType = ifelse(initializer === :rand_sparse, Real, Number) @eval function ($initializer)(dims::Integer...; kwargs...) return $initializer(_default_rng(), Float32, dims...; kwargs...) diff --git a/src/esn/esn_input_layers.jl b/src/esn/esn_input_layers.jl index 8871ee47..76d34283 100644 --- a/src/esn/esn_input_layers.jl +++ b/src/esn/esn_input_layers.jl @@ -123,7 +123,7 @@ Create a layer of a neural network. # Example ```julia -rng = Random.GLOBAL_RNG +rng = Random.default_rng() dims = (100, 200) model_in_size = 50 input_matrix = informed_layer(rng, Float64, dims; model_in_size=model_in_size) @@ -161,4 +161,61 @@ function informed_layer(rng::AbstractRNG, ::Type{T}, dims::Integer...; end return input_matrix +end + + + +function BernoulliSample(; p = 0.5) + return BernoulliSample(p) +end + +function create_minimum_input(BernoulliSample(; p = 0.5), res_size, in_size, weight, rng::AbstractRNG) + p = BernoulliSample(; p = 0.5) + input_matrix = zeros(res_size, in_size) + for i in 1:res_size + for j in 1:in_size + rand(rng, Bernoulli(p)) ? (input_matrix[i, j] = weight) : (input_matrix[i, j] = -weight) + end + end + return input_matrix +end + + +""" + bernoulli_sample_layer(rng::AbstractRNG, ::Type{T}, dims::Integer...; + weight = 0.1, + sampling = BernoulliSample(0.5) + ) where {T <: Number} + +Create a layer matrix using the Bernoulli sampling method. + +# Arguments +- `rng::AbstractRNG`: The random number generator. +- `dims::Integer...`: The dimensions of the layer matrix. +- `weight::Number = 0.1`: The weight value. +- `sampling::BernoulliSample = BernoulliSample(0.5)`: The Bernoulli sampling object. + +# Returns +The generated layer matrix. + +# Example +```julia +rng = Random.default_rng() +dims = (100, 200) +weight = 0.1 +sampling = BernoulliSample(0.5) +layer_matrix = bernoulli_sample_layer(rng, Float64, dims; weight=weight, sampling=sampling) +``` + +""" +function bernoulli_sample_layer(rng::AbstractRNG, ::Type{T}, dims::Integer...; + weight = 0.1, + sampling = BernoulliSample(0.5) + )where {T <: Number} + + res_size, in_size = dims + sampling = sampling + weight = weight + layer_matrix = create_minimum_input(sampling, res_size, in_size, weight, rng) + return layer_matrix end \ No newline at end of file From dc3905fcaeafcc43014cb20ed84feca3c37473d0 Mon Sep 17 00:00:00 2001 From: Jay-sanjay <134289328+Jay-sanjay@users.noreply.github.com> Date: Mon, 22 Jan 2024 02:39:29 +0530 Subject: [PATCH 11/26] irrational sample layer --- src/ReservoirComputing.jl | 4 +- src/esn/esn_input_layers.jl | 73 +++++++++++++++++++++++++++++++++++-- 2 files changed, 72 insertions(+), 5 deletions(-) diff --git a/src/ReservoirComputing.jl b/src/ReservoirComputing.jl index 498234c6..9d4074ba 100644 --- a/src/ReservoirComputing.jl +++ b/src/ReservoirComputing.jl @@ -19,7 +19,7 @@ export NLADefault, NLAT1, NLAT2, NLAT3 export StandardStates, ExtendedStates, PaddedStates, PaddedExtendedStates export StandardRidge, LinearModel export AbstractLayer, create_layer -export scaled_rand, weighted_init, sparse_layer, informed_layer, bernoulli_sample_layer +export scaled_rand, weighted_init, sparse_layer, informed_layer, bernoulli_sample_layer, irrational_sample_layer export rand_sparse, delay_line export RNN, MRNN, GRU, GRUParams, FullyGated, Minimal export ESN, train @@ -75,7 +75,7 @@ function Predictive(prediction_data) end #fallbacks for initializers -for initializer in (:rand_sparse, :delay_line, :scaled_rand, :weighted_init, :sparse_layer, :informed_layer, :bernoulli_sample_layer) +for initializer in (:rand_sparse, :delay_line, :scaled_rand, :weighted_init, :sparse_layer, :informed_layer, :bernoulli_sample_layer, :irrational_sample_layer) NType = ifelse(initializer === :rand_sparse, Real, Number) @eval function ($initializer)(dims::Integer...; kwargs...) return $initializer(_default_rng(), Float32, dims...; kwargs...) diff --git a/src/esn/esn_input_layers.jl b/src/esn/esn_input_layers.jl index 76d34283..61cfc742 100644 --- a/src/esn/esn_input_layers.jl +++ b/src/esn/esn_input_layers.jl @@ -166,11 +166,11 @@ end function BernoulliSample(; p = 0.5) - return BernoulliSample(p) + return p end -function create_minimum_input(BernoulliSample(; p = 0.5), res_size, in_size, weight, rng::AbstractRNG) - p = BernoulliSample(; p = 0.5) +function create_minimum_input(p = sampling, res_size, in_size, weight, rng::AbstractRNG) + input_matrix = zeros(res_size, in_size) for i in 1:res_size for j in 1:in_size @@ -217,5 +217,72 @@ function bernoulli_sample_layer(rng::AbstractRNG, ::Type{T}, dims::Integer...; sampling = sampling weight = weight layer_matrix = create_minimum_input(sampling, res_size, in_size, weight, rng) + return layer_matrix +end + + + +function IrrationalSample(; irrational = pi, start = 1) + return irrational, start +end + +function create_minimum_input(irrational, start = sampling, res_size, in_size, weight, rng::AbstractRNG) + setprecision(BigFloat, Int(ceil(log2(10) * (res_size * in_size + start + 1)))) + ir_string = string(BigFloat(irrational)) |> collect + deleteat!(ir_string, findall(x -> x == '.', ir_string)) + ir_array = zeros(length(ir_string)) + input_matrix = zeros(res_size, in_size) + + for i in 1:length(ir_string) + ir_array[i] = parse(Int, ir_string[i]) + end + + for i in 1:res_size + for j in 1:in_size + random_number = rand(rng) + input_matrix[i, j] = random_number < 0.5 ? -weight : weight + end + end + + return input_matrix +end + +""" + irrational_sample_layer(rng::AbstractRNG, ::Type{T}, dims::Integer...; + weight = 0.1, + sampling = IrrationalSample(; irrational = pi, start = 1) + ) where {T <: Number} + +Create a layer matrix using the provided random number generator and sampling parameters. + +# Arguments +- `rng::AbstractRNG`: The random number generator used to generate random numbers. +- `dims::Integer...`: The dimensions of the layer matrix. +- `weight`: The weight used to fill the layer matrix. Default is 0.1. +- `sampling`: The sampling parameters used to generate the input matrix. Default is IrrationalSample(irrational = pi, start = 1). + +# Returns +The layer matrix generated using the provided random number generator and sampling parameters. + +# Example +```julia +using Random +rng = Random.default_rng() +dims = (3, 2) +weight = 0.5 +layer_matrix = irrational_sample_layer(rng, Float64, dims; weight = weight, sampling = IrrationalSample(irrational = sqrt(2), start = 1)) +``` +""" + +function irrational_sample_layer(rng::AbstractRNG, ::Type{T}, dims::Integer...; + weight = 0.1, + sampling = IrrationalSample(; irrational = pi, start = 1) + )where {T <: Number} + + res_size, in_size = dims + sampling = sampling + weight = weight + layer_matrix = create_minimum_input(sampling, res_size, in_size, weight, rng) + return layer_matrix end \ No newline at end of file From 2d3cc4d0afd51fb3530cb500f855d075f8b5d8bc Mon Sep 17 00:00:00 2001 From: Jay-sanjay <134289328+Jay-sanjay@users.noreply.github.com> Date: Wed, 31 Jan 2024 22:24:05 +0530 Subject: [PATCH 12/26] added delay line backward reservoir --- src/ReservoirComputing.jl | 4 ++-- src/esn/esn_reservoirs.jl | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/src/ReservoirComputing.jl b/src/ReservoirComputing.jl index 9d4074ba..d86b0b19 100644 --- a/src/ReservoirComputing.jl +++ b/src/ReservoirComputing.jl @@ -20,7 +20,7 @@ export StandardStates, ExtendedStates, PaddedStates, PaddedExtendedStates export StandardRidge, LinearModel export AbstractLayer, create_layer export scaled_rand, weighted_init, sparse_layer, informed_layer, bernoulli_sample_layer, irrational_sample_layer -export rand_sparse, delay_line +export rand_sparse, delay_line, delay_line_backward_reservoir export RNN, MRNN, GRU, GRUParams, FullyGated, Minimal export ESN, train export HybridESN, KnowledgeModel @@ -75,7 +75,7 @@ function Predictive(prediction_data) end #fallbacks for initializers -for initializer in (:rand_sparse, :delay_line, :scaled_rand, :weighted_init, :sparse_layer, :informed_layer, :bernoulli_sample_layer, :irrational_sample_layer) +for initializer in (:rand_sparse, :delay_line, :delay_line_backward_reservoir, :scaled_rand, :weighted_init, :sparse_layer, :informed_layer, :bernoulli_sample_layer, :irrational_sample_layer) NType = ifelse(initializer === :rand_sparse, Real, Number) @eval function ($initializer)(dims::Integer...; kwargs...) return $initializer(_default_rng(), Float32, dims...; kwargs...) diff --git a/src/esn/esn_reservoirs.jl b/src/esn/esn_reservoirs.jl index ab2eaf42..53ac961f 100644 --- a/src/esn/esn_reservoirs.jl +++ b/src/esn/esn_reservoirs.jl @@ -69,6 +69,44 @@ function delay_line(rng::AbstractRNG, return reservoir_matrix end +""" + delay_line_backward_reservoir(rng::AbstractRNG, ::Type{T}, dims::Integer...; + weight = T(0.1), fb_weight = T(0.2)) where {T <: Number} + +Create a delay line backward reservoir with the specified by `dims` and weights. Creates a matrix with backward connections +as described in [^Rodan2010]. The `weight` and `fb_weight` can be passed as either arguments or +keyword arguments, and they determine the absolute values of the connections in the reservoir. + +# Arguments +- `rng::AbstractRNG`: Random number generator. +- `T::Type`: Type of the elements in the reservoir matrix. +- `dims::Integer...`: Dimensions of the reservoir matrix. +- `weight::T`: The weight determines the absolute value of forward connections in the reservoir, and is set to 0.1 by default. +- `fb_weight::T`: The `fb_weight` determines the absolute value of backward connections in the reservoir, and is set to 0.2 by default. + + +# Returns +Reservoir matrix with the dimensions specified by `dims` and weights. + +# References +[^Rodan2010]: Rodan, Ali, and Peter Tino. "Minimum complexity echo state network." +IEEE transactions on neural networks 22.1 (2010): 131-144. +""" +function delay_line_backward_reservoir(rng::AbstractRNG, + ::Type{T}, + dims::Integer...; + weight = T(0.1), + fb_weight = T(0.2)) where {T <: Number} + reservoir_matrix = zeros(res_size, res_size) + + for i in 1:(res_size - 1) + reservoir_matrix[i + 1, i] = weight + reservoir_matrix[i, i + 1] = fb_weight + end + + return reservoir_matrix +end + # from WeightInitializers.jl, TODO @MartinuzziFrancesco consider importing package function _default_rng() @static if VERSION >= v"1.7" From 95c78702783ab8d41f074bc857ca93d824d8f6a3 Mon Sep 17 00:00:00 2001 From: Jay-sanjay <134289328+Jay-sanjay@users.noreply.github.com> Date: Wed, 31 Jan 2024 23:06:01 +0530 Subject: [PATCH 13/26] added cycle jumps reservoir --- src/ReservoirComputing.jl | 4 ++-- src/esn/esn_reservoirs.jl | 50 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 2 deletions(-) diff --git a/src/ReservoirComputing.jl b/src/ReservoirComputing.jl index d86b0b19..86ac5da9 100644 --- a/src/ReservoirComputing.jl +++ b/src/ReservoirComputing.jl @@ -20,7 +20,7 @@ export StandardStates, ExtendedStates, PaddedStates, PaddedExtendedStates export StandardRidge, LinearModel export AbstractLayer, create_layer export scaled_rand, weighted_init, sparse_layer, informed_layer, bernoulli_sample_layer, irrational_sample_layer -export rand_sparse, delay_line, delay_line_backward_reservoir +export rand_sparse, delay_line, delay_line_backward_reservoir, cycle_jumps_reservoir export RNN, MRNN, GRU, GRUParams, FullyGated, Minimal export ESN, train export HybridESN, KnowledgeModel @@ -75,7 +75,7 @@ function Predictive(prediction_data) end #fallbacks for initializers -for initializer in (:rand_sparse, :delay_line, :delay_line_backward_reservoir, :scaled_rand, :weighted_init, :sparse_layer, :informed_layer, :bernoulli_sample_layer, :irrational_sample_layer) +for initializer in (:rand_sparse, :delay_line, :delay_line_backward_reservoir, :cycle_jumps_reservoir, :scaled_rand, :weighted_init, :sparse_layer, :informed_layer, :bernoulli_sample_layer, :irrational_sample_layer) NType = ifelse(initializer === :rand_sparse, Real, Number) @eval function ($initializer)(dims::Integer...; kwargs...) return $initializer(_default_rng(), Float32, dims...; kwargs...) diff --git a/src/esn/esn_reservoirs.jl b/src/esn/esn_reservoirs.jl index 53ac961f..a3f93f24 100644 --- a/src/esn/esn_reservoirs.jl +++ b/src/esn/esn_reservoirs.jl @@ -107,6 +107,56 @@ function delay_line_backward_reservoir(rng::AbstractRNG, return reservoir_matrix end + +""" + cycle_jumps_reservoir(rng::AbstractRNG, ::Type{T}, dims::Integer...; + cycle_weight = T(0.1), jump_weight = T(0.1), jump_size = 3) where {T <: Number} + +Create a cycle jumps reservoir with the specified dimensions, cycle weight, jump weight, and jump size. + +# Arguments +- `rng::AbstractRNG`: Random number generator. +- `T::Type`: Type of the elements in the reservoir matrix. +- `dims::Integer...`: Dimensions of the reservoir matrix. +- `cycle_weight::T = T(0.1)`: The weight of cycle connections. +- `jump_weight::T = T(0.1)`: The weight of jump connections. +- `jump_size::Int = 3`: The number of steps between jump connections. + +# Returns +Reservoir matrix with the specified dimensions, cycle weight, jump weight, and jump size. + +# References +[^Rodan2012]: Rodan, Ali, and Peter Tiňo. "Simple deterministically constructed cycle reservoirs +with regular jumps." Neural computation 24.7 (2012): 1822-1852. +""" +function cycle_jumps_reservoir(rng::AbstractRNG, + ::Type{T}, + dims::Integer...; + cycle_weight = T(0.1), + jump_weight = T(0.1), + jump_size = T(3)) where {T <: Number} + + reservoir_matrix = zeros(T, dims...) + + for i in 1:(dims[1] - 1) + reservoir_matrix[i + 1, i] = cycle_weight + end + + reservoir_matrix[1, dims[1]] = cycle_weight + + for i in 1:jump_size:(dims[1] - jump_size) + tmp = (i + jump_size) % dims[1] + if tmp == 0 + tmp = dims[1] + end + reservoir_matrix[i, tmp] = jump_weight + reservoir_matrix[tmp, i] = jump_weight + end + + return reservoir_matrix +end + + # from WeightInitializers.jl, TODO @MartinuzziFrancesco consider importing package function _default_rng() @static if VERSION >= v"1.7" From 2e89a3859ce5d6ede20d38bb05e75c722d3c68c3 Mon Sep 17 00:00:00 2001 From: Jay-sanjay <134289328+Jay-sanjay@users.noreply.github.com> Date: Wed, 31 Jan 2024 23:07:21 +0530 Subject: [PATCH 14/26] added simple cycle reservoir --- src/ReservoirComputing.jl | 4 ++-- src/esn/esn_reservoirs.jl | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/src/ReservoirComputing.jl b/src/ReservoirComputing.jl index 86ac5da9..1b34d098 100644 --- a/src/ReservoirComputing.jl +++ b/src/ReservoirComputing.jl @@ -20,7 +20,7 @@ export StandardStates, ExtendedStates, PaddedStates, PaddedExtendedStates export StandardRidge, LinearModel export AbstractLayer, create_layer export scaled_rand, weighted_init, sparse_layer, informed_layer, bernoulli_sample_layer, irrational_sample_layer -export rand_sparse, delay_line, delay_line_backward_reservoir, cycle_jumps_reservoir +export rand_sparse, delay_line, delay_line_backward_reservoir, cycle_jumps_reservoir, simple_cycle_reservoir export RNN, MRNN, GRU, GRUParams, FullyGated, Minimal export ESN, train export HybridESN, KnowledgeModel @@ -75,7 +75,7 @@ function Predictive(prediction_data) end #fallbacks for initializers -for initializer in (:rand_sparse, :delay_line, :delay_line_backward_reservoir, :cycle_jumps_reservoir, :scaled_rand, :weighted_init, :sparse_layer, :informed_layer, :bernoulli_sample_layer, :irrational_sample_layer) +for initializer in (:rand_sparse, :delay_line, :delay_line_backward_reservoir, :cycle_jumps_reservoir, :simple_cycle_reservoir, :scaled_rand, :weighted_init, :sparse_layer, :informed_layer, :bernoulli_sample_layer, :irrational_sample_layer) NType = ifelse(initializer === :rand_sparse, Real, Number) @eval function ($initializer)(dims::Integer...; kwargs...) return $initializer(_default_rng(), Float32, dims...; kwargs...) diff --git a/src/esn/esn_reservoirs.jl b/src/esn/esn_reservoirs.jl index a3f93f24..276554d2 100644 --- a/src/esn/esn_reservoirs.jl +++ b/src/esn/esn_reservoirs.jl @@ -157,6 +157,40 @@ function cycle_jumps_reservoir(rng::AbstractRNG, end +""" + simple_cycle_reservoir(rng::AbstractRNG, ::Type{T}, dims::Integer...; + weight = T(0.1)) where {T <: Number} + +Create a simple cycle reservoir with the specified dimensions and weight. + +# Arguments +- `rng::AbstractRNG`: Random number generator. +- `T::Type`: Type of the elements in the reservoir matrix. +- `dims::Integer...`: Dimensions of the reservoir matrix. +- `weight::T = T(0.1)`: Weight of the connections in the reservoir matrix. + +# Returns +Reservoir matrix with the dimensions specified by `dims` and weights. + +# References +[^Rodan2010]: Rodan, Ali, and Peter Tino. "Minimum complexity echo state network." +IEEE transactions on neural networks 22.1 (2010): 131-144. +""" +function simple_cycle_reservoir(rng::AbstractRNG, + ::Type{T}, + dims::Integer...; + weight = T(0.1)) where {T <: Number} + reservoir_matrix = zeros(T, dims...) + + for i in 1:(dims[1] - 1) + reservoir_matrix[i + 1, i] = weight + end + + reservoir_matrix[1, dims[1]] = weight + return reservoir_matrix +end + + # from WeightInitializers.jl, TODO @MartinuzziFrancesco consider importing package function _default_rng() @static if VERSION >= v"1.7" From 6ecceb267e942f0da59ae3edf4f5b244768d7b0f Mon Sep 17 00:00:00 2001 From: Jay-sanjay <134289328+Jay-sanjay@users.noreply.github.com> Date: Wed, 31 Jan 2024 23:08:33 +0530 Subject: [PATCH 15/26] added pseudo svd reservoir --- src/ReservoirComputing.jl | 4 +- src/esn/esn_reservoirs.jl | 84 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 2 deletions(-) diff --git a/src/ReservoirComputing.jl b/src/ReservoirComputing.jl index 1b34d098..f7a2046a 100644 --- a/src/ReservoirComputing.jl +++ b/src/ReservoirComputing.jl @@ -20,7 +20,7 @@ export StandardStates, ExtendedStates, PaddedStates, PaddedExtendedStates export StandardRidge, LinearModel export AbstractLayer, create_layer export scaled_rand, weighted_init, sparse_layer, informed_layer, bernoulli_sample_layer, irrational_sample_layer -export rand_sparse, delay_line, delay_line_backward_reservoir, cycle_jumps_reservoir, simple_cycle_reservoir +export rand_sparse, delay_line, delay_line_backward_reservoir, cycle_jumps_reservoir, simple_cycle_reservoir, pseudo_svd_reservoir export RNN, MRNN, GRU, GRUParams, FullyGated, Minimal export ESN, train export HybridESN, KnowledgeModel @@ -75,7 +75,7 @@ function Predictive(prediction_data) end #fallbacks for initializers -for initializer in (:rand_sparse, :delay_line, :delay_line_backward_reservoir, :cycle_jumps_reservoir, :simple_cycle_reservoir, :scaled_rand, :weighted_init, :sparse_layer, :informed_layer, :bernoulli_sample_layer, :irrational_sample_layer) +for initializer in (:rand_sparse, :delay_line, :delay_line_backward_reservoir, :cycle_jumps_reservoir, :simple_cycle_reservoir, :pseudo_svd_reservoir, :scaled_rand, :weighted_init, :sparse_layer, :informed_layer, :bernoulli_sample_layer, :irrational_sample_layer) NType = ifelse(initializer === :rand_sparse, Real, Number) @eval function ($initializer)(dims::Integer...; kwargs...) return $initializer(_default_rng(), Float32, dims...; kwargs...) diff --git a/src/esn/esn_reservoirs.jl b/src/esn/esn_reservoirs.jl index 276554d2..fccb3f24 100644 --- a/src/esn/esn_reservoirs.jl +++ b/src/esn/esn_reservoirs.jl @@ -191,6 +191,90 @@ function simple_cycle_reservoir(rng::AbstractRNG, end + + +""" + pseudo_svd_reservoir(rng::AbstractRNG, ::Type{T}, dims::Integer...; + max_value, sparsity, sorted = true, reverse_sort = false) where {T <: Number} + + Returns an initializer to build a sparse reservoir matrix with the given `sparsity` by using a pseudo-SVD approach as described in [^yang]. + +# Arguments +- `rng::AbstractRNG`: Random number generator. +- `T::Type`: Type of the elements in the reservoir matrix. +- `dims::Integer...`: Dimensions of the reservoir matrix. +- `max_value`: The maximum absolute value of elements in the matrix. +- `sparsity`: The desired sparsity level of the reservoir matrix. +- `sorted`: A boolean indicating whether to sort the singular values before creating the diagonal matrix. By default, it is set to `true`. +- `reverse_sort`: A boolean indicating whether to reverse the sorted singular values. By default, it is set to `false`. + +# Returns +Reservoir matrix with the specified dimensions, max value, and sparsity. + +# References +This reservoir initialization method, based on a pseudo-SVD approach, is inspired by the work in [^yang], which focuses on designing polynomial echo state networks for time series prediction. + +[^yang]: Yang, Cuili, et al. "_Design of polynomial echo state networks for time series prediction._" Neurocomputing 290 (2018): 148-160. +""" +function pseudo_svd_reservoir(rng::AbstractRNG, + ::Type{T}, + dims::Integer...; + max_value, sparsity; + sorted = true, + reverse_sort = false) where {T <: Number} + + reservoir_matrix = create_diag( dims[1], max_value, sorted = sorted, reverse_sort = reverse_sort) + tmp_sparsity = get_sparsity(reservoir_matrix, dims[1]) + + while tmp_sparsity <= sparsity + reservoir_matrix *= create_qmatrix(dims[1], rand(1:dims[1]), rand(1:dims[1]), rand() * 2 - 1) + tmp_sparsity = get_sparsity(reservoir_matrix, dims[1]) + end + + return reservoir_matrix +end + +function create_diag(dim, max_value; sorted = true, reverse_sort = false) + diagonal_matrix = zeros(dim, dim) + if sorted == true + if reverse_sort == true + diagonal_values = sort(rand(dim) .* max_value, rev = true) + diagonal_values[1] = max_value + else + diagonal_values = sort(rand(dim) .* max_value) + diagonal_values[end] = max_value + end + else + diagonal_values = rand(dim) .* max_value + end + + for i in 1:dim + diagonal_matrix[i, i] = diagonal_values[i] + end + + return diagonal_matrix +end + +function create_qmatrix(dim, coord_i, coord_j, theta) + qmatrix = zeros(dim, dim) + + for i in 1:dim + qmatrix[i, i] = 1.0 + end + + qmatrix[coord_i, coord_i] = cos(theta) + qmatrix[coord_j, coord_j] = cos(theta) + qmatrix[coord_i, coord_j] = -sin(theta) + qmatrix[coord_j, coord_i] = sin(theta) + return qmatrix +end + +function get_sparsity(M, dim) + return size(M[M .!= 0], 1) / (dim * dim - size(M[M .!= 0], 1)) #nonzero/zero elements +end + + + # from WeightInitializers.jl, TODO @MartinuzziFrancesco consider importing package function _default_rng() @static if VERSION >= v"1.7" From 4c3925b20cace670eb879ab2e322392f5cf10f74 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Sun, 4 Feb 2024 18:34:43 +0100 Subject: [PATCH 16/26] small changes to inits, start of streamline --- README.md | 3 +- docs/src/esn_tutorials/hybrid.md | 3 +- docs/src/esn_tutorials/lorenz_basic.md | 4 +- src/ReservoirComputing.jl | 8 +- src/esn/deepesn.jl | 92 ++++++----- src/esn/esn.jl | 31 ++-- src/esn/esn_input_layers.jl | 203 ++++++++++--------------- src/esn/esn_predict.jl | 11 +- src/esn/esn_reservoir_drivers.jl | 24 +-- src/esn/esn_reservoirs.jl | 91 ++++++----- src/esn/hybridesn.jl | 45 +++--- test/esn/test_drivers.jl | 83 ++++------ test/esn/test_inits.jl | 90 +++++++++++ test/esn/test_input_layers.jl | 55 ------- test/esn/test_reservoirs.jl | 41 ----- test/runtests.jl | 7 +- test/test_states.jl | 76 ++++----- test/utils.jl | 5 - 18 files changed, 394 insertions(+), 478 deletions(-) create mode 100644 test/esn/test_inits.jl delete mode 100644 test/esn/test_input_layers.jl delete mode 100644 test/esn/test_reservoirs.jl delete mode 100644 test/utils.jl diff --git a/README.md b/README.md index 2b123d66..b4667be0 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ Now that we have the data we can initialize the ESN with the chosen parameters. input_size = 3 res_size = 300 esn = ESN(input_data, input_size, res_size; - reservoir = rand_sparse(;radius = 1.2, sparsity = 6 / res_size), + reservoir = rand_sparse(; radius = 1.2, sparsity = 6 / res_size), input_layer = weighted_init, nla_type = NLAT2()) ``` @@ -104,6 +104,7 @@ If you use this library in your work, please cite: url = {http://jmlr.org/papers/v23/22-0611.html} } ``` + ## Acknowledgements This project was possible thanks to initial funding through the [Google summer of code](https://summerofcode.withgoogle.com/) 2020 program. Francesco M. further acknowledges [ScaDS.AI](https://scads.ai/) and [RSC4Earth](https://rsc4earth.de/) for supporting the current progress on the library. diff --git a/docs/src/esn_tutorials/hybrid.md b/docs/src/esn_tutorials/hybrid.md index 5682e9db..25797089 100644 --- a/docs/src/esn_tutorials/hybrid.md +++ b/docs/src/esn_tutorials/hybrid.md @@ -57,8 +57,7 @@ km = KnowledgeModel(prior_model_data_generator, u0, tspan_train, train_len) in_size = 3 res_size = 300 -hesn = HybridESN( - km, +hesn = HybridESN(km, input_data, in_size, res_size; diff --git a/docs/src/esn_tutorials/lorenz_basic.md b/docs/src/esn_tutorials/lorenz_basic.md index 364c2c4b..1a66f834 100644 --- a/docs/src/esn_tutorials/lorenz_basic.md +++ b/docs/src/esn_tutorials/lorenz_basic.md @@ -53,8 +53,8 @@ input_scaling = 0.1 #build ESN struct esn = ESN(input_data, in_size, res_size; - reservoir = rand_sparse(;radius = res_radius, sparsity = res_sparsity), - input_layer = weighted_init(;scaling = input_scaling), + reservoir = rand_sparse(; radius = res_radius, sparsity = res_sparsity), + input_layer = weighted_init(; scaling = input_scaling), reservoir_driver = RNN(), nla_type = NLADefault(), states_type = StandardStates()) diff --git a/src/ReservoirComputing.jl b/src/ReservoirComputing.jl index f7a2046a..4afb7be1 100644 --- a/src/ReservoirComputing.jl +++ b/src/ReservoirComputing.jl @@ -19,8 +19,8 @@ export NLADefault, NLAT1, NLAT2, NLAT3 export StandardStates, ExtendedStates, PaddedStates, PaddedExtendedStates export StandardRidge, LinearModel export AbstractLayer, create_layer -export scaled_rand, weighted_init, sparse_layer, informed_layer, bernoulli_sample_layer, irrational_sample_layer -export rand_sparse, delay_line, delay_line_backward_reservoir, cycle_jumps_reservoir, simple_cycle_reservoir, pseudo_svd_reservoir +export scaled_rand, weighted_init, sparse_init, informed_init, minimal_init +export rand_sparse, delay_line, delay_line_backward, cycle_jumps, simple_cycle, pseudo_svd export RNN, MRNN, GRU, GRUParams, FullyGated, Minimal export ESN, train export HybridESN, KnowledgeModel @@ -75,7 +75,9 @@ function Predictive(prediction_data) end #fallbacks for initializers -for initializer in (:rand_sparse, :delay_line, :delay_line_backward_reservoir, :cycle_jumps_reservoir, :simple_cycle_reservoir, :pseudo_svd_reservoir, :scaled_rand, :weighted_init, :sparse_layer, :informed_layer, :bernoulli_sample_layer, :irrational_sample_layer) +for initializer in (:rand_sparse, :delay_line, :delay_line_backward, :cycle_jumps, + :simple_cycle, :pseudo_svd, + :scaled_rand, :weighted_init, :sparse_init, :informed_init, :minimal_init) NType = ifelse(initializer === :rand_sparse, Real, Number) @eval function ($initializer)(dims::Integer...; kwargs...) return $initializer(_default_rng(), Float32, dims...; kwargs...) diff --git a/src/esn/deepesn.jl b/src/esn/deepesn.jl index 4ab05f39..3402d2b5 100644 --- a/src/esn/deepesn.jl +++ b/src/esn/deepesn.jl @@ -1,7 +1,6 @@ -struct DeepESN{I, S, V, N, T, O, M, B, ST, W, IS} <: AbstractEchoStateNetwork +struct DeepESN{I, S, N, T, O, M, B, ST, W, IS} <: AbstractEchoStateNetwork res_size::I train_data::S - variation::V nla_type::N input_matrix::T reservoir_driver::O @@ -12,22 +11,19 @@ struct DeepESN{I, S, V, N, T, O, M, B, ST, W, IS} <: AbstractEchoStateNetwork states::IS end -function DeepESN( - train_data, - in_size::Int, - res_size::AbstractArray; - input_layer = scaled_rand, - reservoir = rand_sparse, - bias = zeros64, - reservoir_driver = RNN(), - nla_type = NLADefault(), - states_type = StandardStates(), - washout = 0, - rng = _default_rng(), - T=Float64, - matrix_type = typeof(train_data) -) - +function DeepESN(train_data, + in_size::Int, + res_size::AbstractArray; + input_layer = scaled_rand, + reservoir = rand_sparse, + bias = zeros64, + reservoir_driver = RNN(), + nla_type = NLADefault(), + states_type = StandardStates(), + washout = 0, + rng = _default_rng(), + T = Float64, + matrix_type = typeof(train_data)) if states_type isa AbstractPaddedStates in_size = size(train_data, 1) + 1 train_data = vcat(Adapt.adapt(matrix_type, ones(1, size(train_data, 2))), @@ -42,43 +38,43 @@ function DeepESN( input_matrix, bias_vector) train_data = train_data[:, (washout + 1):end] - ESN(sum(res_size), train_data, variation, nla_type, input_matrix, + DeepESN(sum(res_size), train_data, variation, nla_type, input_matrix, inner_res_driver, reservoir_matrix, bias_vector, states_type, washout, states) end function obtain_layers(in_size, - input_layer, - reservoir::Vector, - bias; - matrix_type = Matrix{Float64}) -esn_depth = length(reservoir) -input_res_sizes = [get_ressize(reservoir[i]) for i in 1:esn_depth] -in_sizes = zeros(Int, esn_depth) -in_sizes[2:end] = input_res_sizes[1:(end - 1)] -in_sizes[1] = in_size + input_layer, + reservoir::Vector, + bias; + matrix_type = Matrix{Float64}) + esn_depth = length(reservoir) + input_res_sizes = [get_ressize(reservoir[i]) for i in 1:esn_depth] + in_sizes = zeros(Int, esn_depth) + in_sizes[2:end] = input_res_sizes[1:(end - 1)] + in_sizes[1] = in_size + + if input_layer isa Array + input_matrix = [create_layer(input_layer[j], input_res_sizes[j], in_sizes[j], + matrix_type = matrix_type) for j in 1:esn_depth] + else + _input_layer = fill(input_layer, esn_depth) + input_matrix = [create_layer(_input_layer[k], input_res_sizes[k], in_sizes[k], + matrix_type = matrix_type) for k in 1:esn_depth] + end -if input_layer isa Array - input_matrix = [create_layer(input_layer[j], input_res_sizes[j], in_sizes[j], - matrix_type = matrix_type) for j in 1:esn_depth] -else - _input_layer = fill(input_layer, esn_depth) - input_matrix = [create_layer(_input_layer[k], input_res_sizes[k], in_sizes[k], + res_sizes = [get_ressize(input_matrix[j]) for j in 1:esn_depth] + reservoir_matrix = [create_reservoir(reservoir[k], res_sizes[k], matrix_type = matrix_type) for k in 1:esn_depth] -end -res_sizes = [get_ressize(input_matrix[j]) for j in 1:esn_depth] -reservoir_matrix = [create_reservoir(reservoir[k], res_sizes[k], - matrix_type = matrix_type) for k in 1:esn_depth] + if bias isa Array + bias_vector = [create_layer(bias[j], res_sizes[j], 1, matrix_type = matrix_type) + for j in 1:esn_depth] + else + _bias = fill(bias, esn_depth) + bias_vector = [create_layer(_bias[k], res_sizes[k], 1, matrix_type = matrix_type) + for k in 1:esn_depth] + end -if bias isa Array - bias_vector = [create_layer(bias[j], res_sizes[j], 1, matrix_type = matrix_type) - for j in 1:esn_depth] -else - _bias = fill(bias, esn_depth) - bias_vector = [create_layer(_bias[k], res_sizes[k], 1, matrix_type = matrix_type) - for k in 1:esn_depth] + return input_matrix, reservoir_matrix, bias_vector, res_sizes end - -return input_matrix, reservoir_matrix, bias_vector, res_sizes -end \ No newline at end of file diff --git a/src/esn/esn.jl b/src/esn/esn.jl index 2beb552f..260e7aff 100644 --- a/src/esn/esn.jl +++ b/src/esn/esn.jl @@ -41,22 +41,19 @@ train_data = rand(10, 100) # 10 features, 100 time steps esn = ESN(train_data, reservoir=RandSparseReservoir(200), washout=10) ``` """ -function ESN( - train_data, - in_size::Int, - res_size::Int; - input_layer = scaled_rand, - reservoir = rand_sparse, - bias = zeros64, - reservoir_driver = RNN(), - nla_type = NLADefault(), - states_type = StandardStates(), - washout = 0, - rng = _default_rng(), - T = Float32, - matrix_type = typeof(train_data) -) - +function ESN(train_data, + in_size::Int, + res_size::Int; + input_layer = scaled_rand, + reservoir = rand_sparse, + bias = zeros64, + reservoir_driver = RNN(), + nla_type = NLADefault(), + states_type = StandardStates(), + washout = 0, + rng = _default_rng(), + T = Float32, + matrix_type = typeof(train_data)) if states_type isa AbstractPaddedStates in_size = size(train_data, 1) + 1 train_data = vcat(Adapt.adapt(matrix_type, ones(1, size(train_data, 2))), @@ -64,7 +61,7 @@ function ESN( end reservoir_matrix = reservoir(rng, T, res_size, res_size) - input_matrix = input_layer(rng, T, in_size, res_size) + input_matrix = input_layer(rng, T, res_size, in_size) bias_vector = bias(rng, res_size) inner_res_driver = reservoir_driver_params(reservoir_driver, res_size, in_size) states = create_states(inner_res_driver, train_data, washout, reservoir_matrix, diff --git a/src/esn/esn_input_layers.jl b/src/esn/esn_input_layers.jl index 61cfc742..10ea2330 100644 --- a/src/esn/esn_input_layers.jl +++ b/src/esn/esn_input_layers.jl @@ -17,15 +17,12 @@ A matrix of type with dimensions specified by `dims`. Each element of the matrix rng = Random.default_rng() matrix = scaled_rand(rng, Float64, (100, 50); scaling=0.2) """ -function scaled_rand( - rng::AbstractRNG, - ::Type{T}, - dims::Integer...; - scaling=T(0.1) -) where {T <: Number} - +function scaled_rand(rng::AbstractRNG, + ::Type{T}, + dims::Integer...; + scaling = T(0.1)) where {T <: Number} res_size, in_size = dims - layer_matrix = rand(rng, Uniform(-scaling, scaling), res_size, in_size) + layer_matrix = T.(rand(rng, Uniform(-scaling, scaling), res_size, in_size)) return layer_matrix end @@ -53,24 +50,27 @@ input_layer = weighted_init(rng, Float64, (3, 300); scaling=0.2) "Reservoir observers: Model-free inference of unmeasured variables in chaotic systems." Chaos: An Interdisciplinary Journal of Nonlinear Science 27.4 (2017): 041102. """ -function weighted_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; scaling=T(0.1)) where {T <: Number} - - in_size, approx_res_size = dims +function weighted_init(rng::AbstractRNG, + ::Type{T}, + dims::Integer...; + scaling = T(0.1)) where {T <: Number} + approx_res_size, in_size = dims res_size = Int(floor(approx_res_size / in_size) * in_size) layer_matrix = zeros(T, res_size, in_size) q = floor(Int, res_size / in_size) for i in 1:in_size - layer_matrix[((i - 1) * q + 1):((i) * q), i] = rand(rng, Uniform(-scaling, scaling), q) + layer_matrix[((i - 1) * q + 1):((i) * q), i] = rand(rng, + Uniform(-scaling, scaling), + q) end return layer_matrix end - - +# TODO: @MartinuzziFrancesco remove when pr gets into WeightInitializers """ - sparse_layer(rng::AbstractRNG, ::Type{T}, dims::Integer...; scaling=T(0.1), sparsity=T(0.1)) where {T <: Number} + sparse_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; scaling=T(0.1), sparsity=T(0.1)) where {T <: Number} Create and return a sparse layer matrix for use in neural network models. The matrix will be of size specified by `dims`, with the specified `sparsity` and `scaling`. @@ -89,24 +89,22 @@ A sparse layer matrix. # Example ```julia rng = Random.default_rng() -input_layer = sparse_layer(rng, Float64, (3, 300); scaling=0.2, sparsity=0.1) +input_layer = sparse_init(rng, Float64, (3, 300); scaling=0.2, sparsity=0.1) ``` """ -function sparse_layer(rng::AbstractRNG,::Type{T}, dims::Integer...; - scaling=T(0.1), sparsity=T(0.1)) where {T <: Number} - - in_size, res_size = dims - layer_matrix = Matrix(sprand(rng, res_size, in_size, sparsity)) - layer_matrix = 2.0 .* (layer_matrix .- 0.5) - replace!(layer_matrix, -1.0 => 0.0) +function sparse_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; + scaling = T(0.1), sparsity = T(0.1)) where {T <: Number} + res_size, in_size = dims + layer_matrix = Matrix(sprand(rng, T, res_size, in_size, sparsity)) + layer_matrix = T.(2.0) .* (layer_matrix .- T.(0.5)) + replace!(layer_matrix, T(-1.0) => T(0.0)) layer_matrix = scaling .* layer_matrix return layer_matrix end - """ - informed_layer(rng::AbstractRNG, ::Type{T}, dims::Integer...; scaling=T(0.1), model_in_size, gamma=T(0.5)) where {T <: Number} + informed_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; scaling=T(0.1), model_in_size, gamma=T(0.5)) where {T <: Number} Create a layer of a neural network. @@ -126,12 +124,11 @@ Create a layer of a neural network. rng = Random.default_rng() dims = (100, 200) model_in_size = 50 -input_matrix = informed_layer(rng, Float64, dims; model_in_size=model_in_size) +input_matrix = informed_init(rng, Float64, dims; model_in_size=model_in_size) ``` """ -function informed_layer(rng::AbstractRNG, ::Type{T}, dims::Integer...; - scaling=T(0.1), model_in_size, gamma=T(0.5)) where {T <: Number} - +function informed_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; + scaling = T(0.1), model_in_size, gamma = T(0.5)) where {T <: Number} res_size, in_size = dims state_size = in_size - model_in_size @@ -148,7 +145,7 @@ function informed_layer(rng::AbstractRNG, ::Type{T}, dims::Integer...; idxs = findall(Bool[zero_connections .== input_matrix[i, :] for i in 1:size(input_matrix, 1)]) random_row_idx = idxs[rand(rng, 1:end)] - random_clm_idx = range(1, state_size, step=1)[rand(rng, 1:end)] + random_clm_idx = range(1, state_size, step = 1)[rand(rng, 1:end)] input_matrix[random_row_idx, random_clm_idx] = rand(rng, Uniform(-scaling, scaling)) end @@ -156,82 +153,90 @@ function informed_layer(rng::AbstractRNG, ::Type{T}, dims::Integer...; idxs = findall(Bool[zero_connections .== input_matrix[i, :] for i in 1:size(input_matrix, 1)]) random_row_idx = idxs[rand(rng, 1:end)] - random_clm_idx = range(state_size + 1, in_size, step=1)[rand(rng, 1:end)] + random_clm_idx = range(state_size + 1, in_size, step = 1)[rand(rng, 1:end)] input_matrix[random_row_idx, random_clm_idx] = rand(rng, Uniform(-scaling, scaling)) end return input_matrix end - - -function BernoulliSample(; p = 0.5) - return p -end - -function create_minimum_input(p = sampling, res_size, in_size, weight, rng::AbstractRNG) - - input_matrix = zeros(res_size, in_size) - for i in 1:res_size - for j in 1:in_size - rand(rng, Bernoulli(p)) ? (input_matrix[i, j] = weight) : (input_matrix[i, j] = -weight) - end - end - return input_matrix -end - - """ - bernoulli_sample_layer(rng::AbstractRNG, ::Type{T}, dims::Integer...; - weight = 0.1, - sampling = BernoulliSample(0.5) + irrational_sample_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; + weight = 0.1, + sampling = IrrationalSample(; irrational = pi, start = 1) ) where {T <: Number} -Create a layer matrix using the Bernoulli sampling method. +Create a layer matrix using the provided random number generator and sampling parameters. # Arguments -- `rng::AbstractRNG`: The random number generator. +- `rng::AbstractRNG`: The random number generator used to generate random numbers. - `dims::Integer...`: The dimensions of the layer matrix. -- `weight::Number = 0.1`: The weight value. -- `sampling::BernoulliSample = BernoulliSample(0.5)`: The Bernoulli sampling object. +- `weight`: The weight used to fill the layer matrix. Default is 0.1. +- `sampling`: The sampling parameters used to generate the input matrix. Default is IrrationalSample(irrational = pi, start = 1). # Returns -The generated layer matrix. +The layer matrix generated using the provided random number generator and sampling parameters. # Example ```julia +using Random rng = Random.default_rng() -dims = (100, 200) -weight = 0.1 -sampling = BernoulliSample(0.5) -layer_matrix = bernoulli_sample_layer(rng, Float64, dims; weight=weight, sampling=sampling) +dims = (3, 2) +weight = 0.5 +layer_matrix = irrational_sample_init(rng, Float64, dims; weight = weight, sampling = IrrationalSample(irrational = sqrt(2), start = 1)) ``` - """ -function bernoulli_sample_layer(rng::AbstractRNG, ::Type{T}, dims::Integer...; - weight = 0.1, - sampling = BernoulliSample(0.5) - )where {T <: Number} - +function minimal_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; + sampling_type::Symbol = :bernoulli, + weight::Number = T(0.1), + irrational::Real = pi, + start::Int = 1, + p::Number = T(0.5)) where {T <: Number} res_size, in_size = dims - sampling = sampling - weight = weight - layer_matrix = create_minimum_input(sampling, res_size, in_size, weight, rng) + if sampling_type == :bernoulli + layer_matrix = _create_bernoulli(p, res_size, in_size, weight, rng, T) + elseif sampling_type == :irrational + layer_matrix = _create_irrational(irrational, + start, + res_size, + in_size, + weight, + rng, + T) + else + error("Sampling type not allowed. Please use one of :bernoulli or :irrational") + end return layer_matrix end - - -function IrrationalSample(; irrational = pi, start = 1) - return irrational, start +function _create_bernoulli(p::T, + res_size::Int, + in_size::Int, + weight::T, + rng::AbstractRNG, + ::Type{T}) where {T <: Number} + input_matrix = zeros(T, res_size, in_size) + for i in 1:res_size + for j in 1:in_size + rand(rng, Bernoulli(p)) ? (input_matrix[i, j] = weight) : + (input_matrix[i, j] = -weight) + end + end + return input_matrix end -function create_minimum_input(irrational, start = sampling, res_size, in_size, weight, rng::AbstractRNG) +function _create_irrational(irrational::Irrational, + start::Int, + res_size::Int, + in_size::Int, + weight::T, + rng::AbstractRNG, + ::Type{T}) where {T <: Number} setprecision(BigFloat, Int(ceil(log2(10) * (res_size * in_size + start + 1)))) ir_string = string(BigFloat(irrational)) |> collect deleteat!(ir_string, findall(x -> x == '.', ir_string)) ir_array = zeros(length(ir_string)) - input_matrix = zeros(res_size, in_size) + input_matrix = zeros(T, res_size, in_size) for i in 1:length(ir_string) ir_array[i] = parse(Int, ir_string[i]) @@ -239,50 +244,10 @@ function create_minimum_input(irrational, start = sampling, res_size, in_size, w for i in 1:res_size for j in 1:in_size - random_number = rand(rng) - input_matrix[i, j] = random_number < 0.5 ? -weight : weight + random_number = rand(rng, T) + input_matrix[i, j] = random_number < 0.5 ? -weight : weight end end - return input_matrix + return T.(input_matrix) end - -""" - irrational_sample_layer(rng::AbstractRNG, ::Type{T}, dims::Integer...; - weight = 0.1, - sampling = IrrationalSample(; irrational = pi, start = 1) - ) where {T <: Number} - -Create a layer matrix using the provided random number generator and sampling parameters. - -# Arguments -- `rng::AbstractRNG`: The random number generator used to generate random numbers. -- `dims::Integer...`: The dimensions of the layer matrix. -- `weight`: The weight used to fill the layer matrix. Default is 0.1. -- `sampling`: The sampling parameters used to generate the input matrix. Default is IrrationalSample(irrational = pi, start = 1). - -# Returns -The layer matrix generated using the provided random number generator and sampling parameters. - -# Example -```julia -using Random -rng = Random.default_rng() -dims = (3, 2) -weight = 0.5 -layer_matrix = irrational_sample_layer(rng, Float64, dims; weight = weight, sampling = IrrationalSample(irrational = sqrt(2), start = 1)) -``` -""" - -function irrational_sample_layer(rng::AbstractRNG, ::Type{T}, dims::Integer...; - weight = 0.1, - sampling = IrrationalSample(; irrational = pi, start = 1) - )where {T <: Number} - - res_size, in_size = dims - sampling = sampling - weight = weight - layer_matrix = create_minimum_input(sampling, res_size, in_size, weight, rng) - - return layer_matrix -end \ No newline at end of file diff --git a/src/esn/esn_predict.jl b/src/esn/esn_predict.jl index 1b4cd462..05deb98c 100644 --- a/src/esn/esn_predict.jl +++ b/src/esn/esn_predict.jl @@ -69,11 +69,18 @@ function next_state_prediction!(esn::ESN, x, x_new, out, out_pad, i, tmp_array, end #TODO fixme @MatrinuzziFra -function next_state_prediction!(hesn::HybridESN, x, x_new, out, out_pad, i, tmp_array, model_prediction_data) +function next_state_prediction!(hesn::HybridESN, + x, + x_new, + out, + out_pad, + i, + tmp_array, + model_prediction_data) out_tmp = vcat(out, model_prediction_data[:, i]) out_pad = pad_state!(hesn.states_type, out_pad, out_tmp) x = next_state!(x, hesn.reservoir_driver, x[1:(hesn.res_size)], out_pad, - hesn.reservoir_matrix, hesn.input_matrix, hesn.bias_vector, tmp_array) + hesn.reservoir_matrix, hesn.input_matrix, hesn.bias_vector, tmp_array) x_tmp = vcat(x, model_prediction_data[:, i]) x_new = hesn.states_type(hesn.nla_type, x_tmp, out_pad) return x, x_new diff --git a/src/esn/esn_reservoir_drivers.jl b/src/esn/esn_reservoir_drivers.jl index 6ab198d7..98e7ab78 100644 --- a/src/esn/esn_reservoir_drivers.jl +++ b/src/esn/esn_reservoir_drivers.jl @@ -285,9 +285,9 @@ A GRUParams object containing the parameters needed for the GRU-based reservoir function GRU( ; activation_function = [NNlib.sigmoid, NNlib.sigmoid, tanh], - inner_layer = fill(DenseLayer(), 2), - reservoir = fill(RandSparseReservoir(0), 2), - bias = fill(DenseLayer(), 2), + inner_layer = fill(scaled_rand, 2), + reservoir = fill(rand_sparse, 2), + bias = fill(scaled_rand, 2), variant = FullyGated()) return GRU(activation_function, inner_layer, reservoir, bias, variant) end @@ -312,22 +312,22 @@ end #dispatch on the different gru variations function create_gru_layers(gru, variant::FullyGated, res_size, in_size) - Wz_in = create_layer(gru.inner_layer[1], res_size, in_size) - Wz = create_reservoir(gru.reservoir[1], res_size) - bz = create_layer(gru.bias[1], res_size, 1) + Wz_in = gru.inner_layer[1](res_size, in_size) + Wz = gru.reservoir[1](res_size, res_size) + bz = gru.bias[1](res_size, 1) - Wr_in = create_layer(gru.inner_layer[2], res_size, in_size) - Wr = create_reservoir(gru.reservoir[2], res_size) - br = create_layer(gru.bias[2], res_size, 1) + Wr_in = gru.inner_layer[2](res_size, in_size) + Wr = gru.reservoir[2](res_size, res_size) + br = gru.bias[2](res_size, 1) return GRUParams(gru.activation_function, variant, Wz_in, Wz, bz, Wr_in, Wr, br) end #check this one, not sure function create_gru_layers(gru, variant::Minimal, res_size, in_size) - Wz_in = create_layer(gru.inner_layer, res_size, in_size) - Wz = create_reservoir(gru.reservoir, res_size) - bz = create_layer(gru.bias, res_size, 1) + Wz_in = gru.inner_layer(res_size, in_size) + Wz = gru.reservoir(res_size, res_size) + bz = gru.bias(res_size, 1) Wr_in = nothing Wr = nothing diff --git a/src/esn/esn_reservoirs.jl b/src/esn/esn_reservoirs.jl index fccb3f24..5d9dd615 100644 --- a/src/esn/esn_reservoirs.jl +++ b/src/esn/esn_reservoirs.jl @@ -60,7 +60,7 @@ function delay_line(rng::AbstractRNG, dims::Integer...; weight = T(0.1)) where {T <: Number} reservoir_matrix = zeros(T, dims...) - @assert length(dims) == 2 && dims[1] == dims[2] "The dimensions must define a square matrix (e.g., (100, 100))" + @assert length(dims) == 2&&dims[1] == dims[2] "The dimensions must define a square matrix (e.g., (100, 100))" for i in 1:(dims[1] - 1) reservoir_matrix[i + 1, i] = weight @@ -70,7 +70,7 @@ function delay_line(rng::AbstractRNG, end """ - delay_line_backward_reservoir(rng::AbstractRNG, ::Type{T}, dims::Integer...; + delay_line_backward(rng::AbstractRNG, ::Type{T}, dims::Integer...; weight = T(0.1), fb_weight = T(0.2)) where {T <: Number} Create a delay line backward reservoir with the specified by `dims` and weights. Creates a matrix with backward connections @@ -92,12 +92,13 @@ Reservoir matrix with the dimensions specified by `dims` and weights. [^Rodan2010]: Rodan, Ali, and Peter Tino. "Minimum complexity echo state network." IEEE transactions on neural networks 22.1 (2010): 131-144. """ -function delay_line_backward_reservoir(rng::AbstractRNG, +function delay_line_backward(rng::AbstractRNG, ::Type{T}, dims::Integer...; - weight = T(0.1), + weight = T(0.1), fb_weight = T(0.2)) where {T <: Number} - reservoir_matrix = zeros(res_size, res_size) + res_size = first(dims) + reservoir_matrix = zeros(T, dims...) for i in 1:(res_size - 1) reservoir_matrix[i + 1, i] = weight @@ -107,9 +108,8 @@ function delay_line_backward_reservoir(rng::AbstractRNG, return reservoir_matrix end - """ - cycle_jumps_reservoir(rng::AbstractRNG, ::Type{T}, dims::Integer...; + cycle_jumps(rng::AbstractRNG, ::Type{T}, dims::Integer...; cycle_weight = T(0.1), jump_weight = T(0.1), jump_size = 3) where {T <: Number} Create a cycle jumps reservoir with the specified dimensions, cycle weight, jump weight, and jump size. @@ -129,25 +129,25 @@ Reservoir matrix with the specified dimensions, cycle weight, jump weight, and j [^Rodan2012]: Rodan, Ali, and Peter Tiňo. "Simple deterministically constructed cycle reservoirs with regular jumps." Neural computation 24.7 (2012): 1822-1852. """ -function cycle_jumps_reservoir(rng::AbstractRNG, +function cycle_jumps(rng::AbstractRNG, ::Type{T}, dims::Integer...; - cycle_weight = T(0.1), - jump_weight = T(0.1), - jump_size = T(3)) where {T <: Number} - + cycle_weight::Number = T(0.1), + jump_weight::Number = T(0.1), + jump_size::Int = 3) where {T <: Number} + res_size = first(dims) reservoir_matrix = zeros(T, dims...) - for i in 1:(dims[1] - 1) + for i in 1:(res_size - 1) reservoir_matrix[i + 1, i] = cycle_weight end - reservoir_matrix[1, dims[1]] = cycle_weight + reservoir_matrix[1, res_size] = cycle_weight - for i in 1:jump_size:(dims[1] - jump_size) - tmp = (i + jump_size) % dims[1] + for i in 1:jump_size:(res_size - jump_size) + tmp = (i + jump_size) % res_size if tmp == 0 - tmp = dims[1] + tmp = res_size end reservoir_matrix[i, tmp] = jump_weight reservoir_matrix[tmp, i] = jump_weight @@ -156,9 +156,8 @@ function cycle_jumps_reservoir(rng::AbstractRNG, return reservoir_matrix end - """ - simple_cycle_reservoir(rng::AbstractRNG, ::Type{T}, dims::Integer...; + simple_cycle(rng::AbstractRNG, ::Type{T}, dims::Integer...; weight = T(0.1)) where {T <: Number} Create a simple cycle reservoir with the specified dimensions and weight. @@ -176,7 +175,7 @@ Reservoir matrix with the dimensions specified by `dims` and weights. [^Rodan2010]: Rodan, Ali, and Peter Tino. "Minimum complexity echo state network." IEEE transactions on neural networks 22.1 (2010): 131-144. """ -function simple_cycle_reservoir(rng::AbstractRNG, +function simple_cycle(rng::AbstractRNG, ::Type{T}, dims::Integer...; weight = T(0.1)) where {T <: Number} @@ -190,11 +189,8 @@ function simple_cycle_reservoir(rng::AbstractRNG, return reservoir_matrix end - - - """ - pseudo_svd_reservoir(rng::AbstractRNG, ::Type{T}, dims::Integer...; + pseudo_svd(rng::AbstractRNG, ::Type{T}, dims::Integer...; max_value, sparsity, sorted = true, reverse_sort = false) where {T <: Number} Returns an initializer to build a sparse reservoir matrix with the given `sparsity` by using a pseudo-SVD approach as described in [^yang]. @@ -216,36 +212,45 @@ This reservoir initialization method, based on a pseudo-SVD approach, is inspire [^yang]: Yang, Cuili, et al. "_Design of polynomial echo state networks for time series prediction._" Neurocomputing 290 (2018): 148-160. """ -function pseudo_svd_reservoir(rng::AbstractRNG, +function pseudo_svd(rng::AbstractRNG, ::Type{T}, - dims::Integer...; - max_value, sparsity; - sorted = true, - reverse_sort = false) where {T <: Number} - - reservoir_matrix = create_diag( dims[1], max_value, sorted = sorted, reverse_sort = reverse_sort) - tmp_sparsity = get_sparsity(reservoir_matrix, dims[1]) + dims::Integer...; + max_value::Number = T(1.0), + sparsity::Number = 0.1, + sorted::Bool = true, + reverse_sort::Bool = false) where {T <: Number} + reservoir_matrix = create_diag(dims[1], + max_value, + T; + sorted = sorted, + reverse_sort = reverse_sort) + tmp_sparsity = get_sparsity(reservoir_matrix, dims[1]) while tmp_sparsity <= sparsity - reservoir_matrix *= create_qmatrix(dims[1], rand(1:dims[1]), rand(1:dims[1]), rand() * 2 - 1) + reservoir_matrix *= create_qmatrix(dims[1], + rand(1:dims[1]), + rand(1:dims[1]), + rand(T) * T(2) - T(1), + T) tmp_sparsity = get_sparsity(reservoir_matrix, dims[1]) end return reservoir_matrix end -function create_diag(dim, max_value; sorted = true, reverse_sort = false) - diagonal_matrix = zeros(dim, dim) +function create_diag(dim::Number, max_value::Number, ::Type{T}; + sorted::Bool = true, reverse_sort::Bool = false) where {T <: Number} + diagonal_matrix = zeros(T, dim, dim) if sorted == true if reverse_sort == true - diagonal_values = sort(rand(dim) .* max_value, rev = true) + diagonal_values = sort(rand(T, dim) .* max_value, rev = true) diagonal_values[1] = max_value else - diagonal_values = sort(rand(dim) .* max_value) + diagonal_values = sort(rand(T, dim) .* max_value) diagonal_values[end] = max_value end else - diagonal_values = rand(dim) .* max_value + diagonal_values = rand(T, dim) .* max_value end for i in 1:dim @@ -255,8 +260,12 @@ function create_diag(dim, max_value; sorted = true, reverse_sort = false) return diagonal_matrix end -function create_qmatrix(dim, coord_i, coord_j, theta) - qmatrix = zeros(dim, dim) +function create_qmatrix(dim::Number, + coord_i::Number, + coord_j::Number, + theta::Number, + ::Type{T}) where {T <: Number} + qmatrix = zeros(T, dim, dim) for i in 1:dim qmatrix[i, i] = 1.0 @@ -273,8 +282,6 @@ function get_sparsity(M, dim) return size(M[M .!= 0], 1) / (dim * dim - size(M[M .!= 0], 1)) #nonzero/zero elements end - - # from WeightInitializers.jl, TODO @MartinuzziFrancesco consider importing package function _default_rng() @static if VERSION >= v"1.7" diff --git a/src/esn/hybridesn.jl b/src/esn/hybridesn.jl index f29028fc..37b40d0c 100644 --- a/src/esn/hybridesn.jl +++ b/src/esn/hybridesn.jl @@ -52,23 +52,20 @@ function KnowledgeModel(prior_model, u0, tspan, datasize) return KnowledgeModel(prior_model, u0, tspan, dt, datasize, model_data) end -function HybridESN( - model, - train_data, - in_size::Int, - res_size::Int; - input_layer = scaled_rand, - reservoir = rand_sparse, - bias = zeros64, - reservoir_driver = RNN(), - nla_type = NLADefault(), - states_type = StandardStates(), - washout = 0, - rng = _default_rng(), - T = Float32, - matrix_type = typeof(train_data) -) - +function HybridESN(model, + train_data, + in_size::Int, + res_size::Int; + input_layer = scaled_rand, + reservoir = rand_sparse, + bias = zeros64, + reservoir_driver = RNN(), + nla_type = NLADefault(), + states_type = StandardStates(), + washout = 0, + rng = _default_rng(), + T = Float32, + matrix_type = typeof(train_data)) train_data = vcat(train_data, model.model_data[:, 1:(end - 1)]) if states_type isa AbstractPaddedStates @@ -94,10 +91,9 @@ function HybridESN( end function (hesn::HybridESN)(prediction::AbstractPrediction, - output_layer::AbstractOutputLayer; - last_state = hesn.states[:, [end]], - kwargs...) - + output_layer::AbstractOutputLayer; + last_state = hesn.states[:, [end]], + kwargs...) km = hesn.model pred_len = prediction.prediction_len @@ -114,11 +110,10 @@ function (hesn::HybridESN)(prediction::AbstractPrediction, end function train(hesn::HybridESN, - target_data, - training_method = StandardRidge(0.0)) - + target_data, + training_method = StandardRidge(0.0)) states = vcat(hesn.states, hesn.model.model_data[:, 2:end]) states_new = hesn.states_type(hesn.nla_type, states, hesn.train_data[:, 1:end]) return _train(states_new, target_data, training_method) -end \ No newline at end of file +end diff --git a/test/esn/test_drivers.jl b/test/esn/test_drivers.jl index db1c7b01..6dd7f7a2 100644 --- a/test/esn/test_drivers.jl +++ b/test/esn/test_drivers.jl @@ -1,65 +1,40 @@ using ReservoirComputing, Random, Statistics, NNlib -const res_size = 20 +const res_size = 50 const ts = 0.0:0.1:50.0 const data = sin.(ts) const train_len = 400 const input_data = reduce(hcat, data[1:(train_len - 1)]) const target_data = reduce(hcat, data[2:train_len]) const predict_len = 100 -const test = reduce(hcat, data[(train_len + 1):(train_len + predict_len)]) +const test_data = reduce(hcat, data[(train_len + 1):(train_len + predict_len)]) const training_method = StandardRidge(10e-6) - Random.seed!(77) -esn = ESN(input_data; - reservoir = RandSparseReservoir(res_size, 1.2, 0.1), - reservoir_driver = GRU(variant = FullyGated(), - reservoir = [ - RandSparseReservoir(res_size, 1.0, 0.5), - RandSparseReservoir(res_size, 1.2, 0.1), - ])) - -output_layer = train(esn, target_data, training_method) -output = esn(Predictive(target_data), output_layer, initial_conditions = target_data[1]) -@test mean(abs.(target_data .- output)) ./ mean(abs.(target_data)) < 0.11 - -esn = ESN(input_data; - reservoir = RandSparseReservoir(res_size, 1.2, 0.1), - reservoir_driver = GRU(variant = Minimal(), - reservoir = RandSparseReservoir(res_size, 1.0, 0.5), - inner_layer = DenseLayer(), - bias = DenseLayer())) - -output_layer = train(esn, target_data, training_method) -output = esn(Predictive(target_data), output_layer, initial_conditions = target_data[1]) -@test mean(abs.(target_data .- output)) ./ mean(abs.(target_data)) < 0.11 - -#multiple rnn -esn = ESN(input_data; - reservoir = RandSparseReservoir(res_size, 1.2, 0.1), - reservoir_driver = MRNN(activation_function = (tanh, sigmoid), - scaling_factor = (0.8, 0.1))) -output_layer = train(esn, target_data, training_method) -output = esn(Predictive(target_data), output_layer, initial_conditions = target_data[1]) -@test mean(abs.(target_data .- output)) ./ mean(abs.(target_data)) < 0.11 - -#deep esn -esn = ESN(input_data; - reservoir = [ - RandSparseReservoir(res_size, 1.2, 0.1), - RandSparseReservoir(res_size, 1.2, 0.1), - ]) -output_layer = train(esn, target_data, training_method) -output = esn(Predictive(target_data), output_layer, initial_conditions = target_data[1]) -@test mean(abs.(target_data .- output)) ./ mean(abs.(target_data)) < 0.11 -esn = ESN(input_data; - reservoir = [ - RandSparseReservoir(res_size, 1.2, 0.1), - RandSparseReservoir(res_size, 1.2, 0.1), - ], - input_layer = [DenseLayer(), DenseLayer()], - bias = [NullLayer(), NullLayer()]) -output_layer = train(esn, target_data, training_method) -output = esn(Predictive(target_data), output_layer, initial_conditions = target_data[1]) -@test mean(abs.(target_data .- output)) ./ mean(abs.(target_data)) < 0.11 +function test_esn(input_data, target_data, training_method, esn_config) + esn = ESN(input_data, 1, res_size; esn_config...) + output_layer = train(esn, target_data, training_method) + + output = esn(Predictive(target_data), output_layer, initial_conditions = target_data[1]) + @test mean(abs.(target_data .- output)) ./ mean(abs.(target_data)) < 0.15 +end + +esn_configs = [ + Dict(:reservoir => rand_sparse(; radius = 1.2), + :reservoir_driver => GRU(variant = FullyGated(), + reservoir = [ + rand_sparse(; radius = 1.0, sparsity = 0.5), + rand_sparse(; radius = 1.2, sparsity = 0.1), + ])), + Dict(:reservoir => rand_sparse(; radius = 1.2), + :reservoir_driver => GRU(variant = Minimal(), + reservoir = rand_sparse(; radius = 1.0, sparsity = 0.5), + inner_layer = scaled_rand)), + Dict(:reservoir => rand_sparse(; radius = 1.2), + :reservoir_driver => MRNN(activation_function = (tanh, sigmoid), + scaling_factor = (0.8, 0.1))), +] + +for config in esn_configs + test_esn(input_data, target_data, training_method, config) +end diff --git a/test/esn/test_inits.jl b/test/esn/test_inits.jl new file mode 100644 index 00000000..789f71e1 --- /dev/null +++ b/test/esn/test_inits.jl @@ -0,0 +1,90 @@ +using ReservoirComputing +using LinearAlgebra +using Random + +const res_size = 30 +const in_size = 3 +const radius = 1.0 +const sparsity = 0.1 +const weight = 0.2 +const jump_size = 3 +const rng = Random.default_rng() + +function check_radius(matrix, target_radius; tolerance = 1e-5) + eigenvalues = eigvals(matrix) + spectral_radius = maximum(abs.(eigenvalues)) + return isapprox(spectral_radius, target_radius, atol = tolerance) +end + +ft = [Float16, Float32, Float64] +reservoir_inits = [ + rand_sparse, + delay_line, + delay_line_backward, + cycle_jumps, + simple_cycle, + pseudo_svd, +] +input_inits = [ + scaled_rand, + weighted_init, + sparse_init, + minimal_init, + minimal_init(; sampling_type = :irrational), +] + +@testset "Reservoir Initializers" begin + @testset "Sizes and types: $init $T" for init in reservoir_inits, T in ft + #sizes + @test size(init(res_size, res_size)) == (res_size, res_size) + @test size(init(rng, res_size, res_size)) == (res_size, res_size) + #types + @test eltype(init(T, res_size, res_size)) == T + @test eltype(init(rng, T, res_size, res_size)) == T + #closure + cl = init(rng) + @test eltype(cl(T, res_size, res_size)) == T + end + + @testset "Check spectral radius" begin + sp = rand_sparse(res_size, res_size) + @test check_radius(sp, radius) + end + + @testset "Minimum complexity: $init" for init in [ + delay_line, + delay_line_backward, + cycle_jumps, + simple_cycle, + ] + dl = init(res_size, res_size) + if init === delay_line_backward + @test unique(dl) == Float32.([0.0, 0.1, 0.2]) + else + @test unique(dl) == Float32.([0.0, 0.1]) + end + end +end + +# TODO: @MartinuzziFrancesco Missing tests for informed_init +@testset "Input Initializers" begin + @testset "Sizes and types: $init $T" for init in input_inits, T in ft + #sizes + @test size(init(res_size, in_size)) == (res_size, in_size) + @test size(init(rng, res_size, in_size)) == (res_size, in_size) + #types + @test eltype(init(T, res_size, in_size)) == T + @test eltype(init(rng, T, res_size, in_size)) == T + #closure + cl = init(rng) + @test eltype(cl(T, res_size, in_size)) == T + end + + @testset "Minimum complexity: $init" for init in [ + minimal_init, + minimal_init(; sampling_type = :irrational), + ] + dl = init(res_size, in_size) + @test sort(unique(dl)) == Float32.([-0.1, 0.1]) + end +end diff --git a/test/esn/test_input_layers.jl b/test/esn/test_input_layers.jl deleted file mode 100644 index adc655da..00000000 --- a/test/esn/test_input_layers.jl +++ /dev/null @@ -1,55 +0,0 @@ -using ReservoirComputing - -const res_size = 10 -const in_size = 3 -const scaling = 0.1 -const weight = 0.2 - -#testing WeightedLayer implicit and esplicit constructors -input_constructor = WeightedLayer(scaling) -input_matrix = create_layer(input_constructor, res_size, in_size) -@test size(input_matrix) == (Int(floor(res_size / in_size) * in_size), in_size) -@test maximum(input_matrix) <= scaling - -input_constructor = WeightedLayer(scaling = scaling) -input_matrix = create_layer(input_constructor, res_size, in_size) -@test size(input_matrix) == (Int(floor(res_size / in_size) * in_size), in_size) -@test maximum(input_matrix) <= scaling - -#testing DenseLayer implicit and esplicit constructors -input_constructor = DenseLayer(scaling) -input_matrix = create_layer(input_constructor, res_size, in_size) -@test size(input_matrix) == (res_size, in_size) -@test maximum(input_matrix) <= scaling - -input_constructor = DenseLayer(scaling = scaling) -input_matrix = create_layer(input_constructor, res_size, in_size) -@test size(input_matrix) == (res_size, in_size) -@test maximum(input_matrix) <= scaling - -#testing SparseLayer implicit and esplicit constructors -input_constructor = SparseLayer(scaling) -input_matrix = create_layer(input_constructor, res_size, in_size) -@test size(input_matrix) == (res_size, in_size) -@test maximum(input_matrix) <= scaling - -input_constructor = SparseLayer(scaling = scaling) -input_matrix = create_layer(input_constructor, res_size, in_size) -@test size(input_matrix) == (res_size, in_size) -@test maximum(input_matrix) <= scaling - -#testing MinimumLayer implicit and esplicit constructors -input_constructor = MinimumLayer(weight) -input_matrix = create_layer(input_constructor, res_size, in_size) -@test size(input_matrix) == (res_size, in_size) -@test maximum(input_matrix) == weight - -input_constructor = MinimumLayer(weight = weight) -input_matrix = create_layer(input_constructor, res_size, in_size) -@test size(input_matrix) == (res_size, in_size) -@test maximum(input_matrix) == weight - -#testing NullLayer constructor -input_constructor = NullLayer() -input_matrix = create_layer(input_constructor, res_size, in_size) -@test size(input_matrix) == (res_size, in_size) diff --git a/test/esn/test_reservoirs.jl b/test/esn/test_reservoirs.jl deleted file mode 100644 index debd9be0..00000000 --- a/test/esn/test_reservoirs.jl +++ /dev/null @@ -1,41 +0,0 @@ -using ReservoirComputing -using LinearAlgebra -using Random -include("../utils.jl") - -const res_size = 20 -const radius = 1.0 -const sparsity = 0.1 -const weight = 0.2 -const jump_size = 3 -const rng = Random.default_rng() - -dtypes = [Float16, Float32, Float64] -reservoir_inits = [rand_sparse, delay_line] - -@testset "Sizes and types" begin - for init in reservoir_inits - for dt in dtypes - #sizes - @test size(init(res_size, res_size)) == (res_size, res_size) - @test size(init(rng, res_size, res_size)) == (res_size, res_size) - #types - @test eltype(init(dt, res_size, res_size)) == dt - @test eltype(init(rng, dt, res_size, res_size)) == dt - #closure - cl = init(rng) - @test cl(dt, res_size, res_size) isa AbstractArray{dt} - end - end -end - -@testset "rand_sparse" begin - sp = rand_sparse(res_size, res_size) - @test check_radius(sp, radius) -end - -@testset "delay_line" begin - dl = delay_line(res_size, res_size) - @test unique(dl) == Float32.([0.0, 0.1]) -end - diff --git a/test/runtests.jl b/test/runtests.jl index 2d114e99..4dfad5c4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,11 +11,8 @@ using Test end @testset "Echo State Networks" begin - @safetestset "ESN Input Layers" begin - include("esn/test_input_layers.jl") - end - @safetestset "ESN Reservoirs" begin - include("esn/test_reservoirs.jl") + @safetestset "Test initializers" begin + include("esn/test_inits.jl") end @safetestset "ESN States" begin include("esn/test_states.jl") diff --git a/test/test_states.jl b/test/test_states.jl index c8808bbf..1215d47b 100644 --- a/test/test_states.jl +++ b/test/test_states.jl @@ -1,51 +1,37 @@ using ReservoirComputing -#padding test_array = [1, 2, 3, 4, 5, 6, 7, 8, 9] -standard_array = zeros(length(test_array), 1) extension = [0, 0, 0] -padded_array = zeros(length(test_array) + 1, 1) -extended_array = zeros(length(test_array) + length(extension), 1) -padded_extended_array = zeros(length(test_array) + length(extension) + 1, 1) padding = 10.0 -#testing non linear algos -nla_array = ReservoirComputing.nla(NLADefault(), test_array) -@test nla_array == test_array - -nla_array = ReservoirComputing.nla(NLAT1(), test_array) -@test nla_array == [1, 2, 9, 4, 25, 6, 49, 8, 81] - -nla_array = ReservoirComputing.nla(NLAT2(), test_array) -@test nla_array == [1, 2, 2, 4, 12, 6, 30, 8, 9] - -nla_array = ReservoirComputing.nla(NLAT3(), test_array) -@test nla_array == [1, 2, 8, 4, 24, 6, 48, 8, 9] - -#testing padding and extension -states_type = StandardStates() -standard_array = states_type(NLADefault(), test_array, extension) -@test standard_array == test_array - -states_type = PaddedStates(padding = padding) -padded_array = states_type(NLADefault(), test_array, extension) -@test padded_array == reshape(vcat(padding, test_array), length(test_array) + 1, 1) - -states_type = PaddedStates(padding) -padded_array = states_type(NLADefault(), test_array, extension) -@test padded_array == reshape(vcat(padding, test_array), length(test_array) + 1, 1) - -states_type = PaddedExtendedStates(padding = padding) -padded_extended_array = states_type(NLADefault(), test_array, extension) -@test padded_extended_array == reshape(vcat(padding, extension, test_array), - length(test_array) + length(extension) + 1, 1) - -states_type = PaddedExtendedStates(padding) -padded_extended_array = states_type(NLADefault(), test_array, extension) -@test padded_extended_array == reshape(vcat(padding, extension, test_array), - length(test_array) + length(extension) + 1, 1) - -states_type = ExtendedStates() -extended_array = states_type(NLADefault(), test_array, extension) -@test extended_array == vcat(extension, test_array) -#reshape(vcat(extension, test_array), length(test_array)+length(extension), 1) +nlas = [(NLADefault(), test_array), + (NLAT1(), [1, 2, 9, 4, 25, 6, 49, 8, 81]), + (NLAT2(), [1, 2, 2, 4, 12, 6, 30, 8, 9]), + (NLAT3(), [1, 2, 8, 4, 24, 6, 48, 8, 9])] + +pes = [(StandardStates(), test_array), + (PaddedStates(padding = padding), + reshape(vcat(padding, test_array), length(test_array) + 1, 1)), + (PaddedExtendedStates(padding = padding), + reshape(vcat(padding, extension, test_array), + length(test_array) + length(extension) + 1, + 1)), + (ExtendedStates(), vcat(extension, test_array))] + +function test_nla(algo, expected_output) + nla_array = ReservoirComputing.nla(algo, test_array) + @test nla_array == expected_output +end + +function test_states_type(state_type, expected_output) + states_output = state_type(NLADefault(), test_array, extension) + @test states_output == expected_output +end + +@testset "Nonlinear Algorithms Testing" for (algo, expected_output) in nlas + test_nla(algo, expected_output) +end + +@testset "States Testing" for (state_type, expected_output) in pes + test_states_type(state_type, expected_output) +end diff --git a/test/utils.jl b/test/utils.jl deleted file mode 100644 index 9ef6f360..00000000 --- a/test/utils.jl +++ /dev/null @@ -1,5 +0,0 @@ -function check_radius(matrix, target_radius; tolerance=1e-5) - eigenvalues = eigvals(matrix) - spectral_radius = maximum(abs.(eigenvalues)) - return isapprox(spectral_radius, target_radius, atol=tolerance) -end \ No newline at end of file From ebde2324c4022ed9c9dce852933da9aff73137e3 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Mon, 12 Feb 2024 17:30:03 +0100 Subject: [PATCH 17/26] various test fixes --- src/esn/deepesn.jl | 1 + test/esn/test_drivers.jl | 5 +++-- test/esn/test_nla.jl | 23 -------------------- test/esn/test_states.jl | 46 ---------------------------------------- test/esn/test_train.jl | 17 ++++++--------- test/test_states.jl | 30 ++++++++++++-------------- 6 files changed, 24 insertions(+), 98 deletions(-) delete mode 100644 test/esn/test_nla.jl delete mode 100644 test/esn/test_states.jl diff --git a/src/esn/deepesn.jl b/src/esn/deepesn.jl index 3402d2b5..5879aa05 100644 --- a/src/esn/deepesn.jl +++ b/src/esn/deepesn.jl @@ -24,6 +24,7 @@ function DeepESN(train_data, rng = _default_rng(), T = Float64, matrix_type = typeof(train_data)) + if states_type isa AbstractPaddedStates in_size = size(train_data, 1) + 1 train_data = vcat(Adapt.adapt(matrix_type, ones(1, size(train_data, 2))), diff --git a/test/esn/test_drivers.jl b/test/esn/test_drivers.jl index 6dd7f7a2..b7034484 100644 --- a/test/esn/test_drivers.jl +++ b/test/esn/test_drivers.jl @@ -29,12 +29,13 @@ esn_configs = [ Dict(:reservoir => rand_sparse(; radius = 1.2), :reservoir_driver => GRU(variant = Minimal(), reservoir = rand_sparse(; radius = 1.0, sparsity = 0.5), - inner_layer = scaled_rand)), + inner_layer = scaled_rand, + bias = scaled_rand)), Dict(:reservoir => rand_sparse(; radius = 1.2), :reservoir_driver => MRNN(activation_function = (tanh, sigmoid), scaling_factor = (0.8, 0.1))), ] -for config in esn_configs +@testset "Test Drivers: $config" for config in esn_configs test_esn(input_data, target_data, training_method, config) end diff --git a/test/esn/test_nla.jl b/test/esn/test_nla.jl deleted file mode 100644 index f5ac42f8..00000000 --- a/test/esn/test_nla.jl +++ /dev/null @@ -1,23 +0,0 @@ -using ReservoirComputing - -states = [1, 2, 3, 4, 5, 6, 7, 8, 9] -nla1_states = [1, 2, 9, 4, 25, 6, 49, 8, 81] -nla2_states = [1, 2, 2, 4, 12, 6, 30, 8, 9] -nla3_states = [1, 2, 8, 4, 24, 6, 48, 8, 9] - -test_types = [Float64, Float32, Float16] - -for tt in test_types - # test default - nla_states = ReservoirComputing.nla(NLADefault(), tt.(states)) - @test nla_states == tt.(states) - # test NLAT1 - nla_states = ReservoirComputing.nla(NLAT1(), tt.(states)) - @test nla_states = tt.(nla1_states) - # test nlat2 - nla_states = ReservoirComputing.nla(NLAT2(), tt.(states)) - @test nla_states = tt.(nla2_states) - # test nlat3 - nla_states = ReservoirComputing.nla(NLAT3(), tt.(states)) - @test nla_states = tt.(nla3_states) -end diff --git a/test/esn/test_states.jl b/test/esn/test_states.jl deleted file mode 100644 index 479d29c9..00000000 --- a/test/esn/test_states.jl +++ /dev/null @@ -1,46 +0,0 @@ -using ReservoirComputing - -test_types = [Float64, Float32, Float16] -states = [1, 2, 3, 4, 5, 6, 7, 8, 9] -in_data = fill(1, 3) - -states_types = [StandardStates, ExtendedStates, PaddedStates, PaddedExtendedStates] - -# testing extension and padding -for tt in test_types - st_states = StandardStates()(NLADefault(), tt.(states), tt.(in_data)) - @test length(st_states) == length(states) - @test typeof(st_states) == typeof(tt.(states)) - - st_states = ExtendedStates()(NLADefault(), tt.(states), tt.(in_data)) - @test length(st_states) == length(states) + length(in_data) - @test typeof(st_states) == typeof(tt.(states)) - - st_states = PaddedStates()(NLADefault(), tt.(states), tt.(in_data)) - @test length(st_states) == length(states) + 1 - @test typeof(st_states[1]) == typeof(tt.(states)[1]) - - st_states = PaddedExtendedStates()(NLADefault(), tt.(states), tt.(in_data)) - @test length(st_states) == length(states) + length(in_data) + 1 - @test typeof(st_states[1]) == typeof(tt.(states)[1]) -end - -## testing non linear algos -nla1_states = [1, 2, 9, 4, 25, 6, 49, 8, 81] -nla2_states = [1, 2, 2, 4, 12, 6, 30, 8, 9] -nla3_states = [1, 2, 8, 4, 24, 6, 48, 8, 9] - -for tt in test_types - # test default - nla_states = ReservoirComputing.nla(NLADefault(), tt.(states)) - @test nla_states == tt.(states) - # test NLAT1 - nla_states = ReservoirComputing.nla(NLAT1(), tt.(states)) - @test nla_states == tt.(nla1_states) - # test nlat2 - nla_states = ReservoirComputing.nla(NLAT2(), tt.(states)) - @test nla_states == tt.(nla2_states) - # test nlat3 - nla_states = ReservoirComputing.nla(NLAT3(), tt.(states)) - @test nla_states == tt.(nla3_states) -end diff --git a/test/esn/test_train.jl b/test/esn/test_train.jl index e5140b97..a0f6a4c1 100644 --- a/test/esn/test_train.jl +++ b/test/esn/test_train.jl @@ -9,10 +9,12 @@ const input_data = reduce(hcat, data[1:(train_len - 1)]) const target_data = reduce(hcat, data[2:train_len]) const test = reduce(hcat, data[(train_len + 1):(train_len + predict_len)]) const reg = 10e-6 +#test_types = [Float64, Float32, Float16] Random.seed!(77) -esn = ESN(input_data; - reservoir = RandSparseReservoir(res_size, 1.2, 0.1)) +res = rand_sparse(; radius=1.2, sparsity=0.1) +esn = ESN(input_data, 1, res_size; + reservoir = rand_sparse) training_methods = [ StandardRidge(regularization_coeff = reg), @@ -21,14 +23,9 @@ training_methods = [ EpsilonSVR(), ] -for t in training_methods - output_layer = train(esn, target_data, t) +# TODO check types +@testset "Training Algo Tests: $ta" for ta in training_methods + output_layer = train(esn, target_data, ta) output = esn(Predictive(input_data), output_layer) @test mean(abs.(target_data .- output)) ./ mean(abs.(target_data)) < 0.22 end - -for t in training_methods - output_layer = train(esn, target_data, t) - output, states = esn(Predictive(input_data), output_layer, save_states = true) - @test size(states) == (res_size, size(input_data, 2)) -end diff --git a/test/test_states.jl b/test/test_states.jl index 1215d47b..34570668 100644 --- a/test/test_states.jl +++ b/test/test_states.jl @@ -3,6 +3,7 @@ using ReservoirComputing test_array = [1, 2, 3, 4, 5, 6, 7, 8, 9] extension = [0, 0, 0] padding = 10.0 +test_types = [Float64, Float32, Float16] nlas = [(NLADefault(), test_array), (NLAT1(), [1, 2, 9, 4, 25, 6, 49, 8, 81]), @@ -18,20 +19,15 @@ pes = [(StandardStates(), test_array), 1)), (ExtendedStates(), vcat(extension, test_array))] -function test_nla(algo, expected_output) - nla_array = ReservoirComputing.nla(algo, test_array) - @test nla_array == expected_output -end - -function test_states_type(state_type, expected_output) - states_output = state_type(NLADefault(), test_array, extension) - @test states_output == expected_output -end - -@testset "Nonlinear Algorithms Testing" for (algo, expected_output) in nlas - test_nla(algo, expected_output) -end - -@testset "States Testing" for (state_type, expected_output) in pes - test_states_type(state_type, expected_output) -end +@testset "States Testing" for T in test_types + @testset "Nonlinear Algorithms Testing: $algo $T" for (algo, expected_output) in nlas + nla_array = ReservoirComputing.nla(algo, T.(test_array)) + @test nla_array == expected_output + @test eltype(nla_array) == T + end + @testset "States Testing: $state_type $T" for (state_type, expected_output) in pes + states_output = state_type(NLADefault(), T.(test_array), T.(extension)) + @test states_output == expected_output + @test eltype(states_output) == T + end +end \ No newline at end of file From a3c764efd5b44caa10f6925a73d7483593b77deb Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Wed, 14 Feb 2024 18:02:24 +0100 Subject: [PATCH 18/26] some work on the DeepESN --- src/esn/deepesn.jl | 55 ++++++-------------------------- src/esn/esn.jl | 4 +-- src/esn/esn_reservoir_drivers.jl | 2 +- test/esn/deepesn.jl | 20 ++++++++++++ test/runtests.jl | 6 ++-- 5 files changed, 36 insertions(+), 51 deletions(-) diff --git a/src/esn/deepesn.jl b/src/esn/deepesn.jl index 5879aa05..17933f3a 100644 --- a/src/esn/deepesn.jl +++ b/src/esn/deepesn.jl @@ -13,14 +13,15 @@ end function DeepESN(train_data, in_size::Int, - res_size::AbstractArray; - input_layer = scaled_rand, - reservoir = rand_sparse, - bias = zeros64, + res_size::Int; + depth::Int=2, + input_layer = fill(scaled_rand, depth), + bias = fill(zeros64, depth), + reservoir = fill(rand_sparse, depth), reservoir_driver = RNN(), nla_type = NLADefault(), states_type = StandardStates(), - washout = 0, + washout::Int = 0, rng = _default_rng(), T = Float64, matrix_type = typeof(train_data)) @@ -31,51 +32,15 @@ function DeepESN(train_data, train_data) end - reservoir_matrix = reservoir(rng, T, res_size, res_size) - input_matrix = input_layer(rng, T, res_size, in_size) - bias_vector = bias(rng, T, res_size) + reservoir_matrix = [reservoir[i](rng, T, res_size, res_size) for i in 1:depth] + input_matrix = [input_layer[i](rng, T, res_size, in_size) for i in 1:depth] + bias_vector = [bias[i](rng, res_size) for i in 1:depth] inner_res_driver = reservoir_driver_params(reservoir_driver, res_size, in_size) states = create_states(inner_res_driver, train_data, washout, reservoir_matrix, input_matrix, bias_vector) train_data = train_data[:, (washout + 1):end] - DeepESN(sum(res_size), train_data, variation, nla_type, input_matrix, + DeepESN(res_size, train_data, variation, nla_type, input_matrix, inner_res_driver, reservoir_matrix, bias_vector, states_type, washout, states) end - -function obtain_layers(in_size, - input_layer, - reservoir::Vector, - bias; - matrix_type = Matrix{Float64}) - esn_depth = length(reservoir) - input_res_sizes = [get_ressize(reservoir[i]) for i in 1:esn_depth] - in_sizes = zeros(Int, esn_depth) - in_sizes[2:end] = input_res_sizes[1:(end - 1)] - in_sizes[1] = in_size - - if input_layer isa Array - input_matrix = [create_layer(input_layer[j], input_res_sizes[j], in_sizes[j], - matrix_type = matrix_type) for j in 1:esn_depth] - else - _input_layer = fill(input_layer, esn_depth) - input_matrix = [create_layer(_input_layer[k], input_res_sizes[k], in_sizes[k], - matrix_type = matrix_type) for k in 1:esn_depth] - end - - res_sizes = [get_ressize(input_matrix[j]) for j in 1:esn_depth] - reservoir_matrix = [create_reservoir(reservoir[k], res_sizes[k], - matrix_type = matrix_type) for k in 1:esn_depth] - - if bias isa Array - bias_vector = [create_layer(bias[j], res_sizes[j], 1, matrix_type = matrix_type) - for j in 1:esn_depth] - else - _bias = fill(bias, esn_depth) - bias_vector = [create_layer(_bias[k], res_sizes[k], 1, matrix_type = matrix_type) - for k in 1:esn_depth] - end - - return input_matrix, reservoir_matrix, bias_vector, res_sizes -end diff --git a/src/esn/esn.jl b/src/esn/esn.jl index 260e7aff..4e5fb41e 100644 --- a/src/esn/esn.jl +++ b/src/esn/esn.jl @@ -73,7 +73,7 @@ function ESN(train_data, states) end -function (esn::ESN)(prediction::AbstractPrediction, +function (esn::AbstractEchoStateNetwork)(prediction::AbstractPrediction, output_layer::AbstractOutputLayer; last_state = esn.states[:, [end]], kwargs...) @@ -122,7 +122,7 @@ trained_esn = train(esn, target_data, training_method=StandardRidge(1.0)) - The training is handled by a lower-level `_train` function which takes the new state matrix and performs the actual training using the specified `training_method`. """ -function train(esn::ESN, +function train(esn::AbstractEchoStateNetwork, target_data, training_method = StandardRidge(0.0)) states_new = esn.states_type(esn.nla_type, esn.states, esn.train_data[:, 1:end]) diff --git a/src/esn/esn_reservoir_drivers.jl b/src/esn/esn_reservoir_drivers.jl index 98e7ab78..a8d7b3b1 100644 --- a/src/esn/esn_reservoir_drivers.jl +++ b/src/esn/esn_reservoir_drivers.jl @@ -129,7 +129,7 @@ end function next_state!(out, rnn::RNN, x, y, W::Vector, W_in, b, tmp_array) esn_depth = length(W) - res_sizes = vcat(0, [get_ressize(W[i]) for i in 1:esn_depth]) + res_sizes = vcat(0, [size(W[i],1) for i in 1:esn_depth]) inner_states = [x[(1 + sum(res_sizes[1:i])):sum(res_sizes[1:(i + 1)])] for i in 1:esn_depth] inner_inputs = vcat([y], inner_states[1:(end - 1)]) diff --git a/test/esn/deepesn.jl b/test/esn/deepesn.jl index 8b137891..ce07ec0f 100644 --- a/test/esn/deepesn.jl +++ b/test/esn/deepesn.jl @@ -1 +1,21 @@ +using ReservoirComputing, Random, Statistics + +const res_size = 20 +const ts = 0.0:0.1:50.0 +const data = sin.(ts) +const train_len = 400 +const predict_len = 100 +const input_data = reduce(hcat, data[1:(train_len - 1)]) +const target_data = reduce(hcat, data[2:train_len]) +const test = reduce(hcat, data[(train_len + 1):(train_len + predict_len)]) +const reg = 10e-6 +#test_types = [Float64, Float32, Float16] + +Random.seed!(77) +res = rand_sparse(; radius=1.2, sparsity=0.1) +esn = DeepESN(input_data, 1, res_size) + +output_layer = train(esn, target_data, ta) +output = esn(Predictive(input_data), output_layer) +@test mean(abs.(target_data .- output)) ./ mean(abs.(target_data)) < 0.22 diff --git a/test/runtests.jl b/test/runtests.jl index 4dfad5c4..d424f467 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -14,9 +14,6 @@ end @safetestset "Test initializers" begin include("esn/test_inits.jl") end - @safetestset "ESN States" begin - include("esn/test_states.jl") - end @safetestset "ESN Train and Predict" begin include("esn/test_train.jl") end @@ -26,6 +23,9 @@ end @safetestset "Hybrid ESN" begin include("esn/test_hybrid.jl") end + @safetestset "Deep ESN" begin + include("esn/test_hybrid.jl") + end end @testset "CA based Reservoirs" begin From 49865f0a841ae0aea4f48d614658ec9e6e176efa Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Wed, 21 Feb 2024 11:06:29 +0100 Subject: [PATCH 19/26] fixed DeepESN and added tests --- Project.toml | 1 + src/ReservoirComputing.jl | 1 - src/esn/deepesn.jl | 4 ++-- src/esn/esn_predict.jl | 6 +++--- src/esn/esn_reservoir_drivers.jl | 1 + test/esn/deepesn.jl | 6 +++--- test/esn/test_hybrid.jl | 14 ++++++++------ test/runtests.jl | 2 +- 8 files changed, 19 insertions(+), 16 deletions(-) diff --git a/Project.toml b/Project.toml index e0d6bb2b..2da4d5c1 100644 --- a/Project.toml +++ b/Project.toml @@ -37,6 +37,7 @@ SafeTestsets = "0.1" SparseArrays = "1.10" Statistics = "1.10" Test = "1" +WeightInitializers = "0.1" julia = "1.6" [extras] diff --git a/src/ReservoirComputing.jl b/src/ReservoirComputing.jl index 4afb7be1..d2047940 100644 --- a/src/ReservoirComputing.jl +++ b/src/ReservoirComputing.jl @@ -18,7 +18,6 @@ using WeightInitializers export NLADefault, NLAT1, NLAT2, NLAT3 export StandardStates, ExtendedStates, PaddedStates, PaddedExtendedStates export StandardRidge, LinearModel -export AbstractLayer, create_layer export scaled_rand, weighted_init, sparse_init, informed_init, minimal_init export rand_sparse, delay_line, delay_line_backward, cycle_jumps, simple_cycle, pseudo_svd export RNN, MRNN, GRU, GRUParams, FullyGated, Minimal diff --git a/src/esn/deepesn.jl b/src/esn/deepesn.jl index 17933f3a..e1e1a641 100644 --- a/src/esn/deepesn.jl +++ b/src/esn/deepesn.jl @@ -33,14 +33,14 @@ function DeepESN(train_data, end reservoir_matrix = [reservoir[i](rng, T, res_size, res_size) for i in 1:depth] - input_matrix = [input_layer[i](rng, T, res_size, in_size) for i in 1:depth] + input_matrix = [i == 1 ? input_layer[i](rng, T, res_size, in_size) : input_layer[i](rng, T, res_size, res_size) for i in 1:depth] bias_vector = [bias[i](rng, res_size) for i in 1:depth] inner_res_driver = reservoir_driver_params(reservoir_driver, res_size, in_size) states = create_states(inner_res_driver, train_data, washout, reservoir_matrix, input_matrix, bias_vector) train_data = train_data[:, (washout + 1):end] - DeepESN(res_size, train_data, variation, nla_type, input_matrix, + DeepESN(res_size, train_data, nla_type, input_matrix, inner_res_driver, reservoir_matrix, bias_vector, states_type, washout, states) end diff --git a/src/esn/esn_predict.jl b/src/esn/esn_predict.jl index 05deb98c..cd843063 100644 --- a/src/esn/esn_predict.jl +++ b/src/esn/esn_predict.jl @@ -59,10 +59,10 @@ function obtain_esn_prediction(esn, end #prediction dispatch on esn -function next_state_prediction!(esn::ESN, x, x_new, out, out_pad, i, tmp_array, args...) +function next_state_prediction!(esn::AbstractEchoStateNetwork, x, x_new, out, out_pad, i, tmp_array, args...) out_pad = pad_state!(esn.states_type, out_pad, out) xv = @view x[1:(esn.res_size)] - x = next_state!(x, esn.reservoir_driver, xv, out_pad, + x = next_state!(x, esn.reservoir_driver, x, out_pad, esn.reservoir_matrix, esn.input_matrix, esn.bias_vector, tmp_array) x_new = esn.states_type(esn.nla_type, x, out_pad) return x, x_new @@ -86,7 +86,7 @@ function next_state_prediction!(hesn::HybridESN, return x, x_new end -function allocate_outpad(ens::ESN, states_type, out) +function allocate_outpad(ens::AbstractEchoStateNetwork, states_type, out) return allocate_singlepadding(states_type, out) end diff --git a/src/esn/esn_reservoir_drivers.jl b/src/esn/esn_reservoir_drivers.jl index a8d7b3b1..41d3439c 100644 --- a/src/esn/esn_reservoir_drivers.jl +++ b/src/esn/esn_reservoir_drivers.jl @@ -135,6 +135,7 @@ function next_state!(out, rnn::RNN, x, y, W::Vector, W_in, b, tmp_array) inner_inputs = vcat([y], inner_states[1:(end - 1)]) for i in 1:esn_depth + inner_states[i] = (1 - rnn.leaky_coefficient) .* inner_states[i] + rnn.leaky_coefficient * rnn.activation_function.((W[i] * inner_states[i]) .+ diff --git a/test/esn/deepesn.jl b/test/esn/deepesn.jl index ce07ec0f..6ebb9c95 100644 --- a/test/esn/deepesn.jl +++ b/test/esn/deepesn.jl @@ -15,7 +15,7 @@ Random.seed!(77) res = rand_sparse(; radius=1.2, sparsity=0.1) esn = DeepESN(input_data, 1, res_size) -output_layer = train(esn, target_data, ta) -output = esn(Predictive(input_data), output_layer) -@test mean(abs.(target_data .- output)) ./ mean(abs.(target_data)) < 0.22 +output_layer = train(esn, target_data) +output = esn(Generative(length(test)), output_layer) +@test mean(abs.(test .- output)) ./ mean(abs.(test)) < 0.22 diff --git a/test/esn/test_hybrid.jl b/test/esn/test_hybrid.jl index 4f858208..415b1343 100644 --- a/test/esn/test_hybrid.jl +++ b/test/esn/test_hybrid.jl @@ -30,15 +30,17 @@ test_data = ode_data[:, (train_len + 1):end][:, 1:1000] predict_len = size(test_data, 2) tspan_train = (tspan[1], ode_sol.t[train_len]) -hybrid = Hybrid(prior_model_data_generator, u0, tspan_train, train_len) +km = KnowledgeModel(prior_model_data_generator, u0, tspan_train, train_len) Random.seed!(77) -esn = ESN(input_data, - reservoir = RandSparseReservoir(300), - variation = hybrid) +hesn = HybridESN(km, + input_data, + 3, + 300; + reservoir = rand_sparse) -output_layer = train(esn, target_data, StandardRidge(0.3)) +output_layer = train(hesn, target_data, StandardRidge(0.3)) -output = esn(Generative(predict_len), output_layer) +output = hesn(Generative(predict_len), output_layer) @test mean(abs.(test_data[1:100] .- output[1:100])) ./ mean(abs.(test_data[1:100])) < 0.11 diff --git a/test/runtests.jl b/test/runtests.jl index d424f467..7aa8defc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -24,7 +24,7 @@ end include("esn/test_hybrid.jl") end @safetestset "Deep ESN" begin - include("esn/test_hybrid.jl") + include("esn/deepesn.jl") end end From 331b6511e7f268a8dd4456c9961db8cb65bd7291 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Mon, 26 Feb 2024 17:34:55 +0100 Subject: [PATCH 20/26] fixed formatting --- docs/pages.jl | 8 +++++--- docs/src/esn_tutorials/change_layers.md | 2 +- src/esn/deepesn.jl | 6 +++--- src/esn/esn_predict.jl | 3 ++- src/esn/esn_reservoir_drivers.jl | 11 +++-------- src/reca/reca_input_encodings.jl | 3 ++- src/train/linear_regression.jl | 3 +-- test/esn/deepesn.jl | 3 +-- test/esn/test_drivers.jl | 4 ++-- test/esn/test_inits.jl | 8 ++++---- test/esn/test_train.jl | 4 ++-- test/test_states.jl | 2 +- 12 files changed, 27 insertions(+), 30 deletions(-) diff --git a/docs/pages.jl b/docs/pages.jl index 309f76fa..0c1465c7 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -1,9 +1,11 @@ pages = [ "ReservoirComputing.jl" => "index.md", - "General Settings" => Any["Changing Training Algorithms" => "general/different_training.md", + "General Settings" => Any[ + "Changing Training Algorithms" => "general/different_training.md", "Altering States" => "general/states_variation.md", "Generative vs Predictive" => "general/predictive_generative.md"], - "Echo State Network Tutorials" => Any["Lorenz System Forecasting" => "esn_tutorials/lorenz_basic.md", + "Echo State Network Tutorials" => Any[ + "Lorenz System Forecasting" => "esn_tutorials/lorenz_basic.md", #"Mackey-Glass Forecasting on GPU" => "esn_tutorials/mackeyglass_basic.md", "Using Different Layers" => "esn_tutorials/change_layers.md", "Using Different Reservoir Drivers" => "esn_tutorials/different_drivers.md", @@ -17,5 +19,5 @@ pages = [ "Echo State Networks" => "api/esn.md", "ESN Layers" => "api/esn_layers.md", "ESN Drivers" => "api/esn_drivers.md", - "ReCA" => "api/reca.md"], + "ReCA" => "api/reca.md"] ] diff --git a/docs/src/esn_tutorials/change_layers.md b/docs/src/esn_tutorials/change_layers.md index 5cfb65cf..86c70477 100644 --- a/docs/src/esn_tutorials/change_layers.md +++ b/docs/src/esn_tutorials/change_layers.md @@ -76,7 +76,7 @@ using ReservoirComputing, StatsBase res_size = 300 input_layer = [ MinimumLayer(0.85, IrrationalSample()), - MinimumLayer(0.95, IrrationalSample()), + MinimumLayer(0.95, IrrationalSample()) ] reservoirs = [SimpleCycleReservoir(res_size, 0.7), CycleJumpsReservoir(res_size, cycle_weight = 0.7, jump_weight = 0.2, jump_size = 5)] diff --git a/src/esn/deepesn.jl b/src/esn/deepesn.jl index e1e1a641..8fe8f798 100644 --- a/src/esn/deepesn.jl +++ b/src/esn/deepesn.jl @@ -14,7 +14,7 @@ end function DeepESN(train_data, in_size::Int, res_size::Int; - depth::Int=2, + depth::Int = 2, input_layer = fill(scaled_rand, depth), bias = fill(zeros64, depth), reservoir = fill(rand_sparse, depth), @@ -25,7 +25,6 @@ function DeepESN(train_data, rng = _default_rng(), T = Float64, matrix_type = typeof(train_data)) - if states_type isa AbstractPaddedStates in_size = size(train_data, 1) + 1 train_data = vcat(Adapt.adapt(matrix_type, ones(1, size(train_data, 2))), @@ -33,7 +32,8 @@ function DeepESN(train_data, end reservoir_matrix = [reservoir[i](rng, T, res_size, res_size) for i in 1:depth] - input_matrix = [i == 1 ? input_layer[i](rng, T, res_size, in_size) : input_layer[i](rng, T, res_size, res_size) for i in 1:depth] + input_matrix = [i == 1 ? input_layer[i](rng, T, res_size, in_size) : + input_layer[i](rng, T, res_size, res_size) for i in 1:depth] bias_vector = [bias[i](rng, res_size) for i in 1:depth] inner_res_driver = reservoir_driver_params(reservoir_driver, res_size, in_size) states = create_states(inner_res_driver, train_data, washout, reservoir_matrix, diff --git a/src/esn/esn_predict.jl b/src/esn/esn_predict.jl index cd843063..cc7cdc5d 100644 --- a/src/esn/esn_predict.jl +++ b/src/esn/esn_predict.jl @@ -59,7 +59,8 @@ function obtain_esn_prediction(esn, end #prediction dispatch on esn -function next_state_prediction!(esn::AbstractEchoStateNetwork, x, x_new, out, out_pad, i, tmp_array, args...) +function next_state_prediction!( + esn::AbstractEchoStateNetwork, x, x_new, out, out_pad, i, tmp_array, args...) out_pad = pad_state!(esn.states_type, out_pad, out) xv = @view x[1:(esn.res_size)] x = next_state!(x, esn.reservoir_driver, x, out_pad, diff --git a/src/esn/esn_reservoir_drivers.jl b/src/esn/esn_reservoir_drivers.jl index 41d3439c..b7b9e03c 100644 --- a/src/esn/esn_reservoir_drivers.jl +++ b/src/esn/esn_reservoir_drivers.jl @@ -129,13 +129,12 @@ end function next_state!(out, rnn::RNN, x, y, W::Vector, W_in, b, tmp_array) esn_depth = length(W) - res_sizes = vcat(0, [size(W[i],1) for i in 1:esn_depth]) + res_sizes = vcat(0, [size(W[i], 1) for i in 1:esn_depth]) inner_states = [x[(1 + sum(res_sizes[1:i])):sum(res_sizes[1:(i + 1)])] for i in 1:esn_depth] inner_inputs = vcat([y], inner_states[1:(end - 1)]) for i in 1:esn_depth - inner_states[i] = (1 - rnn.leaky_coefficient) .* inner_states[i] + rnn.leaky_coefficient * rnn.activation_function.((W[i] * inner_states[i]) .+ @@ -180,9 +179,7 @@ This function creates an MRNN object with the specified activation functions, le "_A novel model of leaky integrator echo state network for time-series prediction._" Neurocomputing 159 (2015): 58-66. """ -function MRNN( - ; - activation_function = [tanh, sigmoid], +function MRNN(; activation_function = [tanh, sigmoid], leaky_coefficient = 1.0, scaling_factor = fill(leaky_coefficient, length(activation_function))) @assert length(activation_function) == length(scaling_factor) @@ -283,9 +280,7 @@ A GRUParams object containing the parameters needed for the GRU-based reservoir "_Learning phrase representations using RNN encoder-decoder for statistical machine translation._" arXiv preprint arXiv:1406.1078 (2014). """ -function GRU( - ; - activation_function = [NNlib.sigmoid, NNlib.sigmoid, tanh], +function GRU(; activation_function = [NNlib.sigmoid, NNlib.sigmoid, tanh], inner_layer = fill(scaled_rand, 2), reservoir = fill(rand_sparse, 2), bias = fill(scaled_rand, 2), diff --git a/src/reca/reca_input_encodings.jl b/src/reca/reca_input_encodings.jl index 8fdd5233..54195f90 100644 --- a/src/reca/reca_input_encodings.jl +++ b/src/reca/reca_input_encodings.jl @@ -66,7 +66,8 @@ function encoding(rm::RandomMaps, input_vector, tot_encoded_vector) new_tot_enc_vec = copy(tot_encoded_vector) for i in 1:(rm.permutations) - new_tot_enc_vec[((i - 1) * rm.expansion_size + 1):(i * rm.expansion_size)] = single_encoding(input_vector, + new_tot_enc_vec[((i - 1) * rm.expansion_size + 1):(i * rm.expansion_size)] = single_encoding( + input_vector, new_tot_enc_vec[((i - 1) * rm.expansion_size + 1):(i * rm.expansion_size)], rm.maps[i, :]) diff --git a/src/train/linear_regression.jl b/src/train/linear_regression.jl index f32b6601..5bb86c5c 100644 --- a/src/train/linear_regression.jl +++ b/src/train/linear_regression.jl @@ -39,8 +39,7 @@ models in the library. All the parameters have to be passed into ```regression_k apart from the solver choice. MLJLinearModels.jl needs to be called in order to use these models. """ -function LinearModel( - ; regression = LinearRegression, +function LinearModel(; regression = LinearRegression, solver = Analytical(), regression_kwargs = (;)) return LinearModel(regression, solver, regression_kwargs) diff --git a/test/esn/deepesn.jl b/test/esn/deepesn.jl index 6ebb9c95..38e06814 100644 --- a/test/esn/deepesn.jl +++ b/test/esn/deepesn.jl @@ -12,10 +12,9 @@ const reg = 10e-6 #test_types = [Float64, Float32, Float16] Random.seed!(77) -res = rand_sparse(; radius=1.2, sparsity=0.1) +res = rand_sparse(; radius = 1.2, sparsity = 0.1) esn = DeepESN(input_data, 1, res_size) output_layer = train(esn, target_data) output = esn(Generative(length(test)), output_layer) @test mean(abs.(test .- output)) ./ mean(abs.(test)) < 0.22 - diff --git a/test/esn/test_drivers.jl b/test/esn/test_drivers.jl index b7034484..7ad47f74 100644 --- a/test/esn/test_drivers.jl +++ b/test/esn/test_drivers.jl @@ -24,7 +24,7 @@ esn_configs = [ :reservoir_driver => GRU(variant = FullyGated(), reservoir = [ rand_sparse(; radius = 1.0, sparsity = 0.5), - rand_sparse(; radius = 1.2, sparsity = 0.1), + rand_sparse(; radius = 1.2, sparsity = 0.1) ])), Dict(:reservoir => rand_sparse(; radius = 1.2), :reservoir_driver => GRU(variant = Minimal(), @@ -33,7 +33,7 @@ esn_configs = [ bias = scaled_rand)), Dict(:reservoir => rand_sparse(; radius = 1.2), :reservoir_driver => MRNN(activation_function = (tanh, sigmoid), - scaling_factor = (0.8, 0.1))), + scaling_factor = (0.8, 0.1))) ] @testset "Test Drivers: $config" for config in esn_configs diff --git a/test/esn/test_inits.jl b/test/esn/test_inits.jl index 789f71e1..6d25ea34 100644 --- a/test/esn/test_inits.jl +++ b/test/esn/test_inits.jl @@ -23,14 +23,14 @@ reservoir_inits = [ delay_line_backward, cycle_jumps, simple_cycle, - pseudo_svd, + pseudo_svd ] input_inits = [ scaled_rand, weighted_init, sparse_init, minimal_init, - minimal_init(; sampling_type = :irrational), + minimal_init(; sampling_type = :irrational) ] @testset "Reservoir Initializers" begin @@ -55,7 +55,7 @@ input_inits = [ delay_line, delay_line_backward, cycle_jumps, - simple_cycle, + simple_cycle ] dl = init(res_size, res_size) if init === delay_line_backward @@ -82,7 +82,7 @@ end @testset "Minimum complexity: $init" for init in [ minimal_init, - minimal_init(; sampling_type = :irrational), + minimal_init(; sampling_type = :irrational) ] dl = init(res_size, in_size) @test sort(unique(dl)) == Float32.([-0.1, 0.1]) diff --git a/test/esn/test_train.jl b/test/esn/test_train.jl index a0f6a4c1..034bbca5 100644 --- a/test/esn/test_train.jl +++ b/test/esn/test_train.jl @@ -12,7 +12,7 @@ const reg = 10e-6 #test_types = [Float64, Float32, Float16] Random.seed!(77) -res = rand_sparse(; radius=1.2, sparsity=0.1) +res = rand_sparse(; radius = 1.2, sparsity = 0.1) esn = ESN(input_data, 1, res_size; reservoir = rand_sparse) @@ -20,7 +20,7 @@ training_methods = [ StandardRidge(regularization_coeff = reg), LinearModel(RidgeRegression, regression_kwargs = (; lambda = reg)), LinearModel(regression = RidgeRegression, regression_kwargs = (; lambda = reg)), - EpsilonSVR(), + EpsilonSVR() ] # TODO check types diff --git a/test/test_states.jl b/test/test_states.jl index 34570668..cd776715 100644 --- a/test/test_states.jl +++ b/test/test_states.jl @@ -30,4 +30,4 @@ pes = [(StandardStates(), test_array), @test states_output == expected_output @test eltype(states_output) == T end -end \ No newline at end of file +end From bd2b9a0a0e6810d37eb05385c6ecd059ba92d836 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Tue, 27 Feb 2024 14:22:40 +0100 Subject: [PATCH 21/26] format --- src/esn/echostatenetwork.jl | 1 - src/esn/esn.jl | 46 +++++++++++--------- src/esn/esn_input_layers.jl | 78 ++++++++++++++++++++------------- src/esn/esn_reservoirs.jl | 87 ++++++++++++++++++++++--------------- src/esn/hybridesn.jl | 15 ++++--- 5 files changed, 135 insertions(+), 92 deletions(-) diff --git a/src/esn/echostatenetwork.jl b/src/esn/echostatenetwork.jl index 49fe2b51..bfa088bf 100644 --- a/src/esn/echostatenetwork.jl +++ b/src/esn/echostatenetwork.jl @@ -276,4 +276,3 @@ end function pad_esnstate!(variation, states_type, x_pad, x, args...) x_pad = pad_state!(states_type, x_pad, x) end - diff --git a/src/esn/esn.jl b/src/esn/esn.jl index 4e5fb41e..2b351ce3 100644 --- a/src/esn/esn.jl +++ b/src/esn/esn.jl @@ -18,27 +18,30 @@ end Creates an Echo State Network (ESN) using specified parameters and training data, suitable for various machine learning tasks. # Parameters -- `train_data`: Matrix of training data (columns as time steps, rows as features). -- `variation`: Variation of ESN (default: `Default()`). -- `input_layer`: Input layer of ESN (default: `DenseLayer()`). -- `reservoir`: Reservoir of the ESN (default: `RandSparseReservoir(100)`). -- `bias`: Bias vector for each time step (default: `NullLayer()`). -- `reservoir_driver`: Mechanism for evolving reservoir states (default: `RNN()`). -- `nla_type`: Non-linear activation type (default: `NLADefault()`). -- `states_type`: Format for storing states (default: `StandardStates()`). -- `washout`: Initial time steps to discard (default: `0`). -- `matrix_type`: Type of matrices used internally (default: type of `train_data`). + + - `train_data`: Matrix of training data (columns as time steps, rows as features). + - `variation`: Variation of ESN (default: `Default()`). + - `input_layer`: Input layer of ESN (default: `DenseLayer()`). + - `reservoir`: Reservoir of the ESN (default: `RandSparseReservoir(100)`). + - `bias`: Bias vector for each time step (default: `NullLayer()`). + - `reservoir_driver`: Mechanism for evolving reservoir states (default: `RNN()`). + - `nla_type`: Non-linear activation type (default: `NLADefault()`). + - `states_type`: Format for storing states (default: `StandardStates()`). + - `washout`: Initial time steps to discard (default: `0`). + - `matrix_type`: Type of matrices used internally (default: type of `train_data`). # Returns -- An initialized ESN instance with specified parameters. + + - An initialized ESN instance with specified parameters. # Examples + ```julia using ReservoirComputing train_data = rand(10, 100) # 10 features, 100 time steps -esn = ESN(train_data, reservoir=RandSparseReservoir(200), washout=10) +esn = ESN(train_data, reservoir = RandSparseReservoir(200), washout = 10) ``` """ function ESN(train_data, @@ -90,15 +93,17 @@ end Trains an Echo State Network (ESN) using the provided target data and a specified training method. # Parameters -- `esn::AbstractEchoStateNetwork`: The ESN instance to be trained. -- `target_data`: Supervised training data for the ESN. -- `training_method`: The method for training the ESN (default: `StandardRidge(0.0)`). + + - `esn::AbstractEchoStateNetwork`: The ESN instance to be trained. + - `target_data`: Supervised training data for the ESN. + - `training_method`: The method for training the ESN (default: `StandardRidge(0.0)`). # Returns -- The trained ESN model. Its type and structure depend on `training_method` and the ESN's implementation. + - The trained ESN model. Its type and structure depend on `training_method` and the ESN's implementation. # Returns + The trained ESN model. The exact type and structure of the return value depends on the `training_method` and the specific ESN implementation. @@ -106,20 +111,21 @@ The trained ESN model. The exact type and structure of the return value depends using ReservoirComputing # Initialize an ESN instance and target data -esn = ESN(train_data, reservoir=RandSparseReservoir(200), washout=10) +esn = ESN(train_data, reservoir = RandSparseReservoir(200), washout = 10) target_data = rand(size(train_data, 2)) # Train the ESN using the default training method trained_esn = train(esn, target_data) # Train the ESN using a custom training method -trained_esn = train(esn, target_data, training_method=StandardRidge(1.0)) +trained_esn = train(esn, target_data, training_method = StandardRidge(1.0)) ``` # Notes -- When using a `Hybrid` variation, the function extends the state matrix with data from the + + - When using a `Hybrid` variation, the function extends the state matrix with data from the physical model included in the `variation`. -- The training is handled by a lower-level `_train` function which takes the new state matrix + - The training is handled by a lower-level `_train` function which takes the new state matrix and performs the actual training using the specified `training_method`. """ function train(esn::AbstractEchoStateNetwork, diff --git a/src/esn/esn_input_layers.jl b/src/esn/esn_input_layers.jl index 10ea2330..58aab7c4 100644 --- a/src/esn/esn_input_layers.jl +++ b/src/esn/esn_input_layers.jl @@ -4,18 +4,22 @@ Create and return a matrix with random values, uniformly distributed within a range defined by `scaling`. This function is useful for initializing matrices, such as the layers of a neural network, with scaled random values. # Arguments -- `rng`: An instance of `AbstractRNG` for random number generation. -- `T`: The data type for the elements of the matrix. -- `dims`: Dimensions of the matrix. It must be a 2-element tuple specifying the number of rows and columns (e.g., `(res_size, in_size)`). -- `scaling`: A scaling factor to define the range of the uniform distribution. The matrix elements will be randomly chosen from the range `[-scaling, scaling]`. Defaults to `T(0.1)`. + + - `rng`: An instance of `AbstractRNG` for random number generation. + - `T`: The data type for the elements of the matrix. + - `dims`: Dimensions of the matrix. It must be a 2-element tuple specifying the number of rows and columns (e.g., `(res_size, in_size)`). + - `scaling`: A scaling factor to define the range of the uniform distribution. The matrix elements will be randomly chosen from the range `[-scaling, scaling]`. Defaults to `T(0.1)`. # Returns + A matrix of type with dimensions specified by `dims`. Each element of the matrix is a random number uniformly distributed between `-scaling` and `scaling`. # Example + ```julia rng = Random.default_rng() -matrix = scaled_rand(rng, Float64, (100, 50); scaling=0.2) +matrix = scaled_rand(rng, Float64, (100, 50); scaling = 0.2) +``` """ function scaled_rand(rng::AbstractRNG, ::Type{T}, @@ -32,20 +36,25 @@ end Create and return a matrix representing a weighted input layer for Echo State Networks (ESNs). This initializer generates a weighted input matrix with random non-zero elements distributed uniformly within the range [-`scaling`, `scaling`], inspired by the approach in [^Lu]. # Arguments -- `rng`: An instance of `AbstractRNG` for random number generation. -- `T`: The data type for the elements of the matrix. -- `dims`: A 2-element tuple specifying the approximate reservoir size and input size (e.g., `(approx_res_size, in_size)`). -- `scaling`: The scaling factor for the weight distribution. Defaults to `T(0.1)`. + + - `rng`: An instance of `AbstractRNG` for random number generation. + - `T`: The data type for the elements of the matrix. + - `dims`: A 2-element tuple specifying the approximate reservoir size and input size (e.g., `(approx_res_size, in_size)`). + - `scaling`: The scaling factor for the weight distribution. Defaults to `T(0.1)`. # Returns + A matrix representing the weighted input layer as defined in [^Lu2017]. The matrix dimensions will be adjusted to ensure each input unit connects to an equal number of reservoir units. # Example + ```julia rng = Random.default_rng() -input_layer = weighted_init(rng, Float64, (3, 300); scaling=0.2) +input_layer = weighted_init(rng, Float64, (3, 300); scaling = 0.2) ``` + # References + [^Lu2017]: Lu, Zhixin, et al. "Reservoir observers: Model-free inference of unmeasured variables in chaotic systems." Chaos: An Interdisciplinary Journal of Nonlinear Science 27.4 (2017): 041102. @@ -76,20 +85,22 @@ Create and return a sparse layer matrix for use in neural network models. The matrix will be of size specified by `dims`, with the specified `sparsity` and `scaling`. # Arguments -- `rng`: An instance of `AbstractRNG` for random number generation. -- `T`: The data type for the elements of the matrix. -- `dims`: Dimensions of the resulting sparse layer matrix. -- `scaling`: The scaling factor for the sparse layer matrix. Defaults to 0.1. -- `sparsity`: The sparsity level of the sparse layer matrix, controlling the fraction of zero elements. Defaults to 0.1. + + - `rng`: An instance of `AbstractRNG` for random number generation. + - `T`: The data type for the elements of the matrix. + - `dims`: Dimensions of the resulting sparse layer matrix. + - `scaling`: The scaling factor for the sparse layer matrix. Defaults to 0.1. + - `sparsity`: The sparsity level of the sparse layer matrix, controlling the fraction of zero elements. Defaults to 0.1. # Returns -A sparse layer matrix. +A sparse layer matrix. # Example + ```julia rng = Random.default_rng() -input_layer = sparse_init(rng, Float64, (3, 300); scaling=0.2, sparsity=0.1) +input_layer = sparse_init(rng, Float64, (3, 300); scaling = 0.2, sparsity = 0.1) ``` """ function sparse_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; @@ -109,22 +120,25 @@ end Create a layer of a neural network. # Arguments -- `rng::AbstractRNG`: The random number generator. -- `T::Type`: The data type. -- `dims::Integer...`: The dimensions of the layer. -- `scaling::T = T(0.1)`: The scaling factor for the input matrix. -- `model_in_size`: The size of the input model. -- `gamma::T = T(0.5)`: The gamma value. + + - `rng::AbstractRNG`: The random number generator. + - `T::Type`: The data type. + - `dims::Integer...`: The dimensions of the layer. + - `scaling::T = T(0.1)`: The scaling factor for the input matrix. + - `model_in_size`: The size of the input model. + - `gamma::T = T(0.5)`: The gamma value. # Returns -- `input_matrix`: The created input matrix for the layer. + + - `input_matrix`: The created input matrix for the layer. # Example + ```julia rng = Random.default_rng() dims = (100, 200) model_in_size = 50 -input_matrix = informed_init(rng, Float64, dims; model_in_size=model_in_size) +input_matrix = informed_init(rng, Float64, dims; model_in_size = model_in_size) ``` """ function informed_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; @@ -169,21 +183,25 @@ end Create a layer matrix using the provided random number generator and sampling parameters. # Arguments -- `rng::AbstractRNG`: The random number generator used to generate random numbers. -- `dims::Integer...`: The dimensions of the layer matrix. -- `weight`: The weight used to fill the layer matrix. Default is 0.1. -- `sampling`: The sampling parameters used to generate the input matrix. Default is IrrationalSample(irrational = pi, start = 1). + + - `rng::AbstractRNG`: The random number generator used to generate random numbers. + - `dims::Integer...`: The dimensions of the layer matrix. + - `weight`: The weight used to fill the layer matrix. Default is 0.1. + - `sampling`: The sampling parameters used to generate the input matrix. Default is IrrationalSample(irrational = pi, start = 1). # Returns + The layer matrix generated using the provided random number generator and sampling parameters. # Example + ```julia using Random rng = Random.default_rng() dims = (3, 2) weight = 0.5 -layer_matrix = irrational_sample_init(rng, Float64, dims; weight = weight, sampling = IrrationalSample(irrational = sqrt(2), start = 1)) +layer_matrix = irrational_sample_init(rng, Float64, dims; weight = weight, + sampling = IrrationalSample(irrational = sqrt(2), start = 1)) ``` """ function minimal_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; diff --git a/src/esn/esn_reservoirs.jl b/src/esn/esn_reservoirs.jl index 390fb364..85dcf94f 100644 --- a/src/esn/esn_reservoirs.jl +++ b/src/esn/esn_reservoirs.jl @@ -4,16 +4,19 @@ Create and return a random sparse reservoir matrix for use in Echo State Networks (ESNs). The matrix will be of size specified by `dims`, with specified `sparsity` and scaled spectral radius according to `radius`. # Arguments -- `rng`: An instance of `AbstractRNG` for random number generation. -- `T`: The data type for the elements of the matrix. -- `dims`: Dimensions of the reservoir matrix. -- `radius`: The desired spectral radius of the reservoir. Defaults to 1.0. -- `sparsity`: The sparsity level of the reservoir matrix, controlling the fraction of zero elements. Defaults to 0.1. + + - `rng`: An instance of `AbstractRNG` for random number generation. + - `T`: The data type for the elements of the matrix. + - `dims`: Dimensions of the reservoir matrix. + - `radius`: The desired spectral radius of the reservoir. Defaults to 1.0. + - `sparsity`: The sparsity level of the reservoir matrix, controlling the fraction of zero elements. Defaults to 0.1. # Returns + A matrix representing the random sparse reservoir. # References + This type of reservoir initialization is commonly used in ESNs for capturing temporal dependencies in data. """ function rand_sparse(rng::AbstractRNG, @@ -38,20 +41,24 @@ end Create and return a delay line reservoir matrix for use in Echo State Networks (ESNs). A delay line reservoir is a deterministic structure where each unit is connected only to its immediate predecessor with a specified weight. This method is particularly useful for tasks that require specific temporal processing. # Arguments -- `rng`: An instance of `AbstractRNG` for random number generation. This argument is not used in the current implementation but is included for consistency with other initialization functions. -- `T`: The data type for the elements of the matrix. -- `dims`: Dimensions of the reservoir matrix. Typically, this should be a tuple of two equal integers representing a square matrix. -- `weight`: The weight determines the absolute value of all connections in the reservoir. Defaults to 0.1. + + - `rng`: An instance of `AbstractRNG` for random number generation. This argument is not used in the current implementation but is included for consistency with other initialization functions. + - `T`: The data type for the elements of the matrix. + - `dims`: Dimensions of the reservoir matrix. Typically, this should be a tuple of two equal integers representing a square matrix. + - `weight`: The weight determines the absolute value of all connections in the reservoir. Defaults to 0.1. # Returns + A delay line reservoir matrix with dimensions specified by `dims`. The matrix is initialized such that each element in the `i+1`th row and `i`th column is set to `weight`, and all other elements are zeros. # Example + ```julia -reservoir = delay_line(Float64, 100, 100; weight=0.2) +reservoir = delay_line(Float64, 100, 100; weight = 0.2) ``` # References + This type of reservoir initialization is described in: Rodan, Ali, and Peter Tino. "Minimum complexity echo state network." IEEE Transactions on Neural Networks 22.1 (2010): 131-144. """ @@ -78,19 +85,21 @@ as described in [^Rodan2010]. The `weight` and `fb_weight` can be passed as eith keyword arguments, and they determine the absolute values of the connections in the reservoir. # Arguments -- `rng::AbstractRNG`: Random number generator. -- `T::Type`: Type of the elements in the reservoir matrix. -- `dims::Integer...`: Dimensions of the reservoir matrix. -- `weight::T`: The weight determines the absolute value of forward connections in the reservoir, and is set to 0.1 by default. -- `fb_weight::T`: The `fb_weight` determines the absolute value of backward connections in the reservoir, and is set to 0.2 by default. + - `rng::AbstractRNG`: Random number generator. + - `T::Type`: Type of the elements in the reservoir matrix. + - `dims::Integer...`: Dimensions of the reservoir matrix. + - `weight::T`: The weight determines the absolute value of forward connections in the reservoir, and is set to 0.1 by default. + - `fb_weight::T`: The `fb_weight` determines the absolute value of backward connections in the reservoir, and is set to 0.2 by default. # Returns + Reservoir matrix with the dimensions specified by `dims` and weights. # References + [^Rodan2010]: Rodan, Ali, and Peter Tino. "Minimum complexity echo state network." -IEEE transactions on neural networks 22.1 (2010): 131-144. + IEEE transactions on neural networks 22.1 (2010): 131-144. """ function delay_line_backward(rng::AbstractRNG, ::Type{T}, @@ -115,19 +124,22 @@ end Create a cycle jumps reservoir with the specified dimensions, cycle weight, jump weight, and jump size. # Arguments -- `rng::AbstractRNG`: Random number generator. -- `T::Type`: Type of the elements in the reservoir matrix. -- `dims::Integer...`: Dimensions of the reservoir matrix. -- `cycle_weight::T = T(0.1)`: The weight of cycle connections. -- `jump_weight::T = T(0.1)`: The weight of jump connections. -- `jump_size::Int = 3`: The number of steps between jump connections. + + - `rng::AbstractRNG`: Random number generator. + - `T::Type`: Type of the elements in the reservoir matrix. + - `dims::Integer...`: Dimensions of the reservoir matrix. + - `cycle_weight::T = T(0.1)`: The weight of cycle connections. + - `jump_weight::T = T(0.1)`: The weight of jump connections. + - `jump_size::Int = 3`: The number of steps between jump connections. # Returns + Reservoir matrix with the specified dimensions, cycle weight, jump weight, and jump size. # References + [^Rodan2012]: Rodan, Ali, and Peter Tiňo. "Simple deterministically constructed cycle reservoirs -with regular jumps." Neural computation 24.7 (2012): 1822-1852. + with regular jumps." Neural computation 24.7 (2012): 1822-1852. """ function cycle_jumps(rng::AbstractRNG, ::Type{T}, @@ -163,17 +175,20 @@ end Create a simple cycle reservoir with the specified dimensions and weight. # Arguments -- `rng::AbstractRNG`: Random number generator. -- `T::Type`: Type of the elements in the reservoir matrix. -- `dims::Integer...`: Dimensions of the reservoir matrix. -- `weight::T = T(0.1)`: Weight of the connections in the reservoir matrix. + + - `rng::AbstractRNG`: Random number generator. + - `T::Type`: Type of the elements in the reservoir matrix. + - `dims::Integer...`: Dimensions of the reservoir matrix. + - `weight::T = T(0.1)`: Weight of the connections in the reservoir matrix. # Returns + Reservoir matrix with the dimensions specified by `dims` and weights. # References + [^Rodan2010]: Rodan, Ali, and Peter Tino. "Minimum complexity echo state network." -IEEE transactions on neural networks 22.1 (2010): 131-144. + IEEE transactions on neural networks 22.1 (2010): 131-144. """ function simple_cycle(rng::AbstractRNG, ::Type{T}, @@ -196,15 +211,17 @@ end Returns an initializer to build a sparse reservoir matrix with the given `sparsity` by using a pseudo-SVD approach as described in [^yang]. # Arguments -- `rng::AbstractRNG`: Random number generator. -- `T::Type`: Type of the elements in the reservoir matrix. -- `dims::Integer...`: Dimensions of the reservoir matrix. -- `max_value`: The maximum absolute value of elements in the matrix. -- `sparsity`: The desired sparsity level of the reservoir matrix. -- `sorted`: A boolean indicating whether to sort the singular values before creating the diagonal matrix. By default, it is set to `true`. -- `reverse_sort`: A boolean indicating whether to reverse the sorted singular values. By default, it is set to `false`. + + - `rng::AbstractRNG`: Random number generator. + - `T::Type`: Type of the elements in the reservoir matrix. + - `dims::Integer...`: Dimensions of the reservoir matrix. + - `max_value`: The maximum absolute value of elements in the matrix. + - `sparsity`: The desired sparsity level of the reservoir matrix. + - `sorted`: A boolean indicating whether to sort the singular values before creating the diagonal matrix. By default, it is set to `true`. + - `reverse_sort`: A boolean indicating whether to reverse the sorted singular values. By default, it is set to `false`. # Returns + Reservoir matrix with the specified dimensions, max value, and sparsity. # References diff --git a/src/esn/hybridesn.jl b/src/esn/hybridesn.jl index 37b40d0c..a190590a 100644 --- a/src/esn/hybridesn.jl +++ b/src/esn/hybridesn.jl @@ -25,20 +25,23 @@ end Hybrid(prior_model, u0, tspan, datasize) Constructs a `Hybrid` variation of Echo State Networks (ESNs) integrating a knowledge-based model -(`prior_model`) with ESNs for advanced training and prediction in chaotic systems. +(`prior_model`) with ESNs for advanced training and prediction in chaotic systems. # Parameters -- `prior_model`: A knowledge-based model function for integration with ESNs. -- `u0`: Initial conditions for the model. -- `tspan`: Time span as a tuple, indicating the duration for model operation. -- `datasize`: The size of the data to be processed. + + - `prior_model`: A knowledge-based model function for integration with ESNs. + - `u0`: Initial conditions for the model. + - `tspan`: Time span as a tuple, indicating the duration for model operation. + - `datasize`: The size of the data to be processed. # Returns -- A `Hybrid` struct instance representing the combined ESN and knowledge-based model. + + - A `Hybrid` struct instance representing the combined ESN and knowledge-based model. This method is effective for chaotic processes as highlighted in [^Pathak]. Reference: + [^Pathak]: Jaideep Pathak et al. "Hybrid Forecasting of Chaotic Processes: Using Machine Learning in Conjunction with a Knowledge-Based Model" (2018). From edfff16072e4c479d7c9eedf3f49e39f69b13b30 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Tue, 27 Feb 2024 14:46:15 +0100 Subject: [PATCH 22/26] julia bump --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 2da4d5c1..3635952e 100644 --- a/Project.toml +++ b/Project.toml @@ -38,7 +38,7 @@ SparseArrays = "1.10" Statistics = "1.10" Test = "1" WeightInitializers = "0.1" -julia = "1.6" +julia = "1.10" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" From 8c9056f3ab26ce728e0bef6639cdc80c401cda2f Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Tue, 27 Feb 2024 14:49:05 +0100 Subject: [PATCH 23/26] bump random --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 3635952e..6fb383e7 100644 --- a/Project.toml +++ b/Project.toml @@ -32,7 +32,7 @@ MLJLinearModels = "0.9.2" NNlib = "0.8.4, 0.9" Optim = "1" PartialFunctions = "1.2" -Random = "1" +Random = "1.10" SafeTestsets = "0.1" SparseArrays = "1.10" Statistics = "1.10" From 00bafa557a46f86f81e97928c759becdd409c7f4 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Tue, 27 Feb 2024 14:53:19 +0100 Subject: [PATCH 24/26] bumps for WI and Distributions --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 6fb383e7..6e3c8c54 100644 --- a/Project.toml +++ b/Project.toml @@ -25,7 +25,7 @@ Aqua = "0.8" CellularAutomata = "0.0.2" DifferentialEquations = "7" Distances = "0.10" -Distributions = "0.24.5, 0.25" +Distributions = "0.25" LIBSVM = "0.8" LinearAlgebra = "1.10" MLJLinearModels = "0.9.2" @@ -37,7 +37,7 @@ SafeTestsets = "0.1" SparseArrays = "1.10" Statistics = "1.10" Test = "1" -WeightInitializers = "0.1" +WeightInitializers = "0.1.5" julia = "1.10" [extras] From 15b7d452df9f0dac08343cb95dd6861fc218841d Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Tue, 27 Feb 2024 15:36:05 +0100 Subject: [PATCH 25/26] bump Distributions --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 6e3c8c54..6c5272ea 100644 --- a/Project.toml +++ b/Project.toml @@ -25,7 +25,7 @@ Aqua = "0.8" CellularAutomata = "0.0.2" DifferentialEquations = "7" Distances = "0.10" -Distributions = "0.25" +Distributions = "0.25.36" LIBSVM = "0.8" LinearAlgebra = "1.10" MLJLinearModels = "0.9.2" From c028bcd241400f1fe5d29c17da95b081f2bffb74 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Tue, 27 Feb 2024 16:21:45 +0100 Subject: [PATCH 26/26] rm echostatenetwork.jl --- src/esn/echostatenetwork.jl | 278 ------------------------------------ 1 file changed, 278 deletions(-) delete mode 100644 src/esn/echostatenetwork.jl diff --git a/src/esn/echostatenetwork.jl b/src/esn/echostatenetwork.jl deleted file mode 100644 index bfa088bf..00000000 --- a/src/esn/echostatenetwork.jl +++ /dev/null @@ -1,278 +0,0 @@ -abstract type AbstractEchoStateNetwork <: AbstractReservoirComputer end -struct ESN{I, S, V, N, T, O, M, B, ST, W, IS} <: AbstractEchoStateNetwork - res_size::I - train_data::S - variation::V - nla_type::N - input_matrix::T - reservoir_driver::O - reservoir_matrix::M - bias_vector::B - states_type::ST - washout::W - states::IS -end - -""" - Default() - -The `Default` struct specifies the use of the standard model in Echo State Networks (ESNs). -It requires no parameters and is used when no specific variations or customizations of the ESN model are needed. -This struct is ideal for straightforward applications where the default ESN settings are sufficient. -""" -struct Default <: AbstractVariation end -struct Hybrid{T, K, O, I, S, D} <: AbstractVariation - prior_model::T - u0::K - tspan::O - dt::I - datasize::S - model_data::D -end - -""" - Hybrid(prior_model, u0, tspan, datasize) - -Constructs a `Hybrid` variation of Echo State Networks (ESNs) integrating a knowledge-based model -(`prior_model`) with ESNs for advanced training and prediction in chaotic systems. - -# Parameters - - - `prior_model`: A knowledge-based model function for integration with ESNs. - - `u0`: Initial conditions for the model. - - `tspan`: Time span as a tuple, indicating the duration for model operation. - - `datasize`: The size of the data to be processed. - -# Returns - - - A `Hybrid` struct instance representing the combined ESN and knowledge-based model. - -This method is effective for chaotic processes as highlighted in [^Pathak]. - -Reference: - -[^Pathak]: Jaideep Pathak et al. - "Hybrid Forecasting of Chaotic Processes: - Using Machine Learning in Conjunction with a Knowledge-Based Model" (2018). -""" -function Hybrid(prior_model, u0, tspan, datasize) - trange = collect(range(tspan[1], tspan[2], length = datasize)) - dt = trange[2] - trange[1] - tsteps = push!(trange, dt + trange[end]) - tspan_new = (tspan[1], dt + tspan[2]) - model_data = prior_model(u0, tspan_new, tsteps) - return Hybrid(prior_model, u0, tspan, dt, datasize, model_data) -end - -""" - ESN(train_data; kwargs...) -> ESN - -Creates an Echo State Network (ESN) using specified parameters and training data, suitable for various machine learning tasks. - -# Parameters - - - `train_data`: Matrix of training data (columns as time steps, rows as features). - - `variation`: Variation of ESN (default: `Default()`). - - `input_layer`: Input layer of ESN (default: `DenseLayer()`). - - `reservoir`: Reservoir of the ESN (default: `RandSparseReservoir(100)`). - - `bias`: Bias vector for each time step (default: `NullLayer()`). - - `reservoir_driver`: Mechanism for evolving reservoir states (default: `RNN()`). - - `nla_type`: Non-linear activation type (default: `NLADefault()`). - - `states_type`: Format for storing states (default: `StandardStates()`). - - `washout`: Initial time steps to discard (default: `0`). - - `matrix_type`: Type of matrices used internally (default: type of `train_data`). - -# Returns - - - An initialized ESN instance with specified parameters. - -# Examples - -```julia -using ReservoirComputing - -train_data = rand(10, 100) # 10 features, 100 time steps - -esn = ESN(train_data, reservoir = RandSparseReservoir(200), washout = 10) -``` -""" -function ESN(train_data; - variation = Default(), - input_layer = DenseLayer(), - reservoir = RandSparseReservoir(100), - bias = NullLayer(), - reservoir_driver = RNN(), - nla_type = NLADefault(), - states_type = StandardStates(), - washout = 0, - matrix_type = typeof(train_data)) - if variation isa Hybrid - train_data = vcat(train_data, variation.model_data[:, 1:(end - 1)]) - end - - if states_type isa AbstractPaddedStates - in_size = size(train_data, 1) + 1 - train_data = vcat(Adapt.adapt(matrix_type, ones(1, size(train_data, 2))), - train_data) - else - in_size = size(train_data, 1) - end - - input_matrix, reservoir_matrix, bias_vector, res_size = obtain_layers(in_size, - input_layer, - reservoir, bias; - matrix_type = matrix_type) - - inner_res_driver = reservoir_driver_params(reservoir_driver, res_size, in_size) - states = create_states(inner_res_driver, train_data, washout, reservoir_matrix, - input_matrix, bias_vector) - train_data = train_data[:, (washout + 1):end] - - ESN(sum(res_size), train_data, variation, nla_type, input_matrix, - inner_res_driver, reservoir_matrix, bias_vector, states_type, washout, - states) -end - -#shallow esn construction -function obtain_layers(in_size, - input_layer, - reservoir, - bias; - matrix_type = Matrix{Float64}) - input_res_size = get_ressize(reservoir) - input_matrix = create_layer(input_layer, input_res_size, in_size, - matrix_type = matrix_type) - res_size = size(input_matrix, 1) #WeightedInput actually changes the res size - reservoir_matrix = create_reservoir(reservoir, res_size, matrix_type = matrix_type) - @assert size(reservoir_matrix, 1) == res_size - bias_vector = create_layer(bias, res_size, 1, matrix_type = matrix_type) - return input_matrix, reservoir_matrix, bias_vector, res_size -end - -#deep esn construction -#there is a bug going on with WeightedLayer in this construction. -#it works for eny other though -function obtain_layers(in_size, - input_layer, - reservoir::Vector, - bias; - matrix_type = Matrix{Float64}) - esn_depth = length(reservoir) - input_res_sizes = [get_ressize(reservoir[i]) for i in 1:esn_depth] - in_sizes = zeros(Int, esn_depth) - in_sizes[2:end] = input_res_sizes[1:(end - 1)] - in_sizes[1] = in_size - - if input_layer isa Array - input_matrix = [create_layer(input_layer[j], input_res_sizes[j], in_sizes[j], - matrix_type = matrix_type) for j in 1:esn_depth] - else - _input_layer = fill(input_layer, esn_depth) - input_matrix = [create_layer(_input_layer[k], input_res_sizes[k], in_sizes[k], - matrix_type = matrix_type) for k in 1:esn_depth] - end - - res_sizes = [get_ressize(input_matrix[j]) for j in 1:esn_depth] - reservoir_matrix = [create_reservoir(reservoir[k], res_sizes[k], - matrix_type = matrix_type) for k in 1:esn_depth] - - if bias isa Array - bias_vector = [create_layer(bias[j], res_sizes[j], 1, matrix_type = matrix_type) - for j in 1:esn_depth] - else - _bias = fill(bias, esn_depth) - bias_vector = [create_layer(_bias[k], res_sizes[k], 1, matrix_type = matrix_type) - for k in 1:esn_depth] - end - - return input_matrix, reservoir_matrix, bias_vector, res_sizes -end - -function (esn::ESN)(prediction::AbstractPrediction, - output_layer::AbstractOutputLayer; - last_state = esn.states[:, [end]], - kwargs...) - variation = esn.variation - pred_len = prediction.prediction_len - - if variation isa Hybrid - model = variation.prior_model - predict_tsteps = [variation.tspan[2] + variation.dt] - [append!(predict_tsteps, predict_tsteps[end] + variation.dt) for i in 1:pred_len] - tspan_new = (variation.tspan[2] + variation.dt, predict_tsteps[end]) - u0 = variation.model_data[:, end] - model_pred_data = model(u0, tspan_new, predict_tsteps)[:, 2:end] - return obtain_esn_prediction(esn, prediction, last_state, output_layer, - model_pred_data; - kwargs...) - else - return obtain_esn_prediction(esn, prediction, last_state, output_layer; - kwargs...) - end -end - -#training dispatch on esn -""" - train(esn::AbstractEchoStateNetwork, target_data, training_method = StandardRidge(0.0)) - -Trains an Echo State Network (ESN) using the provided target data and a specified training method. - -# Parameters - - - `esn::AbstractEchoStateNetwork`: The ESN instance to be trained. - - `target_data`: Supervised training data for the ESN. - - `training_method`: The method for training the ESN (default: `StandardRidge(0.0)`). - -# Returns - - - The trained ESN model. Its type and structure depend on `training_method` and the ESN's implementation. - -# Returns - -The trained ESN model. The exact type and structure of the return value depends on the -`training_method` and the specific ESN implementation. - -```julia -using ReservoirComputing - -# Initialize an ESN instance and target data -esn = ESN(train_data, reservoir = RandSparseReservoir(200), washout = 10) -target_data = rand(size(train_data, 2)) - -# Train the ESN using the default training method -trained_esn = train(esn, target_data) - -# Train the ESN using a custom training method -trained_esn = train(esn, target_data, training_method = StandardRidge(1.0)) -``` - -# Notes - - - When using a `Hybrid` variation, the function extends the state matrix with data from the - physical model included in the `variation`. - - The training is handled by a lower-level `_train` function which takes the new state matrix - and performs the actual training using the specified `training_method`. -""" -function train(esn::AbstractEchoStateNetwork, - target_data, - training_method = StandardRidge(0.0)) - variation = esn.variation - - if esn.variation isa Hybrid - states = vcat(esn.states, esn.variation.model_data[:, 2:end]) - else - states = esn.states - end - states_new = esn.states_type(esn.nla_type, states, esn.train_data[:, 1:end]) - - return _train(states_new, target_data, training_method) -end - -function pad_esnstate(variation::Hybrid, states_type, x_pad, x, model_prediction_data) - x_tmp = vcat(x, model_prediction_data) - x_pad = pad_state!(states_type, x_pad, x_tmp) -end - -function pad_esnstate!(variation, states_type, x_pad, x, args...) - x_pad = pad_state!(states_type, x_pad, x) -end