diff --git a/Project.toml b/Project.toml index b3207f6e..6c5272ea 100644 --- a/Project.toml +++ b/Project.toml @@ -13,8 +13,11 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLJLinearModels = "6ee0df7b-362f-4a72-a706-9e79364fb692" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Optim = "429524aa-4258-5aef-a3af-852621145aeb" +PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" [compat] Adapt = "3.3.3, 4" @@ -22,18 +25,20 @@ Aqua = "0.8" CellularAutomata = "0.0.2" DifferentialEquations = "7" Distances = "0.10" -Distributions = "0.24.5, 0.25" +Distributions = "0.25.36" LIBSVM = "0.8" LinearAlgebra = "1.10" MLJLinearModels = "0.9.2" NNlib = "0.8.4, 0.9" Optim = "1" -Random = "1" +PartialFunctions = "1.2" +Random = "1.10" SafeTestsets = "0.1" SparseArrays = "1.10" Statistics = "1.10" Test = "1" -julia = "1.6" +WeightInitializers = "0.1.5" +julia = "1.10" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" diff --git a/README.md b/README.md index 9172725e..b4667be0 100644 --- a/README.md +++ b/README.md @@ -51,14 +51,15 @@ test = data[:, (shift + train_len):(shift + train_len + predict_len - 1)] Now that we have the data we can initialize the ESN with the chosen parameters. Given that this is a quick example we are going to change the least amount of possible parameters. For more detailed examples and explanations of the functions please refer to the documentation. ```julia +input_size = 3 res_size = 300 -esn = ESN(input_data; - reservoir = RandSparseReservoir(res_size, radius = 1.2, sparsity = 6 / res_size), - input_layer = WeightedLayer(), +esn = ESN(input_data, input_size, res_size; + reservoir = rand_sparse(; radius = 1.2, sparsity = 6 / res_size), + input_layer = weighted_init, nla_type = NLAT2()) ``` -The echo state network can now be trained and tested. If not specified, the training will always be Ordinary Least Squares regression. The full range of training methods is detailed in the documentation. +The echo state network can now be trained and tested. If not specified, the training will always be ordinary least squares regression. The full range of training methods is detailed in the documentation. ```julia output_layer = train(esn, target_data) @@ -103,3 +104,7 @@ If you use this library in your work, please cite: url = {http://jmlr.org/papers/v23/22-0611.html} } ``` + +## Acknowledgements + +This project was possible thanks to initial funding through the [Google summer of code](https://summerofcode.withgoogle.com/) 2020 program. Francesco M. further acknowledges [ScaDS.AI](https://scads.ai/) and [RSC4Earth](https://rsc4earth.de/) for supporting the current progress on the library. diff --git a/docs/src/esn_tutorials/hybrid.md b/docs/src/esn_tutorials/hybrid.md index bf274f01..25797089 100644 --- a/docs/src/esn_tutorials/hybrid.md +++ b/docs/src/esn_tutorials/hybrid.md @@ -1,6 +1,6 @@ # Hybrid Echo State Networks -Following the idea of giving physical information to machine learning models, the hybrid echo state networks [^1] try to achieve this results by feeding model data into the ESN. In this example, it is explained how to create and leverage such models in ReservoirComputing.jl. The full script for this example is available [here](https://github.com/MartinuzziFrancesco/reservoir-computing-examples/blob/main/hybrid/hybrid.jl). This example was run on Julia v1.7.2. +Following the idea of giving physical information to machine learning models, the hybrid echo state networks [^1] try to achieve this results by feeding model data into the ESN. In this example, it is explained how to create and leverage such models in ReservoirComputing.jl. ## Generating the data @@ -47,17 +47,21 @@ function prior_model_data_generator(u0, tspan, tsteps, model = lorenz) end ``` -Given the initial condition, time span, and time steps, this function returns the data for the chosen model. Now, using the `Hybrid` method, it is possible to input all this information to the model. +Given the initial condition, time span, and time steps, this function returns the data for the chosen model. Now, using the `KnowledgeModel` method, it is possible to input all this information to `HybridESN`. ```@example hybrid using ReservoirComputing, Random Random.seed!(42) -hybrid = Hybrid(prior_model_data_generator, u0, tspan_train, train_len) +km = KnowledgeModel(prior_model_data_generator, u0, tspan_train, train_len) -esn = ESN(input_data, - reservoir = RandSparseReservoir(300), - variation = hybrid) +in_size = 3 +res_size = 300 +hesn = HybridESN(km, + input_data, + in_size, + res_size; + reservoir = rand_sparse) ``` ## Training and Prediction @@ -65,7 +69,7 @@ esn = ESN(input_data, The training and prediction of the Hybrid ESN can proceed as usual: ```@example hybrid -output_layer = train(esn, target_data, StandardRidge(0.3)) +output_layer = train(hesn, target_data, StandardRidge(0.3)) output = esn(Generative(predict_len), output_layer) ``` diff --git a/docs/src/esn_tutorials/lorenz_basic.md b/docs/src/esn_tutorials/lorenz_basic.md index f820a36d..1a66f834 100644 --- a/docs/src/esn_tutorials/lorenz_basic.md +++ b/docs/src/esn_tutorials/lorenz_basic.md @@ -1,6 +1,6 @@ # Lorenz System Forecasting -This example expands on the readme Lorenz system forecasting to better showcase how to use methods and functions provided in the library for Echo State Networks. Here the prediction method used is `Generative`, for a more detailed explanation of the differences between `Generative` and `Predictive` please refer to the other examples given in the documentation. The full script for this example is available [here](https://github.com/MartinuzziFrancesco/reservoir-computing-examples/blob/main/lorenz_basic/lorenz_basic.jl). This example was run on Julia v1.7.2. +This example expands on the readme Lorenz system forecasting to better showcase how to use methods and functions provided in the library for Echo State Networks. Here the prediction method used is `Generative`, for a more detailed explanation of the differences between `Generative` and `Predictive` please refer to the other examples given in the documentation. ## Generating the data @@ -46,15 +46,15 @@ using ReservoirComputing #define ESN parameters res_size = 300 +in_size = 3 res_radius = 1.2 res_sparsity = 6 / 300 input_scaling = 0.1 #build ESN struct -esn = ESN(input_data; - variation = Default(), - reservoir = RandSparseReservoir(res_size, radius = res_radius, sparsity = res_sparsity), - input_layer = WeightedLayer(scaling = input_scaling), +esn = ESN(input_data, in_size, res_size; + reservoir = rand_sparse(; radius = res_radius, sparsity = res_sparsity), + input_layer = weighted_init(; scaling = input_scaling), reservoir_driver = RNN(), nla_type = NLADefault(), states_type = StandardStates()) @@ -62,9 +62,9 @@ esn = ESN(input_data; Most of the parameters chosen here mirror the default ones, so a direct call is not necessary. The readme example is identical to this one, except for the explicit call. Going line by line to see what is happening, starting from `res_size`: this value determines the dimensions of the reservoir matrix. In this case, a size of 300 has been chosen, so the reservoir matrix will be 300 x 300. This is not always the case, since some input layer constructions can modify the dimensions of the reservoir, but in that case, everything is taken care of internally. -The `res_radius` determines the scaling of the spectral radius of the reservoir matrix; a proper scaling is necessary to assure the Echo State Property. The default value in the `RandSparseReservoir()` method is 1.0 in accordance with the most commonly followed guidelines found in the literature (see [^2] and references therein). The `sparsity` of the reservoir matrix in this case is obtained by choosing a degree of connections and dividing that by the reservoir size. Of course, it is also possible to simply choose any value between 0.0 and 1.0 to test behaviors for different sparsity values. In this example, the call to the parameters inside `RandSparseReservoir()` was done explicitly to showcase the meaning of each of them, but it is also possible to simply pass the values directly, like so `RandSparseReservoir(1.2, 6/300)`. +The `res_radius` determines the scaling of the spectral radius of the reservoir matrix; a proper scaling is necessary to assure the Echo State Property. The default value in the `rand_sparse` method is 1.0 in accordance with the most commonly followed guidelines found in the literature (see [^2] and references therein). The `sparsity` of the reservoir matrix in this case is obtained by choosing a degree of connections and dividing that by the reservoir size. Of course, it is also possible to simply choose any value between 0.0 and 1.0 to test behaviors for different sparsity values. -The value of `input_scaling` determines the upper and lower bounds of the uniform distribution of the weights in the `WeightedLayer()`. Like before, this value can be passed either as an argument or as a keyword argument `WeightedLayer(0.1)`. The value of 0.1 represents the default. The default input layer is the `DenseLayer`, a fully connected layer. The details of the weighted version can be found in [^3], for this example, this version returns the best results. +The value of `input_scaling` determines the upper and lower bounds of the uniform distribution of the weights in the `weighted_init`. The value of 0.1 represents the default. The default input layer is the `scaled_rand`, a dense matrix. The details of the weighted version can be found in [^3], for this example, this version returns the best results. The reservoir driver represents the dynamics of the reservoir. In the standard ESN definition, these dynamics are obtained through a Recurrent Neural Network (RNN), and this is reflected by calling the `RNN` driver for the `ESN` struct. This option is set as the default, and unless there is the need to change parameters, it is not needed. The full equation is the following: diff --git a/src/ReservoirComputing.jl b/src/ReservoirComputing.jl index 7a5d97ae..9a186e1a 100644 --- a/src/ReservoirComputing.jl +++ b/src/ReservoirComputing.jl @@ -9,20 +9,21 @@ using LinearAlgebra using MLJLinearModels using NNlib using Optim +using PartialFunctions +using Random using SparseArrays using Statistics +using WeightInitializers export NLADefault, NLAT1, NLAT2, NLAT3 export StandardStates, ExtendedStates, PaddedStates, PaddedExtendedStates export StandardRidge, LinearModel -export AbstractLayer, create_layer -export WeightedLayer, DenseLayer, SparseLayer, MinimumLayer, InformedLayer, NullLayer -export BernoulliSample, IrrationalSample -export AbstractReservoir, create_reservoir -export RandSparseReservoir, PseudoSVDReservoir, DelayLineReservoir -export DelayLineBackwardReservoir, SimpleCycleReservoir, CycleJumpsReservoir, NullReservoir +export scaled_rand, weighted_init, sparse_init, informed_init, minimal_init +export rand_sparse, delay_line, delay_line_backward, cycle_jumps, simple_cycle, pseudo_svd export RNN, MRNN, GRU, GRUParams, FullyGated, Minimal -export ESN, Default, Hybrid, train +export ESN, train +export HybridESN, KnowledgeModel +export DeepESN export RECA, train export RandomMapping, RandomMaps export Generative, Predictive, OutputLayer @@ -72,6 +73,31 @@ function Predictive(prediction_data) Predictive(prediction_data, prediction_len) end +#fallbacks for initializers +for initializer in (:rand_sparse, :delay_line, :delay_line_backward, :cycle_jumps, + :simple_cycle, :pseudo_svd, + :scaled_rand, :weighted_init, :sparse_init, :informed_init, :minimal_init) + NType = ifelse(initializer === :rand_sparse, Real, Number) + @eval function ($initializer)(dims::Integer...; kwargs...) + return $initializer(_default_rng(), Float32, dims...; kwargs...) + end + @eval function ($initializer)(rng::AbstractRNG, dims::Integer...; kwargs...) + return $initializer(rng, Float32, dims...; kwargs...) + end + @eval function ($initializer)(::Type{T}, + dims::Integer...; kwargs...) where {T <: $NType} + return $initializer(_default_rng(), T, dims...; kwargs...) + end + @eval function ($initializer)(rng::AbstractRNG; kwargs...) + return __partial_apply($initializer, (rng, (; kwargs...))) + end + @eval function ($initializer)(rng::AbstractRNG, + ::Type{T}; kwargs...) where {T <: $NType} + return __partial_apply($initializer, ((rng, T), (; kwargs...))) + end + @eval ($initializer)(; kwargs...) = __partial_apply($initializer, (; kwargs...)) +end + #general include("states.jl") include("predict.jl") @@ -84,7 +110,9 @@ include("train/supportvector_regression.jl") include("esn/esn_input_layers.jl") include("esn/esn_reservoirs.jl") include("esn/esn_reservoir_drivers.jl") -include("esn/echostatenetwork.jl") +include("esn/esn.jl") +include("esn/deepesn.jl") +include("esn/hybridesn.jl") include("esn/esn_predict.jl") #reca diff --git a/src/esn/deepesn.jl b/src/esn/deepesn.jl new file mode 100644 index 00000000..8fe8f798 --- /dev/null +++ b/src/esn/deepesn.jl @@ -0,0 +1,46 @@ +struct DeepESN{I, S, N, T, O, M, B, ST, W, IS} <: AbstractEchoStateNetwork + res_size::I + train_data::S + nla_type::N + input_matrix::T + reservoir_driver::O + reservoir_matrix::M + bias_vector::B + states_type::ST + washout::W + states::IS +end + +function DeepESN(train_data, + in_size::Int, + res_size::Int; + depth::Int = 2, + input_layer = fill(scaled_rand, depth), + bias = fill(zeros64, depth), + reservoir = fill(rand_sparse, depth), + reservoir_driver = RNN(), + nla_type = NLADefault(), + states_type = StandardStates(), + washout::Int = 0, + rng = _default_rng(), + T = Float64, + matrix_type = typeof(train_data)) + if states_type isa AbstractPaddedStates + in_size = size(train_data, 1) + 1 + train_data = vcat(Adapt.adapt(matrix_type, ones(1, size(train_data, 2))), + train_data) + end + + reservoir_matrix = [reservoir[i](rng, T, res_size, res_size) for i in 1:depth] + input_matrix = [i == 1 ? input_layer[i](rng, T, res_size, in_size) : + input_layer[i](rng, T, res_size, res_size) for i in 1:depth] + bias_vector = [bias[i](rng, res_size) for i in 1:depth] + inner_res_driver = reservoir_driver_params(reservoir_driver, res_size, in_size) + states = create_states(inner_res_driver, train_data, washout, reservoir_matrix, + input_matrix, bias_vector) + train_data = train_data[:, (washout + 1):end] + + DeepESN(res_size, train_data, nla_type, input_matrix, + inner_res_driver, reservoir_matrix, bias_vector, states_type, washout, + states) +end diff --git a/src/esn/echostatenetwork.jl b/src/esn/echostatenetwork.jl deleted file mode 100644 index bfa088bf..00000000 --- a/src/esn/echostatenetwork.jl +++ /dev/null @@ -1,278 +0,0 @@ -abstract type AbstractEchoStateNetwork <: AbstractReservoirComputer end -struct ESN{I, S, V, N, T, O, M, B, ST, W, IS} <: AbstractEchoStateNetwork - res_size::I - train_data::S - variation::V - nla_type::N - input_matrix::T - reservoir_driver::O - reservoir_matrix::M - bias_vector::B - states_type::ST - washout::W - states::IS -end - -""" - Default() - -The `Default` struct specifies the use of the standard model in Echo State Networks (ESNs). -It requires no parameters and is used when no specific variations or customizations of the ESN model are needed. -This struct is ideal for straightforward applications where the default ESN settings are sufficient. -""" -struct Default <: AbstractVariation end -struct Hybrid{T, K, O, I, S, D} <: AbstractVariation - prior_model::T - u0::K - tspan::O - dt::I - datasize::S - model_data::D -end - -""" - Hybrid(prior_model, u0, tspan, datasize) - -Constructs a `Hybrid` variation of Echo State Networks (ESNs) integrating a knowledge-based model -(`prior_model`) with ESNs for advanced training and prediction in chaotic systems. - -# Parameters - - - `prior_model`: A knowledge-based model function for integration with ESNs. - - `u0`: Initial conditions for the model. - - `tspan`: Time span as a tuple, indicating the duration for model operation. - - `datasize`: The size of the data to be processed. - -# Returns - - - A `Hybrid` struct instance representing the combined ESN and knowledge-based model. - -This method is effective for chaotic processes as highlighted in [^Pathak]. - -Reference: - -[^Pathak]: Jaideep Pathak et al. - "Hybrid Forecasting of Chaotic Processes: - Using Machine Learning in Conjunction with a Knowledge-Based Model" (2018). -""" -function Hybrid(prior_model, u0, tspan, datasize) - trange = collect(range(tspan[1], tspan[2], length = datasize)) - dt = trange[2] - trange[1] - tsteps = push!(trange, dt + trange[end]) - tspan_new = (tspan[1], dt + tspan[2]) - model_data = prior_model(u0, tspan_new, tsteps) - return Hybrid(prior_model, u0, tspan, dt, datasize, model_data) -end - -""" - ESN(train_data; kwargs...) -> ESN - -Creates an Echo State Network (ESN) using specified parameters and training data, suitable for various machine learning tasks. - -# Parameters - - - `train_data`: Matrix of training data (columns as time steps, rows as features). - - `variation`: Variation of ESN (default: `Default()`). - - `input_layer`: Input layer of ESN (default: `DenseLayer()`). - - `reservoir`: Reservoir of the ESN (default: `RandSparseReservoir(100)`). - - `bias`: Bias vector for each time step (default: `NullLayer()`). - - `reservoir_driver`: Mechanism for evolving reservoir states (default: `RNN()`). - - `nla_type`: Non-linear activation type (default: `NLADefault()`). - - `states_type`: Format for storing states (default: `StandardStates()`). - - `washout`: Initial time steps to discard (default: `0`). - - `matrix_type`: Type of matrices used internally (default: type of `train_data`). - -# Returns - - - An initialized ESN instance with specified parameters. - -# Examples - -```julia -using ReservoirComputing - -train_data = rand(10, 100) # 10 features, 100 time steps - -esn = ESN(train_data, reservoir = RandSparseReservoir(200), washout = 10) -``` -""" -function ESN(train_data; - variation = Default(), - input_layer = DenseLayer(), - reservoir = RandSparseReservoir(100), - bias = NullLayer(), - reservoir_driver = RNN(), - nla_type = NLADefault(), - states_type = StandardStates(), - washout = 0, - matrix_type = typeof(train_data)) - if variation isa Hybrid - train_data = vcat(train_data, variation.model_data[:, 1:(end - 1)]) - end - - if states_type isa AbstractPaddedStates - in_size = size(train_data, 1) + 1 - train_data = vcat(Adapt.adapt(matrix_type, ones(1, size(train_data, 2))), - train_data) - else - in_size = size(train_data, 1) - end - - input_matrix, reservoir_matrix, bias_vector, res_size = obtain_layers(in_size, - input_layer, - reservoir, bias; - matrix_type = matrix_type) - - inner_res_driver = reservoir_driver_params(reservoir_driver, res_size, in_size) - states = create_states(inner_res_driver, train_data, washout, reservoir_matrix, - input_matrix, bias_vector) - train_data = train_data[:, (washout + 1):end] - - ESN(sum(res_size), train_data, variation, nla_type, input_matrix, - inner_res_driver, reservoir_matrix, bias_vector, states_type, washout, - states) -end - -#shallow esn construction -function obtain_layers(in_size, - input_layer, - reservoir, - bias; - matrix_type = Matrix{Float64}) - input_res_size = get_ressize(reservoir) - input_matrix = create_layer(input_layer, input_res_size, in_size, - matrix_type = matrix_type) - res_size = size(input_matrix, 1) #WeightedInput actually changes the res size - reservoir_matrix = create_reservoir(reservoir, res_size, matrix_type = matrix_type) - @assert size(reservoir_matrix, 1) == res_size - bias_vector = create_layer(bias, res_size, 1, matrix_type = matrix_type) - return input_matrix, reservoir_matrix, bias_vector, res_size -end - -#deep esn construction -#there is a bug going on with WeightedLayer in this construction. -#it works for eny other though -function obtain_layers(in_size, - input_layer, - reservoir::Vector, - bias; - matrix_type = Matrix{Float64}) - esn_depth = length(reservoir) - input_res_sizes = [get_ressize(reservoir[i]) for i in 1:esn_depth] - in_sizes = zeros(Int, esn_depth) - in_sizes[2:end] = input_res_sizes[1:(end - 1)] - in_sizes[1] = in_size - - if input_layer isa Array - input_matrix = [create_layer(input_layer[j], input_res_sizes[j], in_sizes[j], - matrix_type = matrix_type) for j in 1:esn_depth] - else - _input_layer = fill(input_layer, esn_depth) - input_matrix = [create_layer(_input_layer[k], input_res_sizes[k], in_sizes[k], - matrix_type = matrix_type) for k in 1:esn_depth] - end - - res_sizes = [get_ressize(input_matrix[j]) for j in 1:esn_depth] - reservoir_matrix = [create_reservoir(reservoir[k], res_sizes[k], - matrix_type = matrix_type) for k in 1:esn_depth] - - if bias isa Array - bias_vector = [create_layer(bias[j], res_sizes[j], 1, matrix_type = matrix_type) - for j in 1:esn_depth] - else - _bias = fill(bias, esn_depth) - bias_vector = [create_layer(_bias[k], res_sizes[k], 1, matrix_type = matrix_type) - for k in 1:esn_depth] - end - - return input_matrix, reservoir_matrix, bias_vector, res_sizes -end - -function (esn::ESN)(prediction::AbstractPrediction, - output_layer::AbstractOutputLayer; - last_state = esn.states[:, [end]], - kwargs...) - variation = esn.variation - pred_len = prediction.prediction_len - - if variation isa Hybrid - model = variation.prior_model - predict_tsteps = [variation.tspan[2] + variation.dt] - [append!(predict_tsteps, predict_tsteps[end] + variation.dt) for i in 1:pred_len] - tspan_new = (variation.tspan[2] + variation.dt, predict_tsteps[end]) - u0 = variation.model_data[:, end] - model_pred_data = model(u0, tspan_new, predict_tsteps)[:, 2:end] - return obtain_esn_prediction(esn, prediction, last_state, output_layer, - model_pred_data; - kwargs...) - else - return obtain_esn_prediction(esn, prediction, last_state, output_layer; - kwargs...) - end -end - -#training dispatch on esn -""" - train(esn::AbstractEchoStateNetwork, target_data, training_method = StandardRidge(0.0)) - -Trains an Echo State Network (ESN) using the provided target data and a specified training method. - -# Parameters - - - `esn::AbstractEchoStateNetwork`: The ESN instance to be trained. - - `target_data`: Supervised training data for the ESN. - - `training_method`: The method for training the ESN (default: `StandardRidge(0.0)`). - -# Returns - - - The trained ESN model. Its type and structure depend on `training_method` and the ESN's implementation. - -# Returns - -The trained ESN model. The exact type and structure of the return value depends on the -`training_method` and the specific ESN implementation. - -```julia -using ReservoirComputing - -# Initialize an ESN instance and target data -esn = ESN(train_data, reservoir = RandSparseReservoir(200), washout = 10) -target_data = rand(size(train_data, 2)) - -# Train the ESN using the default training method -trained_esn = train(esn, target_data) - -# Train the ESN using a custom training method -trained_esn = train(esn, target_data, training_method = StandardRidge(1.0)) -``` - -# Notes - - - When using a `Hybrid` variation, the function extends the state matrix with data from the - physical model included in the `variation`. - - The training is handled by a lower-level `_train` function which takes the new state matrix - and performs the actual training using the specified `training_method`. -""" -function train(esn::AbstractEchoStateNetwork, - target_data, - training_method = StandardRidge(0.0)) - variation = esn.variation - - if esn.variation isa Hybrid - states = vcat(esn.states, esn.variation.model_data[:, 2:end]) - else - states = esn.states - end - states_new = esn.states_type(esn.nla_type, states, esn.train_data[:, 1:end]) - - return _train(states_new, target_data, training_method) -end - -function pad_esnstate(variation::Hybrid, states_type, x_pad, x, model_prediction_data) - x_tmp = vcat(x, model_prediction_data) - x_pad = pad_state!(states_type, x_pad, x_tmp) -end - -function pad_esnstate!(variation, states_type, x_pad, x, args...) - x_pad = pad_state!(states_type, x_pad, x) -end diff --git a/src/esn/esn.jl b/src/esn/esn.jl new file mode 100644 index 00000000..2b351ce3 --- /dev/null +++ b/src/esn/esn.jl @@ -0,0 +1,146 @@ +abstract type AbstractEchoStateNetwork <: AbstractReservoirComputer end +struct ESN{I, S, N, T, O, M, B, ST, W, IS} <: AbstractEchoStateNetwork + res_size::I + train_data::S + nla_type::N + input_matrix::T + reservoir_driver::O + reservoir_matrix::M + bias_vector::B + states_type::ST + washout::W + states::IS +end + +""" + ESN(train_data; kwargs...) -> ESN + +Creates an Echo State Network (ESN) using specified parameters and training data, suitable for various machine learning tasks. + +# Parameters + + - `train_data`: Matrix of training data (columns as time steps, rows as features). + - `variation`: Variation of ESN (default: `Default()`). + - `input_layer`: Input layer of ESN (default: `DenseLayer()`). + - `reservoir`: Reservoir of the ESN (default: `RandSparseReservoir(100)`). + - `bias`: Bias vector for each time step (default: `NullLayer()`). + - `reservoir_driver`: Mechanism for evolving reservoir states (default: `RNN()`). + - `nla_type`: Non-linear activation type (default: `NLADefault()`). + - `states_type`: Format for storing states (default: `StandardStates()`). + - `washout`: Initial time steps to discard (default: `0`). + - `matrix_type`: Type of matrices used internally (default: type of `train_data`). + +# Returns + + - An initialized ESN instance with specified parameters. + +# Examples + +```julia +using ReservoirComputing + +train_data = rand(10, 100) # 10 features, 100 time steps + +esn = ESN(train_data, reservoir = RandSparseReservoir(200), washout = 10) +``` +""" +function ESN(train_data, + in_size::Int, + res_size::Int; + input_layer = scaled_rand, + reservoir = rand_sparse, + bias = zeros64, + reservoir_driver = RNN(), + nla_type = NLADefault(), + states_type = StandardStates(), + washout = 0, + rng = _default_rng(), + T = Float32, + matrix_type = typeof(train_data)) + if states_type isa AbstractPaddedStates + in_size = size(train_data, 1) + 1 + train_data = vcat(Adapt.adapt(matrix_type, ones(1, size(train_data, 2))), + train_data) + end + + reservoir_matrix = reservoir(rng, T, res_size, res_size) + input_matrix = input_layer(rng, T, res_size, in_size) + bias_vector = bias(rng, res_size) + inner_res_driver = reservoir_driver_params(reservoir_driver, res_size, in_size) + states = create_states(inner_res_driver, train_data, washout, reservoir_matrix, + input_matrix, bias_vector) + train_data = train_data[:, (washout + 1):end] + + ESN(res_size, train_data, nla_type, input_matrix, + inner_res_driver, reservoir_matrix, bias_vector, states_type, washout, + states) +end + +function (esn::AbstractEchoStateNetwork)(prediction::AbstractPrediction, + output_layer::AbstractOutputLayer; + last_state = esn.states[:, [end]], + kwargs...) + pred_len = prediction.prediction_len + + return obtain_esn_prediction(esn, prediction, last_state, output_layer; + kwargs...) +end + +#training dispatch on esn +""" + train(esn::AbstractEchoStateNetwork, target_data, training_method = StandardRidge(0.0)) + +Trains an Echo State Network (ESN) using the provided target data and a specified training method. + +# Parameters + + - `esn::AbstractEchoStateNetwork`: The ESN instance to be trained. + - `target_data`: Supervised training data for the ESN. + - `training_method`: The method for training the ESN (default: `StandardRidge(0.0)`). + +# Returns + + - The trained ESN model. Its type and structure depend on `training_method` and the ESN's implementation. + +# Returns + +The trained ESN model. The exact type and structure of the return value depends on the +`training_method` and the specific ESN implementation. + +```julia +using ReservoirComputing + +# Initialize an ESN instance and target data +esn = ESN(train_data, reservoir = RandSparseReservoir(200), washout = 10) +target_data = rand(size(train_data, 2)) + +# Train the ESN using the default training method +trained_esn = train(esn, target_data) + +# Train the ESN using a custom training method +trained_esn = train(esn, target_data, training_method = StandardRidge(1.0)) +``` + +# Notes + + - When using a `Hybrid` variation, the function extends the state matrix with data from the + physical model included in the `variation`. + - The training is handled by a lower-level `_train` function which takes the new state matrix + and performs the actual training using the specified `training_method`. +""" +function train(esn::AbstractEchoStateNetwork, + target_data, + training_method = StandardRidge(0.0)) + states_new = esn.states_type(esn.nla_type, esn.states, esn.train_data[:, 1:end]) + + return _train(states_new, target_data, training_method) +end + +#function pad_esnstate(variation::Hybrid, states_type, x_pad, x, model_prediction_data) +# x_tmp = vcat(x, model_prediction_data) +# x_pad = pad_state!(states_type, x_pad, x_tmp) +#end + +#function pad_esnstate!(variation, states_type, x_pad, x, args...) +# x_pad = pad_state!(states_type, x_pad, x) +#end diff --git a/src/esn/esn_input_layers.jl b/src/esn/esn_input_layers.jl index 36b553a1..58aab7c4 100644 --- a/src/esn/esn_input_layers.jl +++ b/src/esn/esn_input_layers.jl @@ -1,394 +1,271 @@ -abstract type AbstractLayer end - -struct WeightedLayer{T} <: AbstractLayer - scaling::T -end - """ - WeightedInput(scaling) - WeightedInput(;scaling=0.1) + scaled_rand(rng::AbstractRNG, ::Type{T}, dims::Integer...; scaling=T(0.1)) where {T <: Number} -Creates a `WeightedInput` layer initializer for Echo State Networks. -This initializer generates a weighted input matrix with random non-zero -elements distributed uniformly within the range [-`scaling`, `scaling`], -following the approach in [^Lu]. +Create and return a matrix with random values, uniformly distributed within a range defined by `scaling`. This function is useful for initializing matrices, such as the layers of a neural network, with scaled random values. -# Parameters +# Arguments - - `scaling`: The scaling factor for the weight distribution (default: 0.1). + - `rng`: An instance of `AbstractRNG` for random number generation. + - `T`: The data type for the elements of the matrix. + - `dims`: Dimensions of the matrix. It must be a 2-element tuple specifying the number of rows and columns (e.g., `(res_size, in_size)`). + - `scaling`: A scaling factor to define the range of the uniform distribution. The matrix elements will be randomly chosen from the range `[-scaling, scaling]`. Defaults to `T(0.1)`. # Returns - - A `WeightedInput` instance to be used for initializing the input layer of an ESN. +A matrix of type with dimensions specified by `dims`. Each element of the matrix is a random number uniformly distributed between `-scaling` and `scaling`. -Reference: +# Example -[^Lu]: Lu, Zhixin, et al. - "Reservoir observers: Model-free inference of unmeasured variables in chaotic systems." - Chaos: An Interdisciplinary Journal of Nonlinear Science 27.4 (2017): 041102. +```julia +rng = Random.default_rng() +matrix = scaled_rand(rng, Float64, (100, 50); scaling = 0.2) +``` """ -function WeightedLayer(; scaling = 0.1) - return WeightedLayer(scaling) -end - -function create_layer(input_layer::WeightedLayer, - approx_res_size, - in_size; - matrix_type = Matrix{Float64}) - scaling = input_layer.scaling - res_size = Int(floor(approx_res_size / in_size) * in_size) - layer_matrix = zeros(res_size, in_size) - q = floor(Int, res_size / in_size) - - for i in 1:in_size - layer_matrix[((i - 1) * q + 1):((i) * q), i] = rand(Uniform(-scaling, scaling), 1, - q) - end - - return Adapt.adapt(matrix_type, layer_matrix) -end - -function create_layer(layer, args...; kwargs...) - return layer +function scaled_rand(rng::AbstractRNG, + ::Type{T}, + dims::Integer...; + scaling = T(0.1)) where {T <: Number} + res_size, in_size = dims + layer_matrix = T.(rand(rng, Uniform(-scaling, scaling), res_size, in_size)) + return layer_matrix end """ - DenseLayer(scaling) - DenseLayer(;scaling=0.1) + weighted_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; scaling=T(0.1)) where {T <: Number} -Creates a `DenseLayer` initializer for Echo State Networks, generating a fully connected input layer. -The layer is initialized with random weights uniformly distributed within [-`scaling`, `scaling`]. -This scaling factor can be provided either as an argument or a keyword argument. -The `DenseLayer` is the default input layer in `ESN` construction. +Create and return a matrix representing a weighted input layer for Echo State Networks (ESNs). This initializer generates a weighted input matrix with random non-zero elements distributed uniformly within the range [-`scaling`, `scaling`], inspired by the approach in [^Lu]. -# Parameters +# Arguments - - `scaling`: The scaling factor for weight distribution (default: 0.1). + - `rng`: An instance of `AbstractRNG` for random number generation. + - `T`: The data type for the elements of the matrix. + - `dims`: A 2-element tuple specifying the approximate reservoir size and input size (e.g., `(approx_res_size, in_size)`). + - `scaling`: The scaling factor for the weight distribution. Defaults to `T(0.1)`. # Returns - - A `DenseLayer` instance for initializing the ESN's input layer. -""" -struct DenseLayer{T} <: AbstractLayer - scaling::T -end - -function DenseLayer(; scaling = 0.1) - return DenseLayer(scaling) -end +A matrix representing the weighted input layer as defined in [^Lu2017]. The matrix dimensions will be adjusted to ensure each input unit connects to an equal number of reservoir units. -""" - create_layer(input_layer::AbstractLayer, res_size, in_size) +# Example -Generates a matrix layer of size `res_size` x `in_size`, constructed according to the specifications of the `input_layer`. +```julia +rng = Random.default_rng() +input_layer = weighted_init(rng, Float64, (3, 300); scaling = 0.2) +``` -# Parameters +# References - - `input_layer`: An instance of `AbstractLayer` determining the layer construction. - - `res_size`: The number of rows (reservoir size) for the layer. - - `in_size`: The number of columns (input size) for the layer. +[^Lu2017]: Lu, Zhixin, et al. + "Reservoir observers: Model-free inference of unmeasured variables in chaotic systems." + Chaos: An Interdisciplinary Journal of Nonlinear Science 27.4 (2017): 041102. +""" +function weighted_init(rng::AbstractRNG, + ::Type{T}, + dims::Integer...; + scaling = T(0.1)) where {T <: Number} + approx_res_size, in_size = dims + res_size = Int(floor(approx_res_size / in_size) * in_size) + layer_matrix = zeros(T, res_size, in_size) + q = floor(Int, res_size / in_size) -# Returns + for i in 1:in_size + layer_matrix[((i - 1) * q + 1):((i) * q), i] = rand(rng, + Uniform(-scaling, scaling), + q) + end - - A matrix representing the constructed layer. -""" -function create_layer(input_layer::DenseLayer, - res_size, - in_size; - matrix_type = Matrix{Float64}) - scaling = input_layer.scaling - layer_matrix = rand(Uniform(-scaling, scaling), res_size, in_size) - return Adapt.adapt(matrix_type, layer_matrix) + return layer_matrix end +# TODO: @MartinuzziFrancesco remove when pr gets into WeightInitializers """ - SparseLayer(scaling, sparsity) - SparseLayer(scaling; sparsity=0.1) - SparseLayer(;scaling=0.1, sparsity=0.1) + sparse_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; scaling=T(0.1), sparsity=T(0.1)) where {T <: Number} -Creates a `SparseLayer` initializer for Echo State Networks, generating a sparse input layer. -The layer is initialized with weights distributed within [-`scaling`, `scaling`] -and a specified `sparsity` level. Both `scaling` and `sparsity` can be set as arguments or keyword arguments. +Create and return a sparse layer matrix for use in neural network models. +The matrix will be of size specified by `dims`, with the specified `sparsity` and `scaling`. -# Parameters +# Arguments - - `scaling`: Scaling factor for weight distribution (default: 0.1). - - `sparsity`: Sparsity level of the layer (default: 0.1). + - `rng`: An instance of `AbstractRNG` for random number generation. + - `T`: The data type for the elements of the matrix. + - `dims`: Dimensions of the resulting sparse layer matrix. + - `scaling`: The scaling factor for the sparse layer matrix. Defaults to 0.1. + - `sparsity`: The sparsity level of the sparse layer matrix, controlling the fraction of zero elements. Defaults to 0.1. # Returns - - A `SparseLayer` instance for initializing ESN's input layer with sparse connections. -""" -struct SparseLayer{T} <: AbstractLayer - scaling::T - sparsity::T -end - -function SparseLayer(; scaling = 0.1, sparsity = 0.1) - return SparseLayer(scaling, sparsity) -end +A sparse layer matrix. -function SparseLayer(scaling_arg; scaling = scaling_arg, sparsity = 0.1) - return SparseLayer(scaling, sparsity) -end +# Example -function create_layer(input_layer::SparseLayer, - res_size, - in_size; - matrix_type = Matrix{Float64}) - layer_matrix = Matrix(sprand(res_size, in_size, input_layer.sparsity)) - layer_matrix = 2.0 .* (layer_matrix .- 0.5) - replace!(layer_matrix, -1.0 => 0.0) - layer_matrix = input_layer.scaling .* layer_matrix - return Adapt.adapt(matrix_type, layer_matrix) -end - -#from "minimum complexity echo state network" Rodan -#and "simple deterministically constructed cycle reservoirs with regular jumps" -#by Rodan and Tino -struct BernoulliSample{T} - p::T +```julia +rng = Random.default_rng() +input_layer = sparse_init(rng, Float64, (3, 300); scaling = 0.2, sparsity = 0.1) +``` +""" +function sparse_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; + scaling = T(0.1), sparsity = T(0.1)) where {T <: Number} + res_size, in_size = dims + layer_matrix = Matrix(sprand(rng, T, res_size, in_size, sparsity)) + layer_matrix = T.(2.0) .* (layer_matrix .- T.(0.5)) + replace!(layer_matrix, T(-1.0) => T(0.0)) + layer_matrix = scaling .* layer_matrix + + return layer_matrix end """ - BernoulliSample(p) - BernoulliSample(;p=0.5) + informed_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; scaling=T(0.1), model_in_size, gamma=T(0.5)) where {T <: Number} -Creates a `BernoulliSample` constructor for the `MinimumLayer`. -It uses a Bernoulli distribution to determine the sign of weights in the input layer. -The parameter `p` sets the probability of a weight being positive, as per the `Distributions` package. -This method of sign weight determination for input layers is based on the approach in [^Rodan]. +Create a layer of a neural network. -# Parameters +# Arguments - - `p`: Probability of a positive weight (default: 0.5). + - `rng::AbstractRNG`: The random number generator. + - `T::Type`: The data type. + - `dims::Integer...`: The dimensions of the layer. + - `scaling::T = T(0.1)`: The scaling factor for the input matrix. + - `model_in_size`: The size of the input model. + - `gamma::T = T(0.5)`: The gamma value. # Returns - - A `BernoulliSample` instance for generating sign weights in `MinimumLayer`. - -Reference: - -[^Rodan]: Rodan, Ali, and Peter Tino. - "Minimum complexity echo state network." - IEEE Transactions on Neural Networks 22.1 (2010): 131-144. -""" -function BernoulliSample(; p = 0.5) - return BernoulliSample(p) -end + - `input_matrix`: The created input matrix for the layer. -struct IrrationalSample{K} - irrational::Irrational - start::K -end +# Example +```julia +rng = Random.default_rng() +dims = (100, 200) +model_in_size = 50 +input_matrix = informed_init(rng, Float64, dims; model_in_size = model_in_size) +``` """ - IrrationalSample(irrational, start) - IrrationalSample(;irrational=pi, start=1) - -Creates an `IrrationalSample` constructor for the `MinimumLayer`. -It determines the sign of weights in the input layer based on the decimal expansion of an `irrational` number. -The `start` parameter sets the starting point in the decimal sequence. -The signs are assigned based on the thresholding of each decimal digit against 4.5, as described in [^Rodan]. - -# Parameters +function informed_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; + scaling = T(0.1), model_in_size, gamma = T(0.5)) where {T <: Number} + res_size, in_size = dims + state_size = in_size - model_in_size - - `irrational`: An irrational number for weight sign determination (default: π). - - `start`: Starting index in the decimal expansion (default: 1). - -# Returns + if state_size <= 0 + throw(DimensionMismatch("in_size must be greater than model_in_size")) + end - - An `IrrationalSample` instance for generating sign weights in `MinimumLayer`. + input_matrix = zeros(res_size, in_size) + zero_connections = zeros(in_size) + num_for_state = floor(Int, res_size * gamma) + num_for_model = floor(Int, res_size * (1 - gamma)) -Reference: + for i in 1:num_for_state + idxs = findall(Bool[zero_connections .== input_matrix[i, :] + for i in 1:size(input_matrix, 1)]) + random_row_idx = idxs[rand(rng, 1:end)] + random_clm_idx = range(1, state_size, step = 1)[rand(rng, 1:end)] + input_matrix[random_row_idx, random_clm_idx] = rand(rng, Uniform(-scaling, scaling)) + end -[^Rodan]: Rodan, Ali, and Peter Tiňo. - "Simple deterministically constructed cycle reservoirs with regular jumps." - Neural Computation 24.7 (2012): 1822-1852. -""" -function IrrationalSample(; irrational = pi, start = 1) - return IrrationalSample(irrational, start) -end + for i in 1:num_for_model + idxs = findall(Bool[zero_connections .== input_matrix[i, :] + for i in 1:size(input_matrix, 1)]) + random_row_idx = idxs[rand(rng, 1:end)] + random_clm_idx = range(state_size + 1, in_size, step = 1)[rand(rng, 1:end)] + input_matrix[random_row_idx, random_clm_idx] = rand(rng, Uniform(-scaling, scaling)) + end -struct MinimumLayer{T, K} <: AbstractLayer - weight::T - sampling::K + return input_matrix end """ - MinimumLayer(weight, sampling) - MinimumLayer(weight; sampling=BernoulliSample(0.5)) - MinimumLayer(;weight=0.1, sampling=BernoulliSample(0.5)) + irrational_sample_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; + weight = 0.1, + sampling = IrrationalSample(; irrational = pi, start = 1) + ) where {T <: Number} -Creates a `MinimumLayer` initializer for Echo State Networks, generating a fully connected input layer. -This layer has a uniform absolute weight value (`weight`) with the sign of each -weight determined by the `sampling` method. This approach, as detailed in [^Rodan1] and [^Rodan2], -allows for controlled weight distribution in the layer. +Create a layer matrix using the provided random number generator and sampling parameters. -# Parameters +# Arguments - - `weight`: Absolute value of weights in the layer. - - `sampling`: Method for determining the sign of weights (default: `BernoulliSample(0.5)`). + - `rng::AbstractRNG`: The random number generator used to generate random numbers. + - `dims::Integer...`: The dimensions of the layer matrix. + - `weight`: The weight used to fill the layer matrix. Default is 0.1. + - `sampling`: The sampling parameters used to generate the input matrix. Default is IrrationalSample(irrational = pi, start = 1). # Returns - - A `MinimumLayer` instance for initializing the ESN's input layer. +The layer matrix generated using the provided random number generator and sampling parameters. -References: +# Example -[^Rodan1]: Rodan, Ali, and Peter Tino. - "Minimum complexity echo state network." - IEEE Transactions on Neural Networks 22.1 (2010): 131-144. -[^Rodan2]: Rodan, Ali, and Peter Tiňo. - "Simple deterministically constructed cycle reservoirs with regular jumps." - Neural Computation 24.7 (2012): 1822-1852. +```julia +using Random +rng = Random.default_rng() +dims = (3, 2) +weight = 0.5 +layer_matrix = irrational_sample_init(rng, Float64, dims; weight = weight, + sampling = IrrationalSample(irrational = sqrt(2), start = 1)) +``` """ -function MinimumLayer(weight; sampling = BernoulliSample(0.5)) - return MinimumLayer(weight, sampling) -end - -function MinimumLayer(; weight = 0.1, sampling = BernoulliSample(0.5)) - return MinimumLayer(weight, sampling) -end - -function create_layer(input_layer::MinimumLayer, - res_size, - in_size; - matrix_type = Matrix{Float64}) - sampling = input_layer.sampling - weight = input_layer.weight - layer_matrix = create_minimum_input(sampling, res_size, in_size, weight) - return Adapt.adapt(matrix_type, layer_matrix) +function minimal_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; + sampling_type::Symbol = :bernoulli, + weight::Number = T(0.1), + irrational::Real = pi, + start::Int = 1, + p::Number = T(0.5)) where {T <: Number} + res_size, in_size = dims + if sampling_type == :bernoulli + layer_matrix = _create_bernoulli(p, res_size, in_size, weight, rng, T) + elseif sampling_type == :irrational + layer_matrix = _create_irrational(irrational, + start, + res_size, + in_size, + weight, + rng, + T) + else + error("Sampling type not allowed. Please use one of :bernoulli or :irrational") + end + return layer_matrix end -function create_minimum_input(sampling::BernoulliSample, res_size, in_size, weight) - p = sampling.p - input_matrix = zeros(res_size, in_size) +function _create_bernoulli(p::T, + res_size::Int, + in_size::Int, + weight::T, + rng::AbstractRNG, + ::Type{T}) where {T <: Number} + input_matrix = zeros(T, res_size, in_size) for i in 1:res_size for j in 1:in_size - rand(Bernoulli(p)) ? input_matrix[i, j] = weight : input_matrix[i, j] = -weight + rand(rng, Bernoulli(p)) ? (input_matrix[i, j] = weight) : + (input_matrix[i, j] = -weight) end end - return input_matrix end -function create_minimum_input(sampling::IrrationalSample, res_size, in_size, weight) - setprecision(BigFloat, Int(ceil(log2(10) * (res_size * in_size + sampling.start + 1)))) - ir_string = string(BigFloat(sampling.irrational)) |> collect +function _create_irrational(irrational::Irrational, + start::Int, + res_size::Int, + in_size::Int, + weight::T, + rng::AbstractRNG, + ::Type{T}) where {T <: Number} + setprecision(BigFloat, Int(ceil(log2(10) * (res_size * in_size + start + 1)))) + ir_string = string(BigFloat(irrational)) |> collect deleteat!(ir_string, findall(x -> x == '.', ir_string)) ir_array = zeros(length(ir_string)) - input_matrix = zeros(res_size, in_size) + input_matrix = zeros(T, res_size, in_size) for i in 1:length(ir_string) ir_array[i] = parse(Int, ir_string[i]) end - co = sampling.start - counter = 1 - for i in 1:res_size for j in 1:in_size - ir_array[counter] < 5 ? input_matrix[i, j] = -weight : - input_matrix[i, j] = weight - counter += 1 + random_number = rand(rng, T) + input_matrix[i, j] = random_number < 0.5 ? -weight : weight end end - return input_matrix -end - -struct InformedLayer{T, K, M} <: AbstractLayer - scaling::T - gamma::K - model_in_size::M -end - -""" - InformedLayer(model_in_size; scaling=0.1, gamma=0.5) - -Creates an `InformedLayer` initializer for Echo State Networks (ESNs) that generates -a weighted input layer matrix. The matrix contains random non-zero elements drawn from -the range [-`scaling`, `scaling`]. This initializer ensures that a fraction (`gamma`) -of reservoir nodes are exclusively connected to the raw inputs, while the rest are -connected to the outputs of a prior knowledge model, as described in [^Pathak]. - -# Arguments - - - `model_in_size`: The size of the prior knowledge model's output, - which determines the number of columns in the input layer matrix. - -# Keyword Arguments - - - `scaling`: The absolute value of the weights (default: 0.1). - - `gamma`: The fraction of reservoir nodes connected exclusively to raw inputs (default: 0.5). - -# Returns - - - An `InformedLayer` instance for initializing the ESN's input layer matrix. - -Reference: - -[^Pathak]: Jaideep Pathak et al. - "Hybrid Forecasting of Chaotic Processes: Using Machine Learning in Conjunction with a Knowledge-Based Model" (2018). -""" -function InformedLayer(model_in_size; scaling = 0.1, gamma = 0.5) - return InformedLayer(scaling, gamma, model_in_size) -end - -function create_layer(input_layer::InformedLayer, - res_size, - in_size; - matrix_type = Matrix{Float64}) - scaling = input_layer.scaling - state_size = in_size - input_layer.model_in_size - - if state_size <= 0 - throw(DimensionMismatch("in_size must be greater than model_in_size")) - end - - input_matrix = zeros(res_size, in_size) - #Vector used to find res nodes not yet connected - zero_connections = zeros(in_size) - #Num of res nodes allotted for raw states - num_for_state = floor(Int, res_size * input_layer.gamma) - #Num of res nodes allotted for prior model input - num_for_model = floor(Int, (res_size * (1 - input_layer.gamma))) - - for i in 1:num_for_state - #find res nodes with no connections - idxs = findall(Bool[zero_connections == input_matrix[i, :] - for i in 1:size(input_matrix, 1)]) - random_row_idx = idxs[rand(1:end)] - random_clm_idx = range(1, state_size, step = 1)[rand(1:end)] - input_matrix[random_row_idx, random_clm_idx] = rand(Uniform(-scaling, scaling)) - end - - for i in 1:num_for_model - idxs = findall(Bool[zero_connections == input_matrix[i, :] - for i in 1:size(input_matrix, 1)]) - random_row_idx = idxs[rand(1:end)] - random_clm_idx = range(state_size + 1, in_size, step = 1)[rand(1:end)] - input_matrix[random_row_idx, random_clm_idx] = rand(Uniform(-scaling, scaling)) - end - - return Adapt.adapt(matrix_type, input_matrix) -end - -""" - NullLayer() - -Creates a `NullLayer` initializer for Echo State Networks (ESNs) that generates a vector of zeros. - -# Returns - - - A `NullLayer` instance for initializing the ESN's input layer matrix. -""" -struct NullLayer <: AbstractLayer end - -function create_layer(input_layer::NullLayer, - res_size, - in_size; - matrix_type = Matrix{Float64}) - return Adapt.adapt(matrix_type, zeros(res_size, in_size)) + return T.(input_matrix) end diff --git a/src/esn/esn_predict.jl b/src/esn/esn_predict.jl index daa6fc34..cc7cdc5d 100644 --- a/src/esn/esn_predict.jl +++ b/src/esn/esn_predict.jl @@ -13,7 +13,7 @@ function obtain_esn_prediction(esn, out = initial_conditions states = similar(esn.states, size(esn.states, 1), prediction_len) - out_pad = allocate_outpad(esn.variation, esn.states_type, out) + out_pad = allocate_outpad(esn, esn.states_type, out) tmp_array = allocate_tmp(esn.reservoir_driver, typeof(esn.states), esn.res_size) x_new = esn.states_type(esn.nla_type, x, out_pad) @@ -43,7 +43,7 @@ function obtain_esn_prediction(esn, out = initial_conditions states = similar(esn.states, size(esn.states, 1), prediction_len) - out_pad = allocate_outpad(esn.variation, esn.states_type, out) + out_pad = allocate_outpad(esn, esn.states_type, out) tmp_array = allocate_tmp(esn.reservoir_driver, typeof(esn.states), esn.res_size) x_new = esn.states_type(esn.nla_type, x, out_pad) @@ -59,31 +59,18 @@ function obtain_esn_prediction(esn, end #prediction dispatch on esn -function next_state_prediction!(esn::ESN, x, x_new, out, out_pad, i, tmp_array, args...) - return _variation_prediction!(esn.variation, esn, x, x_new, out, out_pad, i, tmp_array, - args...) -end - -#dispatch the prediction on the esn variation -function _variation_prediction!(variation, - esn, - x, - x_new, - out, - out_pad, - i, - tmp_array, - args...) +function next_state_prediction!( + esn::AbstractEchoStateNetwork, x, x_new, out, out_pad, i, tmp_array, args...) out_pad = pad_state!(esn.states_type, out_pad, out) xv = @view x[1:(esn.res_size)] - x = next_state!(x, esn.reservoir_driver, xv, out_pad, + x = next_state!(x, esn.reservoir_driver, x, out_pad, esn.reservoir_matrix, esn.input_matrix, esn.bias_vector, tmp_array) x_new = esn.states_type(esn.nla_type, x, out_pad) return x, x_new end -function _variation_prediction!(variation::Hybrid, - esn, +#TODO fixme @MatrinuzziFra +function next_state_prediction!(hesn::HybridESN, x, x_new, out, @@ -92,20 +79,20 @@ function _variation_prediction!(variation::Hybrid, tmp_array, model_prediction_data) out_tmp = vcat(out, model_prediction_data[:, i]) - out_pad = pad_state!(esn.states_type, out_pad, out_tmp) - x = next_state!(x, esn.reservoir_driver, x[1:(esn.res_size)], out_pad, - esn.reservoir_matrix, esn.input_matrix, esn.bias_vector, tmp_array) + out_pad = pad_state!(hesn.states_type, out_pad, out_tmp) + x = next_state!(x, hesn.reservoir_driver, x[1:(hesn.res_size)], out_pad, + hesn.reservoir_matrix, hesn.input_matrix, hesn.bias_vector, tmp_array) x_tmp = vcat(x, model_prediction_data[:, i]) - x_new = esn.states_type(esn.nla_type, x_tmp, out_pad) + x_new = hesn.states_type(hesn.nla_type, x_tmp, out_pad) return x, x_new end -function allocate_outpad(variation, states_type, out) +function allocate_outpad(ens::AbstractEchoStateNetwork, states_type, out) return allocate_singlepadding(states_type, out) end -function allocate_outpad(variation::Hybrid, states_type, out) - pad_length = length(out) + size(variation.model_data[:, 1], 1) +function allocate_outpad(hesn::HybridESN, states_type, out) + pad_length = length(out) + size(hesn.model.model_data[:, 1], 1) out_tmp = Adapt.adapt(typeof(out), zeros(pad_length)) return allocate_singlepadding(states_type, out_tmp) end diff --git a/src/esn/esn_reservoir_drivers.jl b/src/esn/esn_reservoir_drivers.jl index 97e7b548..76cabae6 100644 --- a/src/esn/esn_reservoir_drivers.jl +++ b/src/esn/esn_reservoir_drivers.jl @@ -132,7 +132,7 @@ end function next_state!(out, rnn::RNN, x, y, W::Vector, W_in, b, tmp_array) esn_depth = length(W) - res_sizes = vcat(0, [get_ressize(W[i]) for i in 1:esn_depth]) + res_sizes = vcat(0, [size(W[i], 1) for i in 1:esn_depth]) inner_states = [x[(1 + sum(res_sizes[1:i])):sum(res_sizes[1:(i + 1)])] for i in 1:esn_depth] inner_inputs = vcat([y], inner_states[1:(end - 1)]) @@ -291,9 +291,9 @@ A GRUParams object containing the parameters needed for the GRU-based reservoir arXiv preprint arXiv:1406.1078 (2014). """ function GRU(; activation_function = [NNlib.sigmoid, NNlib.sigmoid, tanh], - inner_layer = fill(DenseLayer(), 2), - reservoir = fill(RandSparseReservoir(0), 2), - bias = fill(DenseLayer(), 2), + inner_layer = fill(scaled_rand, 2), + reservoir = fill(rand_sparse, 2), + bias = fill(scaled_rand, 2), variant = FullyGated()) return GRU(activation_function, inner_layer, reservoir, bias, variant) end @@ -318,22 +318,22 @@ end #dispatch on the different gru variations function create_gru_layers(gru, variant::FullyGated, res_size, in_size) - Wz_in = create_layer(gru.inner_layer[1], res_size, in_size) - Wz = create_reservoir(gru.reservoir[1], res_size) - bz = create_layer(gru.bias[1], res_size, 1) + Wz_in = gru.inner_layer[1](res_size, in_size) + Wz = gru.reservoir[1](res_size, res_size) + bz = gru.bias[1](res_size, 1) - Wr_in = create_layer(gru.inner_layer[2], res_size, in_size) - Wr = create_reservoir(gru.reservoir[2], res_size) - br = create_layer(gru.bias[2], res_size, 1) + Wr_in = gru.inner_layer[2](res_size, in_size) + Wr = gru.reservoir[2](res_size, res_size) + br = gru.bias[2](res_size, 1) return GRUParams(gru.activation_function, variant, Wz_in, Wz, bz, Wr_in, Wr, br) end #check this one, not sure function create_gru_layers(gru, variant::Minimal, res_size, in_size) - Wz_in = create_layer(gru.inner_layer, res_size, in_size) - Wz = create_reservoir(gru.reservoir, res_size) - bz = create_layer(gru.bias, res_size, 1) + Wz_in = gru.inner_layer(res_size, in_size) + Wz = gru.reservoir(res_size, res_size) + bz = gru.bias(res_size, 1) Wr_in = nothing Wr = nothing diff --git a/src/esn/esn_reservoirs.jl b/src/esn/esn_reservoirs.jl index 2c727afb..85dcf94f 100644 --- a/src/esn/esn_reservoirs.jl +++ b/src/esn/esn_reservoirs.jl @@ -1,416 +1,312 @@ -abstract type AbstractReservoir end - -function get_ressize(reservoir::AbstractReservoir) - return reservoir.res_size -end - -function get_ressize(reservoir) - return size(reservoir, 1) -end - -struct RandSparseReservoir{T, C} <: AbstractReservoir - res_size::Int - radius::T - sparsity::C -end - """ - RandSparseReservoir(res_size, radius, sparsity) - RandSparseReservoir(res_size; radius=1.0, sparsity=0.1) + rand_sparse(rng::AbstractRNG, ::Type{T}, dims::Integer...; radius=1.0, sparsity=0.1) -Returns a random sparse reservoir initializer, which generates a matrix of size `res_size x res_size` with the specified `sparsity` and scaled spectral radius according to `radius`. This type of reservoir initializer is commonly used in Echo State Networks (ESNs) for capturing complex temporal dependencies. +Create and return a random sparse reservoir matrix for use in Echo State Networks (ESNs). The matrix will be of size specified by `dims`, with specified `sparsity` and scaled spectral radius according to `radius`. # Arguments - - `res_size`: The size of the reservoir matrix. - - `radius`: The desired spectral radius of the reservoir. By default, it is set to 1.0. - - `sparsity`: The sparsity level of the reservoir matrix, controlling the fraction of zero elements. By default, it is set to 0.1. + - `rng`: An instance of `AbstractRNG` for random number generation. + - `T`: The data type for the elements of the matrix. + - `dims`: Dimensions of the reservoir matrix. + - `radius`: The desired spectral radius of the reservoir. Defaults to 1.0. + - `sparsity`: The sparsity level of the reservoir matrix, controlling the fraction of zero elements. Defaults to 0.1. # Returns -A RandSparseReservoir object that can be used as a reservoir initializer in ESN construction. +A matrix representing the random sparse reservoir. # References -This type of reservoir initialization is a common choice in ESN construction for its ability to capture temporal dependencies in data. However, there is no specific reference associated with this function. +This type of reservoir initialization is commonly used in ESNs for capturing temporal dependencies in data. """ -function RandSparseReservoir(res_size; radius = 1.0, sparsity = 0.1) - return RandSparseReservoir(res_size, radius, sparsity) +function rand_sparse(rng::AbstractRNG, + ::Type{T}, + dims::Integer...; + radius = T(1.0), + sparsity = T(0.1)) where {T <: Number} + reservoir_matrix = Matrix{T}(sprand(rng, dims..., sparsity)) + reservoir_matrix = T(2.0) .* (reservoir_matrix .- T(0.5)) + replace!(reservoir_matrix, T(-1.0) => T(0.0)) + rho_w = maximum(abs.(eigvals(reservoir_matrix))) + reservoir_matrix .*= radius / rho_w + if Inf in unique(reservoir_matrix) || -Inf in unique(reservoir_matrix) + error("Sparsity too low for size of the matrix. Increase res_size or increase sparsity") + end + return reservoir_matrix end """ - create_reservoir(reservoir::AbstractReservoir, res_size) - create_reservoir(reservoir, args...) + delay_line(rng::AbstractRNG, ::Type{T}, dims::Integer...; weight=0.1) where {T <: Number} -Given an `AbstractReservoir` constructor and the size of the reservoir (`res_size`), this function returns the corresponding reservoir matrix. Alternatively, it accepts a pre-generated matrix. +Create and return a delay line reservoir matrix for use in Echo State Networks (ESNs). A delay line reservoir is a deterministic structure where each unit is connected only to its immediate predecessor with a specified weight. This method is particularly useful for tasks that require specific temporal processing. # Arguments - - `reservoir`: An `AbstractReservoir` object or constructor. - - `res_size`: The size of the reservoir matrix. - - `matrix_type`: The type of the resulting matrix. By default, it is set to `Matrix{Float64}`. + - `rng`: An instance of `AbstractRNG` for random number generation. This argument is not used in the current implementation but is included for consistency with other initialization functions. + - `T`: The data type for the elements of the matrix. + - `dims`: Dimensions of the reservoir matrix. Typically, this should be a tuple of two equal integers representing a square matrix. + - `weight`: The weight determines the absolute value of all connections in the reservoir. Defaults to 0.1. # Returns -A matrix representing the reservoir, generated based on the properties of the specified `reservoir` object or constructor. - -# References - -The choice of reservoir initialization is crucial in Echo State Networks (ESNs) for achieving effective temporal modeling. Specific references for reservoir initialization methods may vary based on the type of reservoir used, but the practice of initializing reservoirs for ESNs is widely documented in the ESN literature. -""" -function create_reservoir(reservoir::RandSparseReservoir, - res_size; - matrix_type = Matrix{Float64}) - reservoir_matrix = Matrix(sprand(res_size, res_size, reservoir.sparsity)) - reservoir_matrix = 2.0 .* (reservoir_matrix .- 0.5) - replace!(reservoir_matrix, -1.0 => 0.0) - rho_w = maximum(abs.(eigvals(reservoir_matrix))) - reservoir_matrix .*= reservoir.radius / rho_w - #TODO: change to explicit if - Inf in unique(reservoir_matrix) || -Inf in unique(reservoir_matrix) ? - error("Sparsity too low for size of the matrix. - Increase res_size or increase sparsity") : nothing - return Adapt.adapt(matrix_type, reservoir_matrix) -end +A delay line reservoir matrix with dimensions specified by `dims`. The matrix is initialized such that each element in the `i+1`th row and `i`th column is set to `weight`, and all other elements are zeros. -function create_reservoir(reservoir, args...; kwargs...) - return reservoir -end +# Example -#= -function create_reservoir(res_size, reservoir::RandReservoir) - sparsity = degree/res_size - W = Matrix(sprand(Float64, res_size, res_size, sparsity)) - W = 2.0 .*(W.-0.5) - replace!(W, -1.0=>0.0) - rho_w = maximum(abs.(eigvals(W))) - W .*= radius/rho_w - W -end -=# - -struct PseudoSVDReservoir{T, C} <: AbstractReservoir - res_size::Int - max_value::T - sparsity::C - sorted::Bool - reverse_sort::Bool -end - -function PseudoSVDReservoir(res_size; - max_value = 1.0, - sparsity = 0.1, - sorted = true, - reverse_sort = false) - return PseudoSVDReservoir(res_size, max_value, sparsity, sorted, reverse_sort) -end - -""" - PseudoSVDReservoir(max_value, sparsity, sorted, reverse_sort) - PseudoSVDReservoir(max_value, sparsity; sorted=true, reverse_sort=false) - -Returns an initializer to build a sparse reservoir matrix with the given `sparsity` by using a pseudo-SVD approach as described in [^yang]. - -# Arguments - - - `res_size`: The size of the reservoir matrix. - - `max_value`: The maximum absolute value of elements in the matrix. - - `sparsity`: The desired sparsity level of the reservoir matrix. - - `sorted`: A boolean indicating whether to sort the singular values before creating the diagonal matrix. By default, it is set to `true`. - - `reverse_sort`: A boolean indicating whether to reverse the sorted singular values. By default, it is set to `false`. - -# Returns - -A PseudoSVDReservoir object that can be used as a reservoir initializer in ESN construction. +```julia +reservoir = delay_line(Float64, 100, 100; weight = 0.2) +``` # References -This reservoir initialization method, based on a pseudo-SVD approach, is inspired by the work in [^yang], which focuses on designing polynomial echo state networks for time series prediction. - -[^yang]: Yang, Cuili, et al. "_Design of polynomial echo state networks for time series prediction._" Neurocomputing 290 (2018): 148-160. +This type of reservoir initialization is described in: +Rodan, Ali, and Peter Tino. "Minimum complexity echo state network." IEEE Transactions on Neural Networks 22.1 (2010): 131-144. """ -function PseudoSVDReservoir(res_size, max_value, sparsity; sorted = true, - reverse_sort = false) - return PseudoSVDReservoir(res_size, max_value, sparsity, sorted, reverse_sort) -end - -function create_reservoir(reservoir::PseudoSVDReservoir, - res_size; - matrix_type = Matrix{Float64}) - sorted = reservoir.sorted - reverse_sort = reservoir.reverse_sort - reservoir_matrix = create_diag(res_size, reservoir.max_value, sorted = sorted, - reverse_sort = reverse_sort) - tmp_sparsity = get_sparsity(reservoir_matrix, res_size) - - while tmp_sparsity <= reservoir.sparsity - reservoir_matrix *= create_qmatrix(res_size, rand(1:res_size), rand(1:res_size), - rand() * 2 - 1) - tmp_sparsity = get_sparsity(reservoir_matrix, res_size) - end - - return Adapt.adapt(matrix_type, reservoir_matrix) -end - -function create_diag(dim, max_value; sorted = true, reverse_sort = false) - diagonal_matrix = zeros(dim, dim) - if sorted == true - if reverse_sort == true - diagonal_values = sort(rand(dim) .* max_value, rev = true) - diagonal_values[1] = max_value - else - diagonal_values = sort(rand(dim) .* max_value) - diagonal_values[end] = max_value - end - else - diagonal_values = rand(dim) .* max_value - end - - for i in 1:dim - diagonal_matrix[i, i] = diagonal_values[i] - end - - return diagonal_matrix -end - -function create_qmatrix(dim, coord_i, coord_j, theta) - qmatrix = zeros(dim, dim) - - for i in 1:dim - qmatrix[i, i] = 1.0 +function delay_line(rng::AbstractRNG, + ::Type{T}, + dims::Integer...; + weight = T(0.1)) where {T <: Number} + reservoir_matrix = zeros(T, dims...) + @assert length(dims) == 2&&dims[1] == dims[2] "The dimensions must define a square matrix (e.g., (100, 100))" + + for i in 1:(dims[1] - 1) + reservoir_matrix[i + 1, i] = weight end - qmatrix[coord_i, coord_i] = cos(theta) - qmatrix[coord_j, coord_j] = cos(theta) - qmatrix[coord_i, coord_j] = -sin(theta) - qmatrix[coord_j, coord_i] = sin(theta) - return qmatrix -end - -function get_sparsity(M, dim) - return size(M[M .!= 0], 1) / (dim * dim - size(M[M .!= 0], 1)) #nonzero/zero elements -end - -#from "minimum complexity echo state network" Rodan -# Delay Line Reservoir - -struct DelayLineReservoir{T} <: AbstractReservoir - res_size::Int - weight::T + return reservoir_matrix end """ - DelayLineReservoir(res_size, weight) - DelayLineReservoir(res_size; weight=0.1) + delay_line_backward(rng::AbstractRNG, ::Type{T}, dims::Integer...; + weight = T(0.1), fb_weight = T(0.2)) where {T <: Number} -Returns a Delay Line Reservoir matrix constructor to obtain a deterministic reservoir as -described in [^Rodan2010]. +Create a delay line backward reservoir with the specified by `dims` and weights. Creates a matrix with backward connections +as described in [^Rodan2010]. The `weight` and `fb_weight` can be passed as either arguments or +keyword arguments, and they determine the absolute values of the connections in the reservoir. # Arguments - - `res_size::Int`: The size of the reservoir. - - `weight::T`: The weight determines the absolute value of all the connections in the reservoir. + - `rng::AbstractRNG`: Random number generator. + - `T::Type`: Type of the elements in the reservoir matrix. + - `dims::Integer...`: Dimensions of the reservoir matrix. + - `weight::T`: The weight determines the absolute value of forward connections in the reservoir, and is set to 0.1 by default. + - `fb_weight::T`: The `fb_weight` determines the absolute value of backward connections in the reservoir, and is set to 0.2 by default. # Returns -A `DelayLineReservoir` object. +Reservoir matrix with the dimensions specified by `dims` and weights. # References [^Rodan2010]: Rodan, Ali, and Peter Tino. "Minimum complexity echo state network." IEEE transactions on neural networks 22.1 (2010): 131-144. """ -function DelayLineReservoir(res_size; weight = 0.1) - return DelayLineReservoir(res_size, weight) -end - -function create_reservoir(reservoir::DelayLineReservoir, - res_size; - matrix_type = Matrix{Float64}) - reservoir_matrix = zeros(res_size, res_size) +function delay_line_backward(rng::AbstractRNG, + ::Type{T}, + dims::Integer...; + weight = T(0.1), + fb_weight = T(0.2)) where {T <: Number} + res_size = first(dims) + reservoir_matrix = zeros(T, dims...) for i in 1:(res_size - 1) - reservoir_matrix[i + 1, i] = reservoir.weight + reservoir_matrix[i + 1, i] = weight + reservoir_matrix[i, i + 1] = fb_weight end - return Adapt.adapt(matrix_type, reservoir_matrix) -end - -#from "minimum complexity echo state network" Rodan -# Delay Line Reservoir with backward connections -struct DelayLineBackwardReservoir{T} <: AbstractReservoir - res_size::Int - weight::T - fb_weight::T + return reservoir_matrix end """ - DelayLineBackwardReservoir(res_size, weight, fb_weight) - DelayLineBackwardReservoir(res_size; weight=0.1, fb_weight=0.2) + cycle_jumps(rng::AbstractRNG, ::Type{T}, dims::Integer...; + cycle_weight = T(0.1), jump_weight = T(0.1), jump_size = 3) where {T <: Number} -Returns a Delay Line Reservoir constructor to create a matrix with backward connections -as described in [^Rodan2010]. The `weight` and `fb_weight` can be passed as either arguments or -keyword arguments, and they determine the absolute values of the connections in the reservoir. +Create a cycle jumps reservoir with the specified dimensions, cycle weight, jump weight, and jump size. # Arguments - - `res_size::Int`: The size of the reservoir. - - `weight::T`: The weight determines the absolute value of forward connections in the reservoir. - - `fb_weight::T`: The `fb_weight` determines the absolute value of backward connections in the reservoir. + - `rng::AbstractRNG`: Random number generator. + - `T::Type`: Type of the elements in the reservoir matrix. + - `dims::Integer...`: Dimensions of the reservoir matrix. + - `cycle_weight::T = T(0.1)`: The weight of cycle connections. + - `jump_weight::T = T(0.1)`: The weight of jump connections. + - `jump_size::Int = 3`: The number of steps between jump connections. # Returns -A `DelayLineBackwardReservoir` object. +Reservoir matrix with the specified dimensions, cycle weight, jump weight, and jump size. # References -[^Rodan2010]: Rodan, Ali, and Peter Tino. "Minimum complexity echo state network." - IEEE transactions on neural networks 22.1 (2010): 131-144. +[^Rodan2012]: Rodan, Ali, and Peter Tiňo. "Simple deterministically constructed cycle reservoirs + with regular jumps." Neural computation 24.7 (2012): 1822-1852. """ -function DelayLineBackwardReservoir(res_size; weight = 0.1, fb_weight = 0.2) - return DelayLineBackwardReservoir(res_size, weight, fb_weight) -end - -function create_reservoir(reservoir::DelayLineBackwardReservoir, - res_size; - matrix_type = Matrix{Float64}) - reservoir_matrix = zeros(res_size, res_size) +function cycle_jumps(rng::AbstractRNG, + ::Type{T}, + dims::Integer...; + cycle_weight::Number = T(0.1), + jump_weight::Number = T(0.1), + jump_size::Int = 3) where {T <: Number} + res_size = first(dims) + reservoir_matrix = zeros(T, dims...) for i in 1:(res_size - 1) - reservoir_matrix[i + 1, i] = reservoir.weight - reservoir_matrix[i, i + 1] = reservoir.fb_weight + reservoir_matrix[i + 1, i] = cycle_weight end - return Adapt.adapt(matrix_type, reservoir_matrix) -end + reservoir_matrix[1, res_size] = cycle_weight -#from "minimum complexity echo state network" Rodan -# Simple cycle reservoir -struct SimpleCycleReservoir{T} <: AbstractReservoir - res_size::Int - weight::T + for i in 1:jump_size:(res_size - jump_size) + tmp = (i + jump_size) % res_size + if tmp == 0 + tmp = res_size + end + reservoir_matrix[i, tmp] = jump_weight + reservoir_matrix[tmp, i] = jump_weight + end + + return reservoir_matrix end """ - SimpleCycleReservoir(res_size, weight) - SimpleCycleReservoir(res_size; weight=0.1) + simple_cycle(rng::AbstractRNG, ::Type{T}, dims::Integer...; + weight = T(0.1)) where {T <: Number} -Returns a Simple Cycle Reservoir constructor to build a reservoir matrix as -described in [^Rodan2010]. The `weight` can be passed as an argument or a keyword argument, and it determines the -absolute value of all the connections in the reservoir. +Create a simple cycle reservoir with the specified dimensions and weight. # Arguments - - `res_size::Int`: The size of the reservoir. - - `weight::T`: The weight determines the absolute value of connections in the reservoir. + - `rng::AbstractRNG`: Random number generator. + - `T::Type`: Type of the elements in the reservoir matrix. + - `dims::Integer...`: Dimensions of the reservoir matrix. + - `weight::T = T(0.1)`: Weight of the connections in the reservoir matrix. # Returns -A `SimpleCycleReservoir` object. +Reservoir matrix with the dimensions specified by `dims` and weights. # References [^Rodan2010]: Rodan, Ali, and Peter Tino. "Minimum complexity echo state network." IEEE transactions on neural networks 22.1 (2010): 131-144. """ -function SimpleCycleReservoir(res_size; weight = 0.1) - return SimpleCycleReservoir(res_size, weight) -end - -function create_reservoir(reservoir::SimpleCycleReservoir, - res_size; - matrix_type = Matrix{Float64}) - reservoir_matrix = zeros(Float64, res_size, res_size) - - for i in 1:(res_size - 1) - reservoir_matrix[i + 1, i] = reservoir.weight +function simple_cycle(rng::AbstractRNG, + ::Type{T}, + dims::Integer...; + weight = T(0.1)) where {T <: Number} + reservoir_matrix = zeros(T, dims...) + + for i in 1:(dims[1] - 1) + reservoir_matrix[i + 1, i] = weight end - reservoir_matrix[1, res_size] = reservoir.weight - return Adapt.adapt(matrix_type, reservoir_matrix) -end - -#from "simple deterministically constructed cycle reservoirs with regular jumps" by Rodan and Tino -# Cycle Reservoir with Jumps -struct CycleJumpsReservoir{T} <: AbstractReservoir - res_size::Int - cycle_weight::T - jump_weight::T - jump_size::Int + reservoir_matrix[1, dims[1]] = weight + return reservoir_matrix end """ - CycleJumpsReservoir(res_size; cycle_weight=0.1, jump_weight=0.1, jump_size=3) - CycleJumpsReservoir(res_size, cycle_weight, jump_weight, jump_size) + pseudo_svd(rng::AbstractRNG, ::Type{T}, dims::Integer...; + max_value, sparsity, sorted = true, reverse_sort = false) where {T <: Number} -Return a Cycle Reservoir with Jumps constructor to create a reservoir matrix as described -in [^Rodan2012]. The `cycle_weight`, `jump_weight`, and `jump_size` can be passed as arguments or keyword arguments, and they -determine the absolute values of connections in the reservoir. The `jump_size` determines the jumps between `jump_weight`s. + Returns an initializer to build a sparse reservoir matrix with the given `sparsity` by using a pseudo-SVD approach as described in [^yang]. # Arguments - - `res_size::Int`: The size of the reservoir. - - `cycle_weight::T`: The weight of cycle connections. - - `jump_weight::T`: The weight of jump connections. - - `jump_size::Int`: The number of steps between jump connections. + - `rng::AbstractRNG`: Random number generator. + - `T::Type`: Type of the elements in the reservoir matrix. + - `dims::Integer...`: Dimensions of the reservoir matrix. + - `max_value`: The maximum absolute value of elements in the matrix. + - `sparsity`: The desired sparsity level of the reservoir matrix. + - `sorted`: A boolean indicating whether to sort the singular values before creating the diagonal matrix. By default, it is set to `true`. + - `reverse_sort`: A boolean indicating whether to reverse the sorted singular values. By default, it is set to `false`. # Returns -A `CycleJumpsReservoir` object. +Reservoir matrix with the specified dimensions, max value, and sparsity. # References -[^Rodan2012]: Rodan, Ali, and Peter Tiňo. "Simple deterministically constructed cycle reservoirs - with regular jumps." Neural computation 24.7 (2012): 1822-1852. -""" -function CycleJumpsReservoir(res_size; cycle_weight = 0.1, jump_weight = 0.1, jump_size = 3) - return CycleJumpsReservoir(res_size, cycle_weight, jump_weight, jump_size) -end - -function create_reservoir(reservoir::CycleJumpsReservoir, - res_size; - matrix_type = Matrix{Float64}) - reservoir_matrix = zeros(res_size, res_size) +This reservoir initialization method, based on a pseudo-SVD approach, is inspired by the work in [^yang], which focuses on designing polynomial echo state networks for time series prediction. - for i in 1:(res_size - 1) - reservoir_matrix[i + 1, i] = reservoir.cycle_weight +[^yang]: Yang, Cuili, et al. "_Design of polynomial echo state networks for time series prediction._" Neurocomputing 290 (2018): 148-160. +""" +function pseudo_svd(rng::AbstractRNG, + ::Type{T}, + dims::Integer...; + max_value::Number = T(1.0), + sparsity::Number = 0.1, + sorted::Bool = true, + reverse_sort::Bool = false) where {T <: Number} + reservoir_matrix = create_diag(dims[1], + max_value, + T; + sorted = sorted, + reverse_sort = reverse_sort) + tmp_sparsity = get_sparsity(reservoir_matrix, dims[1]) + + while tmp_sparsity <= sparsity + reservoir_matrix *= create_qmatrix(dims[1], + rand(1:dims[1]), + rand(1:dims[1]), + rand(T) * T(2) - T(1), + T) + tmp_sparsity = get_sparsity(reservoir_matrix, dims[1]) end - reservoir_matrix[1, res_size] = reservoir.cycle_weight + return reservoir_matrix +end - for i in 1:(reservoir.jump_size):(res_size - reservoir.jump_size) - tmp = (i + reservoir.jump_size) % res_size - if tmp == 0 - tmp = res_size +function create_diag(dim::Number, max_value::Number, ::Type{T}; + sorted::Bool = true, reverse_sort::Bool = false) where {T <: Number} + diagonal_matrix = zeros(T, dim, dim) + if sorted == true + if reverse_sort == true + diagonal_values = sort(rand(T, dim) .* max_value, rev = true) + diagonal_values[1] = max_value + else + diagonal_values = sort(rand(T, dim) .* max_value) + diagonal_values[end] = max_value end - reservoir_matrix[i, tmp] = reservoir.jump_weight - reservoir_matrix[tmp, i] = reservoir.jump_weight + else + diagonal_values = rand(T, dim) .* max_value end - return Adapt.adapt(matrix_type, reservoir_matrix) -end - -""" - NullReservoir() - -Return a constructor for a matrix of zeros with dimensions `res_size x res_size`. - -# Arguments + for i in 1:dim + diagonal_matrix[i, i] = diagonal_values[i] + end - - None + return diagonal_matrix +end -# Returns +function create_qmatrix(dim::Number, + coord_i::Number, + coord_j::Number, + theta::Number, + ::Type{T}) where {T <: Number} + qmatrix = zeros(T, dim, dim) -A `NullReservoir` object. + for i in 1:dim + qmatrix[i, i] = 1.0 + end -# References + qmatrix[coord_i, coord_i] = cos(theta) + qmatrix[coord_j, coord_j] = cos(theta) + qmatrix[coord_i, coord_j] = -sin(theta) + qmatrix[coord_j, coord_i] = sin(theta) + return qmatrix +end - - None -""" -struct NullReservoir <: AbstractReservoir end +function get_sparsity(M, dim) + return size(M[M .!= 0], 1) / (dim * dim - size(M[M .!= 0], 1)) #nonzero/zero elements +end -function create_reservoir(reservoir::NullReservoir, - res_size; - matrix_type = Matrix{Float64}) - return Adapt.adapt(matrix_type, zeros(res_size, res_size)) +# from WeightInitializers.jl, TODO @MartinuzziFrancesco consider importing package +function _default_rng() + @static if VERSION >= v"1.7" + return Xoshiro(1234) + else + return MersenneTwister(1234) + end end + +__partial_apply(fn, inp) = fn$inp diff --git a/src/esn/hybridesn.jl b/src/esn/hybridesn.jl new file mode 100644 index 00000000..a190590a --- /dev/null +++ b/src/esn/hybridesn.jl @@ -0,0 +1,122 @@ +struct HybridESN{I, S, V, N, T, O, M, B, ST, W, IS} <: AbstractEchoStateNetwork + res_size::I + train_data::S + model::V + nla_type::N + input_matrix::T + reservoir_driver::O + reservoir_matrix::M + bias_vector::B + states_type::ST + washout::W + states::IS +end + +struct KnowledgeModel{T, K, O, I, S, D} + prior_model::T + u0::K + tspan::O + dt::I + datasize::S + model_data::D +end + +""" + Hybrid(prior_model, u0, tspan, datasize) + +Constructs a `Hybrid` variation of Echo State Networks (ESNs) integrating a knowledge-based model +(`prior_model`) with ESNs for advanced training and prediction in chaotic systems. + +# Parameters + + - `prior_model`: A knowledge-based model function for integration with ESNs. + - `u0`: Initial conditions for the model. + - `tspan`: Time span as a tuple, indicating the duration for model operation. + - `datasize`: The size of the data to be processed. + +# Returns + + - A `Hybrid` struct instance representing the combined ESN and knowledge-based model. + +This method is effective for chaotic processes as highlighted in [^Pathak]. + +Reference: + +[^Pathak]: Jaideep Pathak et al. + "Hybrid Forecasting of Chaotic Processes: + Using Machine Learning in Conjunction with a Knowledge-Based Model" (2018). +""" +function KnowledgeModel(prior_model, u0, tspan, datasize) + trange = collect(range(tspan[1], tspan[2], length = datasize)) + dt = trange[2] - trange[1] + tsteps = push!(trange, dt + trange[end]) + tspan_new = (tspan[1], dt + tspan[2]) + model_data = prior_model(u0, tspan_new, tsteps) + return KnowledgeModel(prior_model, u0, tspan, dt, datasize, model_data) +end + +function HybridESN(model, + train_data, + in_size::Int, + res_size::Int; + input_layer = scaled_rand, + reservoir = rand_sparse, + bias = zeros64, + reservoir_driver = RNN(), + nla_type = NLADefault(), + states_type = StandardStates(), + washout = 0, + rng = _default_rng(), + T = Float32, + matrix_type = typeof(train_data)) + train_data = vcat(train_data, model.model_data[:, 1:(end - 1)]) + + if states_type isa AbstractPaddedStates + in_size = size(train_data, 1) + 1 + train_data = vcat(Adapt.adapt(matrix_type, ones(1, size(train_data, 2))), + train_data) + else + in_size = size(train_data, 1) + end + + reservoir_matrix = reservoir(rng, T, res_size, res_size) + #different from ESN, why? + input_matrix = input_layer(rng, T, res_size, in_size) + bias_vector = bias(rng, res_size) + inner_res_driver = reservoir_driver_params(reservoir_driver, res_size, in_size) + states = create_states(inner_res_driver, train_data, washout, reservoir_matrix, + input_matrix, bias_vector) + train_data = train_data[:, (washout + 1):end] + + HybridESN(res_size, train_data, model, nla_type, input_matrix, + inner_res_driver, reservoir_matrix, bias_vector, states_type, washout, + states) +end + +function (hesn::HybridESN)(prediction::AbstractPrediction, + output_layer::AbstractOutputLayer; + last_state = hesn.states[:, [end]], + kwargs...) + km = hesn.model + pred_len = prediction.prediction_len + + model = km.prior_model + predict_tsteps = [km.tspan[2] + km.dt] + [append!(predict_tsteps, predict_tsteps[end] + km.dt) for i in 1:pred_len] + tspan_new = (km.tspan[2] + km.dt, predict_tsteps[end]) + u0 = km.model_data[:, end] + model_pred_data = model(u0, tspan_new, predict_tsteps)[:, 2:end] + + return obtain_esn_prediction(hesn, prediction, last_state, output_layer, + model_pred_data; + kwargs...) +end + +function train(hesn::HybridESN, + target_data, + training_method = StandardRidge(0.0)) + states = vcat(hesn.states, hesn.model.model_data[:, 2:end]) + states_new = hesn.states_type(hesn.nla_type, states, hesn.train_data[:, 1:end]) + + return _train(states_new, target_data, training_method) +end diff --git a/test/esn/deepesn.jl b/test/esn/deepesn.jl index 8b137891..38e06814 100644 --- a/test/esn/deepesn.jl +++ b/test/esn/deepesn.jl @@ -1 +1,20 @@ +using ReservoirComputing, Random, Statistics +const res_size = 20 +const ts = 0.0:0.1:50.0 +const data = sin.(ts) +const train_len = 400 +const predict_len = 100 +const input_data = reduce(hcat, data[1:(train_len - 1)]) +const target_data = reduce(hcat, data[2:train_len]) +const test = reduce(hcat, data[(train_len + 1):(train_len + predict_len)]) +const reg = 10e-6 +#test_types = [Float64, Float32, Float16] + +Random.seed!(77) +res = rand_sparse(; radius = 1.2, sparsity = 0.1) +esn = DeepESN(input_data, 1, res_size) + +output_layer = train(esn, target_data) +output = esn(Generative(length(test)), output_layer) +@test mean(abs.(test .- output)) ./ mean(abs.(test)) < 0.22 diff --git a/test/esn/test_drivers.jl b/test/esn/test_drivers.jl index 54045925..7ad47f74 100644 --- a/test/esn/test_drivers.jl +++ b/test/esn/test_drivers.jl @@ -1,65 +1,41 @@ using ReservoirComputing, Random, Statistics, NNlib -const res_size = 20 +const res_size = 50 const ts = 0.0:0.1:50.0 const data = sin.(ts) const train_len = 400 const input_data = reduce(hcat, data[1:(train_len - 1)]) const target_data = reduce(hcat, data[2:train_len]) const predict_len = 100 -const test = reduce(hcat, data[(train_len + 1):(train_len + predict_len)]) +const test_data = reduce(hcat, data[(train_len + 1):(train_len + predict_len)]) const training_method = StandardRidge(10e-6) - Random.seed!(77) -esn = ESN(input_data; - reservoir = RandSparseReservoir(res_size, 1.2, 0.1), - reservoir_driver = GRU(variant = FullyGated(), - reservoir = [ - RandSparseReservoir(res_size, 1.0, 0.5), - RandSparseReservoir(res_size, 1.2, 0.1) - ])) - -output_layer = train(esn, target_data, training_method) -output = esn(Predictive(target_data), output_layer, initial_conditions = target_data[1]) -@test mean(abs.(target_data .- output)) ./ mean(abs.(target_data)) < 0.11 - -esn = ESN(input_data; - reservoir = RandSparseReservoir(res_size, 1.2, 0.1), - reservoir_driver = GRU(variant = Minimal(), - reservoir = RandSparseReservoir(res_size, 1.0, 0.5), - inner_layer = DenseLayer(), - bias = DenseLayer())) - -output_layer = train(esn, target_data, training_method) -output = esn(Predictive(target_data), output_layer, initial_conditions = target_data[1]) -@test mean(abs.(target_data .- output)) ./ mean(abs.(target_data)) < 0.11 - -#multiple rnn -esn = ESN(input_data; - reservoir = RandSparseReservoir(res_size, 1.2, 0.1), - reservoir_driver = MRNN(activation_function = (tanh, sigmoid), - scaling_factor = (0.8, 0.1))) -output_layer = train(esn, target_data, training_method) -output = esn(Predictive(target_data), output_layer, initial_conditions = target_data[1]) -@test mean(abs.(target_data .- output)) ./ mean(abs.(target_data)) < 0.11 - -#deep esn -esn = ESN(input_data; - reservoir = [ - RandSparseReservoir(res_size, 1.2, 0.1), - RandSparseReservoir(res_size, 1.2, 0.1) - ]) -output_layer = train(esn, target_data, training_method) -output = esn(Predictive(target_data), output_layer, initial_conditions = target_data[1]) -@test mean(abs.(target_data .- output)) ./ mean(abs.(target_data)) < 0.11 -esn = ESN(input_data; - reservoir = [ - RandSparseReservoir(res_size, 1.2, 0.1), - RandSparseReservoir(res_size, 1.2, 0.1) - ], - input_layer = [DenseLayer(), DenseLayer()], - bias = [NullLayer(), NullLayer()]) -output_layer = train(esn, target_data, training_method) -output = esn(Predictive(target_data), output_layer, initial_conditions = target_data[1]) -@test mean(abs.(target_data .- output)) ./ mean(abs.(target_data)) < 0.11 +function test_esn(input_data, target_data, training_method, esn_config) + esn = ESN(input_data, 1, res_size; esn_config...) + output_layer = train(esn, target_data, training_method) + + output = esn(Predictive(target_data), output_layer, initial_conditions = target_data[1]) + @test mean(abs.(target_data .- output)) ./ mean(abs.(target_data)) < 0.15 +end + +esn_configs = [ + Dict(:reservoir => rand_sparse(; radius = 1.2), + :reservoir_driver => GRU(variant = FullyGated(), + reservoir = [ + rand_sparse(; radius = 1.0, sparsity = 0.5), + rand_sparse(; radius = 1.2, sparsity = 0.1) + ])), + Dict(:reservoir => rand_sparse(; radius = 1.2), + :reservoir_driver => GRU(variant = Minimal(), + reservoir = rand_sparse(; radius = 1.0, sparsity = 0.5), + inner_layer = scaled_rand, + bias = scaled_rand)), + Dict(:reservoir => rand_sparse(; radius = 1.2), + :reservoir_driver => MRNN(activation_function = (tanh, sigmoid), + scaling_factor = (0.8, 0.1))) +] + +@testset "Test Drivers: $config" for config in esn_configs + test_esn(input_data, target_data, training_method, config) +end diff --git a/test/esn/test_hybrid.jl b/test/esn/test_hybrid.jl index 4f858208..415b1343 100644 --- a/test/esn/test_hybrid.jl +++ b/test/esn/test_hybrid.jl @@ -30,15 +30,17 @@ test_data = ode_data[:, (train_len + 1):end][:, 1:1000] predict_len = size(test_data, 2) tspan_train = (tspan[1], ode_sol.t[train_len]) -hybrid = Hybrid(prior_model_data_generator, u0, tspan_train, train_len) +km = KnowledgeModel(prior_model_data_generator, u0, tspan_train, train_len) Random.seed!(77) -esn = ESN(input_data, - reservoir = RandSparseReservoir(300), - variation = hybrid) +hesn = HybridESN(km, + input_data, + 3, + 300; + reservoir = rand_sparse) -output_layer = train(esn, target_data, StandardRidge(0.3)) +output_layer = train(hesn, target_data, StandardRidge(0.3)) -output = esn(Generative(predict_len), output_layer) +output = hesn(Generative(predict_len), output_layer) @test mean(abs.(test_data[1:100] .- output[1:100])) ./ mean(abs.(test_data[1:100])) < 0.11 diff --git a/test/esn/test_inits.jl b/test/esn/test_inits.jl new file mode 100644 index 00000000..6d25ea34 --- /dev/null +++ b/test/esn/test_inits.jl @@ -0,0 +1,90 @@ +using ReservoirComputing +using LinearAlgebra +using Random + +const res_size = 30 +const in_size = 3 +const radius = 1.0 +const sparsity = 0.1 +const weight = 0.2 +const jump_size = 3 +const rng = Random.default_rng() + +function check_radius(matrix, target_radius; tolerance = 1e-5) + eigenvalues = eigvals(matrix) + spectral_radius = maximum(abs.(eigenvalues)) + return isapprox(spectral_radius, target_radius, atol = tolerance) +end + +ft = [Float16, Float32, Float64] +reservoir_inits = [ + rand_sparse, + delay_line, + delay_line_backward, + cycle_jumps, + simple_cycle, + pseudo_svd +] +input_inits = [ + scaled_rand, + weighted_init, + sparse_init, + minimal_init, + minimal_init(; sampling_type = :irrational) +] + +@testset "Reservoir Initializers" begin + @testset "Sizes and types: $init $T" for init in reservoir_inits, T in ft + #sizes + @test size(init(res_size, res_size)) == (res_size, res_size) + @test size(init(rng, res_size, res_size)) == (res_size, res_size) + #types + @test eltype(init(T, res_size, res_size)) == T + @test eltype(init(rng, T, res_size, res_size)) == T + #closure + cl = init(rng) + @test eltype(cl(T, res_size, res_size)) == T + end + + @testset "Check spectral radius" begin + sp = rand_sparse(res_size, res_size) + @test check_radius(sp, radius) + end + + @testset "Minimum complexity: $init" for init in [ + delay_line, + delay_line_backward, + cycle_jumps, + simple_cycle + ] + dl = init(res_size, res_size) + if init === delay_line_backward + @test unique(dl) == Float32.([0.0, 0.1, 0.2]) + else + @test unique(dl) == Float32.([0.0, 0.1]) + end + end +end + +# TODO: @MartinuzziFrancesco Missing tests for informed_init +@testset "Input Initializers" begin + @testset "Sizes and types: $init $T" for init in input_inits, T in ft + #sizes + @test size(init(res_size, in_size)) == (res_size, in_size) + @test size(init(rng, res_size, in_size)) == (res_size, in_size) + #types + @test eltype(init(T, res_size, in_size)) == T + @test eltype(init(rng, T, res_size, in_size)) == T + #closure + cl = init(rng) + @test eltype(cl(T, res_size, in_size)) == T + end + + @testset "Minimum complexity: $init" for init in [ + minimal_init, + minimal_init(; sampling_type = :irrational) + ] + dl = init(res_size, in_size) + @test sort(unique(dl)) == Float32.([-0.1, 0.1]) + end +end diff --git a/test/esn/test_input_layers.jl b/test/esn/test_input_layers.jl deleted file mode 100644 index adc655da..00000000 --- a/test/esn/test_input_layers.jl +++ /dev/null @@ -1,55 +0,0 @@ -using ReservoirComputing - -const res_size = 10 -const in_size = 3 -const scaling = 0.1 -const weight = 0.2 - -#testing WeightedLayer implicit and esplicit constructors -input_constructor = WeightedLayer(scaling) -input_matrix = create_layer(input_constructor, res_size, in_size) -@test size(input_matrix) == (Int(floor(res_size / in_size) * in_size), in_size) -@test maximum(input_matrix) <= scaling - -input_constructor = WeightedLayer(scaling = scaling) -input_matrix = create_layer(input_constructor, res_size, in_size) -@test size(input_matrix) == (Int(floor(res_size / in_size) * in_size), in_size) -@test maximum(input_matrix) <= scaling - -#testing DenseLayer implicit and esplicit constructors -input_constructor = DenseLayer(scaling) -input_matrix = create_layer(input_constructor, res_size, in_size) -@test size(input_matrix) == (res_size, in_size) -@test maximum(input_matrix) <= scaling - -input_constructor = DenseLayer(scaling = scaling) -input_matrix = create_layer(input_constructor, res_size, in_size) -@test size(input_matrix) == (res_size, in_size) -@test maximum(input_matrix) <= scaling - -#testing SparseLayer implicit and esplicit constructors -input_constructor = SparseLayer(scaling) -input_matrix = create_layer(input_constructor, res_size, in_size) -@test size(input_matrix) == (res_size, in_size) -@test maximum(input_matrix) <= scaling - -input_constructor = SparseLayer(scaling = scaling) -input_matrix = create_layer(input_constructor, res_size, in_size) -@test size(input_matrix) == (res_size, in_size) -@test maximum(input_matrix) <= scaling - -#testing MinimumLayer implicit and esplicit constructors -input_constructor = MinimumLayer(weight) -input_matrix = create_layer(input_constructor, res_size, in_size) -@test size(input_matrix) == (res_size, in_size) -@test maximum(input_matrix) == weight - -input_constructor = MinimumLayer(weight = weight) -input_matrix = create_layer(input_constructor, res_size, in_size) -@test size(input_matrix) == (res_size, in_size) -@test maximum(input_matrix) == weight - -#testing NullLayer constructor -input_constructor = NullLayer() -input_matrix = create_layer(input_constructor, res_size, in_size) -@test size(input_matrix) == (res_size, in_size) diff --git a/test/esn/test_nla.jl b/test/esn/test_nla.jl deleted file mode 100644 index f5ac42f8..00000000 --- a/test/esn/test_nla.jl +++ /dev/null @@ -1,23 +0,0 @@ -using ReservoirComputing - -states = [1, 2, 3, 4, 5, 6, 7, 8, 9] -nla1_states = [1, 2, 9, 4, 25, 6, 49, 8, 81] -nla2_states = [1, 2, 2, 4, 12, 6, 30, 8, 9] -nla3_states = [1, 2, 8, 4, 24, 6, 48, 8, 9] - -test_types = [Float64, Float32, Float16] - -for tt in test_types - # test default - nla_states = ReservoirComputing.nla(NLADefault(), tt.(states)) - @test nla_states == tt.(states) - # test NLAT1 - nla_states = ReservoirComputing.nla(NLAT1(), tt.(states)) - @test nla_states = tt.(nla1_states) - # test nlat2 - nla_states = ReservoirComputing.nla(NLAT2(), tt.(states)) - @test nla_states = tt.(nla2_states) - # test nlat3 - nla_states = ReservoirComputing.nla(NLAT3(), tt.(states)) - @test nla_states = tt.(nla3_states) -end diff --git a/test/esn/test_reservoirs.jl b/test/esn/test_reservoirs.jl deleted file mode 100644 index ac751712..00000000 --- a/test/esn/test_reservoirs.jl +++ /dev/null @@ -1,79 +0,0 @@ -using ReservoirComputing - -const res_size = 20 -const radius = 1.0 -const sparsity = 0.1 -const weight = 0.2 -const jump_size = 3 - -#testing RandSparseReservoir implicit and esplicit constructors -reservoir_constructor = RandSparseReservoir(res_size, radius, sparsity) -reservoir_matrix = create_reservoir(reservoir_constructor, res_size) -@test size(reservoir_matrix) == (res_size, res_size) - -reservoir_constructor = RandSparseReservoir(res_size, radius = radius, sparsity = sparsity) -reservoir_matrix = create_reservoir(reservoir_constructor, res_size) -@test size(reservoir_matrix) == (res_size, res_size) - -#testing PseudoSVDReservoir implicit and esplicit constructors -reservoir_constructor = PseudoSVDReservoir(res_size, radius, sparsity) -reservoir_matrix = create_reservoir(reservoir_constructor, res_size) -@test size(reservoir_matrix) == (res_size, res_size) -@test maximum(reservoir_matrix) <= radius - -reservoir_constructor = PseudoSVDReservoir(res_size, max_value = radius, - sparsity = sparsity) -reservoir_matrix = create_reservoir(reservoir_constructor, res_size) -@test size(reservoir_matrix) == (res_size, res_size) -@test maximum(reservoir_matrix) <= radius - -#testing DelayLineReservoir implicit and esplicit constructors -reservoir_constructor = DelayLineReservoir(res_size, weight) -reservoir_matrix = create_reservoir(reservoir_constructor, res_size) -@test size(reservoir_matrix) == (res_size, res_size) -@test maximum(reservoir_matrix) == weight - -reservoir_constructor = DelayLineReservoir(res_size, weight = weight) -reservoir_matrix = create_reservoir(reservoir_constructor, res_size) -@test size(reservoir_matrix) == (res_size, res_size) -@test maximum(reservoir_matrix) == weight - -#testing DelayLineReservoir implicit and esplicit constructors -reservoir_constructor = DelayLineBackwardReservoir(res_size, weight, weight) -reservoir_matrix = create_reservoir(reservoir_constructor, res_size) -@test size(reservoir_matrix) == (res_size, res_size) -@test maximum(reservoir_matrix) == weight - -reservoir_constructor = DelayLineBackwardReservoir(res_size, weight = weight, - fb_weight = weight) -reservoir_matrix = create_reservoir(reservoir_constructor, res_size) -@test size(reservoir_matrix) == (res_size, res_size) -@test maximum(reservoir_matrix) == weight - -#testing SimpleCycleReservoir implicit and esplicit constructors -reservoir_constructor = SimpleCycleReservoir(res_size, weight) -reservoir_matrix = create_reservoir(reservoir_constructor, res_size) -@test size(reservoir_matrix) == (res_size, res_size) -@test maximum(reservoir_matrix) == weight - -reservoir_constructor = SimpleCycleReservoir(res_size, weight = weight) -reservoir_matrix = create_reservoir(reservoir_constructor, res_size) -@test size(reservoir_matrix) == (res_size, res_size) -@test maximum(reservoir_matrix) == weight - -#testing CycleJumpsReservoir implicit and esplicit constructors -reservoir_constructor = CycleJumpsReservoir(res_size, weight, weight, jump_size) -reservoir_matrix = create_reservoir(reservoir_constructor, res_size) -@test size(reservoir_matrix) == (res_size, res_size) -@test maximum(reservoir_matrix) == weight - -reservoir_constructor = CycleJumpsReservoir(res_size, cycle_weight = weight, - jump_weight = weight, jump_size = jump_size) -reservoir_matrix = create_reservoir(reservoir_constructor, res_size) -@test size(reservoir_matrix) == (res_size, res_size) -@test maximum(reservoir_matrix) == weight - -#testing NullReservoir constructors -reservoir_constructor = NullReservoir() -reservoir_matrix = create_reservoir(reservoir_constructor, res_size) -@test size(reservoir_matrix) == (res_size, res_size) diff --git a/test/esn/test_states.jl b/test/esn/test_states.jl deleted file mode 100644 index 479d29c9..00000000 --- a/test/esn/test_states.jl +++ /dev/null @@ -1,46 +0,0 @@ -using ReservoirComputing - -test_types = [Float64, Float32, Float16] -states = [1, 2, 3, 4, 5, 6, 7, 8, 9] -in_data = fill(1, 3) - -states_types = [StandardStates, ExtendedStates, PaddedStates, PaddedExtendedStates] - -# testing extension and padding -for tt in test_types - st_states = StandardStates()(NLADefault(), tt.(states), tt.(in_data)) - @test length(st_states) == length(states) - @test typeof(st_states) == typeof(tt.(states)) - - st_states = ExtendedStates()(NLADefault(), tt.(states), tt.(in_data)) - @test length(st_states) == length(states) + length(in_data) - @test typeof(st_states) == typeof(tt.(states)) - - st_states = PaddedStates()(NLADefault(), tt.(states), tt.(in_data)) - @test length(st_states) == length(states) + 1 - @test typeof(st_states[1]) == typeof(tt.(states)[1]) - - st_states = PaddedExtendedStates()(NLADefault(), tt.(states), tt.(in_data)) - @test length(st_states) == length(states) + length(in_data) + 1 - @test typeof(st_states[1]) == typeof(tt.(states)[1]) -end - -## testing non linear algos -nla1_states = [1, 2, 9, 4, 25, 6, 49, 8, 81] -nla2_states = [1, 2, 2, 4, 12, 6, 30, 8, 9] -nla3_states = [1, 2, 8, 4, 24, 6, 48, 8, 9] - -for tt in test_types - # test default - nla_states = ReservoirComputing.nla(NLADefault(), tt.(states)) - @test nla_states == tt.(states) - # test NLAT1 - nla_states = ReservoirComputing.nla(NLAT1(), tt.(states)) - @test nla_states == tt.(nla1_states) - # test nlat2 - nla_states = ReservoirComputing.nla(NLAT2(), tt.(states)) - @test nla_states == tt.(nla2_states) - # test nlat3 - nla_states = ReservoirComputing.nla(NLAT3(), tt.(states)) - @test nla_states == tt.(nla3_states) -end diff --git a/test/esn/test_train.jl b/test/esn/test_train.jl index 178d4c75..034bbca5 100644 --- a/test/esn/test_train.jl +++ b/test/esn/test_train.jl @@ -9,10 +9,12 @@ const input_data = reduce(hcat, data[1:(train_len - 1)]) const target_data = reduce(hcat, data[2:train_len]) const test = reduce(hcat, data[(train_len + 1):(train_len + predict_len)]) const reg = 10e-6 +#test_types = [Float64, Float32, Float16] Random.seed!(77) -esn = ESN(input_data; - reservoir = RandSparseReservoir(res_size, 1.2, 0.1)) +res = rand_sparse(; radius = 1.2, sparsity = 0.1) +esn = ESN(input_data, 1, res_size; + reservoir = rand_sparse) training_methods = [ StandardRidge(regularization_coeff = reg), @@ -21,14 +23,9 @@ training_methods = [ EpsilonSVR() ] -for t in training_methods - output_layer = train(esn, target_data, t) +# TODO check types +@testset "Training Algo Tests: $ta" for ta in training_methods + output_layer = train(esn, target_data, ta) output = esn(Predictive(input_data), output_layer) @test mean(abs.(target_data .- output)) ./ mean(abs.(target_data)) < 0.22 end - -for t in training_methods - output_layer = train(esn, target_data, t) - output, states = esn(Predictive(input_data), output_layer, save_states = true) - @test size(states) == (res_size, size(input_data, 2)) -end diff --git a/test/runtests.jl b/test/runtests.jl index bac8443e..27a8ed2c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,12 +7,11 @@ using Test end @testset "Echo State Networks" begin - @safetestset "ESN Input Layers" include("esn/test_input_layers.jl") - @safetestset "ESN Reservoirs" include("esn/test_reservoirs.jl") - @safetestset "ESN States" include("esn/test_states.jl") + @safetestset "ESN Input Layers" include("esn/test_inits.jl") @safetestset "ESN Train and Predict" include("esn/test_train.jl") @safetestset "ESN Drivers" include("esn/test_drivers.jl") @safetestset "Hybrid ESN" include("esn/test_hybrid.jl") + @safetestset "Deep ESN" include("esn/deepesn.jl") end @testset "CA based Reservoirs" begin diff --git a/test/test_states.jl b/test/test_states.jl index c8808bbf..cd776715 100644 --- a/test/test_states.jl +++ b/test/test_states.jl @@ -1,51 +1,33 @@ using ReservoirComputing -#padding test_array = [1, 2, 3, 4, 5, 6, 7, 8, 9] -standard_array = zeros(length(test_array), 1) extension = [0, 0, 0] -padded_array = zeros(length(test_array) + 1, 1) -extended_array = zeros(length(test_array) + length(extension), 1) -padded_extended_array = zeros(length(test_array) + length(extension) + 1, 1) padding = 10.0 - -#testing non linear algos -nla_array = ReservoirComputing.nla(NLADefault(), test_array) -@test nla_array == test_array - -nla_array = ReservoirComputing.nla(NLAT1(), test_array) -@test nla_array == [1, 2, 9, 4, 25, 6, 49, 8, 81] - -nla_array = ReservoirComputing.nla(NLAT2(), test_array) -@test nla_array == [1, 2, 2, 4, 12, 6, 30, 8, 9] - -nla_array = ReservoirComputing.nla(NLAT3(), test_array) -@test nla_array == [1, 2, 8, 4, 24, 6, 48, 8, 9] - -#testing padding and extension -states_type = StandardStates() -standard_array = states_type(NLADefault(), test_array, extension) -@test standard_array == test_array - -states_type = PaddedStates(padding = padding) -padded_array = states_type(NLADefault(), test_array, extension) -@test padded_array == reshape(vcat(padding, test_array), length(test_array) + 1, 1) - -states_type = PaddedStates(padding) -padded_array = states_type(NLADefault(), test_array, extension) -@test padded_array == reshape(vcat(padding, test_array), length(test_array) + 1, 1) - -states_type = PaddedExtendedStates(padding = padding) -padded_extended_array = states_type(NLADefault(), test_array, extension) -@test padded_extended_array == reshape(vcat(padding, extension, test_array), - length(test_array) + length(extension) + 1, 1) - -states_type = PaddedExtendedStates(padding) -padded_extended_array = states_type(NLADefault(), test_array, extension) -@test padded_extended_array == reshape(vcat(padding, extension, test_array), - length(test_array) + length(extension) + 1, 1) - -states_type = ExtendedStates() -extended_array = states_type(NLADefault(), test_array, extension) -@test extended_array == vcat(extension, test_array) -#reshape(vcat(extension, test_array), length(test_array)+length(extension), 1) +test_types = [Float64, Float32, Float16] + +nlas = [(NLADefault(), test_array), + (NLAT1(), [1, 2, 9, 4, 25, 6, 49, 8, 81]), + (NLAT2(), [1, 2, 2, 4, 12, 6, 30, 8, 9]), + (NLAT3(), [1, 2, 8, 4, 24, 6, 48, 8, 9])] + +pes = [(StandardStates(), test_array), + (PaddedStates(padding = padding), + reshape(vcat(padding, test_array), length(test_array) + 1, 1)), + (PaddedExtendedStates(padding = padding), + reshape(vcat(padding, extension, test_array), + length(test_array) + length(extension) + 1, + 1)), + (ExtendedStates(), vcat(extension, test_array))] + +@testset "States Testing" for T in test_types + @testset "Nonlinear Algorithms Testing: $algo $T" for (algo, expected_output) in nlas + nla_array = ReservoirComputing.nla(algo, T.(test_array)) + @test nla_array == expected_output + @test eltype(nla_array) == T + end + @testset "States Testing: $state_type $T" for (state_type, expected_output) in pes + states_output = state_type(NLADefault(), T.(test_array), T.(extension)) + @test states_output == expected_output + @test eltype(states_output) == T + end +end