Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes to training #205

Merged
merged 1 commit into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,22 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CellularAutomata = "878138dc-5b27-11ea-1a71-cb95d38d6b29"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
LIBSVM = "b1bec4e5-fd48-53fe-b0cb-9723c09d164b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLJLinearModels = "6ee0df7b-362f-4a72-a706-9e79364fb692"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"

[weakdeps]
LIBSVM = "b1bec4e5-fd48-53fe-b0cb-9723c09d164b"
MLJLinearModels = "6ee0df7b-362f-4a72-a706-9e79364fb692"

[extensions]
RCMLJLinearModelsExt = "MLJLinearModels"
RCLIBSVMExt = "LIBSVM"

[compat]
Adapt = "3.3.3, 4"
Aqua = "0.8"
Expand Down Expand Up @@ -46,4 +52,4 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "Test", "SafeTestsets", "Random", "DifferentialEquations"]
test = ["Aqua", "Test", "SafeTestsets", "Random", "DifferentialEquations", "MLJLinearModels", "LIBSVM"]
32 changes: 32 additions & 0 deletions ext/RCLIBSVMExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
module RCLIBSVMExt
using ReservoirComputing
using LIBSVM

function ReservoirComputing.train(svr::LIBSVM.AbstractSVR, states, target)
out_size = size(target, 1)
output_matrix = []

if out_size == 1
output_matrix = LIBSVM.fit!(svr, states', vec(target))
else
for i in 1:out_size
push!(output_matrix, LIBSVM.fit!(svr, states', target[i, :]))
end
end

return OutputLayer(svr, output_matrix, out_size, target[:, end])
end

function ReservoirComputing.get_prediction(
training_method::LIBSVM.AbstractSVR, output_layer, x)
out = zeros(output_layer.out_size)

for i in 1:(output_layer.out_size)
x_new = reshape(x, 1, length(x))
out[i] = LIBSVM.predict(output_layer.output_matrix[i], x_new)[1]
end

return out
end

end #module
25 changes: 25 additions & 0 deletions ext/RCMLJLinearModelsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
module RCMLJLinearModelsExt
using ReservoirComputing
using MLJLinearModels

function ReservoirComputing.train(regressor::MLJLinearModels.GeneralizedLinearRegression,
states::AbstractArray{T},
target::AbstractArray{T};
kwargs...) where {T <: Number}
out_size = size(target, 1)
output_layer = similar(target, size(target, 1), size(states, 1))

if regressor.fit_intercept
throw(ArgumentError("fit_intercept=true is not yet supported.
Please add fit_intercept=false to the MLJ regressor"))
end

for i in axes(target, 1)
output_layer[i, :] = MLJLinearModels.fit(regressor, states',
target[i, :]; kwargs...)
end

return OutputLayer(regressor, output_layer, out_size, target[:, end])
end

end #module
15 changes: 7 additions & 8 deletions src/ReservoirComputing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@ using Adapt
using CellularAutomata
using Distances
using Distributions
using LIBSVM
using LinearAlgebra
using MLJLinearModels
using NNlib
using Optim
using PartialFunctions
Expand All @@ -16,7 +14,7 @@ using WeightInitializers

export NLADefault, NLAT1, NLAT2, NLAT3
export StandardStates, ExtendedStates, PaddedStates, PaddedExtendedStates
export StandardRidge, LinearModel
export StandardRidge
export scaled_rand, weighted_init, informed_init, minimal_init
export rand_sparse, delay_line, delay_line_backward, cycle_jumps, simple_cycle, pseudo_svd
export RNN, MRNN, GRU, GRUParams, FullyGated, Minimal
Expand All @@ -31,11 +29,7 @@ export Generative, Predictive, OutputLayer
abstract type AbstractReservoirComputer end
abstract type AbstractOutputLayer end
abstract type AbstractPrediction end
#training methods
abstract type AbstractLinearModel end
abstract type AbstractSupportVector end
#should probably move some of these
abstract type AbstractVariation end
abstract type AbstractGRUVariant end

#general output layer struct
Expand Down Expand Up @@ -104,7 +98,6 @@ include("predict.jl")

#general training
include("train/linear_regression.jl")
include("train/supportvector_regression.jl")

#esn
include("esn/esn_input_layers.jl")
Expand All @@ -119,4 +112,10 @@ include("esn/esn_predict.jl")
include("reca/reca.jl")
include("reca/reca_input_encodings.jl")

# Julia < 1.9 support
if !isdefined(Base, :get_extension)
include("../ext/RCMLJLinearModelsExt.jl")
include("../ext/RCLIBSVMExt.jl")
end

end #module
5 changes: 3 additions & 2 deletions src/esn/esn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,11 @@ trained_esn = train(esn, target_data, training_method = StandardRidge(1.0))
"""
function train(esn::AbstractEchoStateNetwork,
target_data,
training_method = StandardRidge(0.0))
training_method = StandardRidge();
kwargs...)
states_new = esn.states_type(esn.nla_type, esn.states, esn.train_data[:, 1:end])

return _train(states_new, target_data, training_method)
return train(training_method, states_new, target_data; kwargs...)
end

#function pad_esnstate(variation::Hybrid, states_type, x_pad, x, model_prediction_data)
Expand Down
5 changes: 3 additions & 2 deletions src/esn/hybridesn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,10 @@ end

function train(hesn::HybridESN,
target_data,
training_method = StandardRidge(0.0))
training_method = StandardRidge();
kwargs...)
states = vcat(hesn.states, hesn.model.model_data[:, 2:end])
states_new = hesn.states_type(hesn.nla_type, states, hesn.train_data[:, 1:end])

return _train(states_new, target_data, training_method)
return train(training_method, states_new, target_data; kwargs...)
end
14 changes: 1 addition & 13 deletions src/predict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,10 @@ function obtain_prediction(rc::AbstractReservoirComputer,
end

#linear models
function get_prediction(training_method::AbstractLinearModel, output_layer, x)
function get_prediction(training_method, output_layer, x)
return output_layer.output_matrix * x
end

#support vector regression
function get_prediction(training_method::LIBSVM.AbstractSVR, output_layer, x)
out = zeros(output_layer.out_size)

for i in 1:(output_layer.out_size)
x_new = reshape(x, 1, length(x))
out[i] = LIBSVM.predict(output_layer.output_matrix[i], x_new)[1]
end

return out
end

#single matrix for other training methods
function output_storing(training_method, out_size, prediction_len, storing_type)
return Adapt.adapt(storing_type, zeros(out_size, prediction_len))
Expand Down
4 changes: 2 additions & 2 deletions src/reca/reca.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ function RECA(train_data,
end

#training dispatch
function train(reca::AbstractReca, target_data, training_method = StandardRidge(0.0))
function train(reca::AbstractReca, target_data, training_method = StandardRidge; kwargs...)
states_new = reca.states_type(reca.nla_type, reca.states, reca.train_data)
return _train(states_new, target_data, training_method)
return train(training_method, states_new, target_data; kwargs...)
end

#predict dispatch
Expand Down
72 changes: 15 additions & 57 deletions src/train/linear_regression.jl
Original file line number Diff line number Diff line change
@@ -1,64 +1,22 @@
struct StandardRidge{T} <: AbstractLinearModel
regularization_coeff::T
struct StandardRidge
reg::Number
end

"""
StandardRidge(regularization_coeff)
StandardRidge(;regularization_coeff=0.0)

Ridge regression training for all the models in the library. The
`regularization_coeff` is the regularization, it can be passed as an arg or kwarg.
"""
function StandardRidge(; regularization_coeff = 0.0)
return StandardRidge(regularization_coeff)
end

#default training - OLS
function _train(states, target_data, sr::StandardRidge = StandardRidge(0.0))
output_layer = ((states * states' + sr.regularization_coeff * I) \
(states * target_data'))'
#output_layer = (target_data*states')*inv(states*states'+sr.regularization_coeff*I)
return OutputLayer(sr, output_layer, size(target_data, 1), target_data[:, end])
function StandardRidge(::Type{T}, reg) where {T <: Number}
return StandardRidge(T.(reg))
end

#mlj interface
struct LinearModel{T, S, K} <: AbstractLinearModel
regression::T
solver::S
regression_kwargs::K
function StandardRidge()
return StandardRidge(0.0)
end

"""
LinearModel(;regression=LinearRegression,
solver=Analytical(),
regression_kwargs=(;))

Linear regression training based on
[MLJLinearModels](https://juliaai.github.io/MLJLinearModels.jl/stable/) for all the
models in the library. All the parameters have to be passed into `regression_kwargs`,
apart from the solver choice. MLJLinearModels.jl needs to be called in order
to use these models.
"""
function LinearModel(; regression = LinearRegression,
solver = Analytical(),
regression_kwargs = (;))
return LinearModel(regression, solver, regression_kwargs)
end

function LinearModel(regression;
solver = Analytical(),
regression_kwargs = (;))
return LinearModel(regression, solver, regression_kwargs)
end

function _train(states, target_data, linear::LinearModel)
out_size = size(target_data, 1)
output_layer = zeros(size(target_data, 1), size(states, 1))
for i in 1:size(target_data, 1)
regressor = linear.regression(; fit_intercept = false, linear.regression_kwargs...)
output_layer[i, :] = MLJLinearModels.fit(regressor, states',
target_data[i, :], solver = linear.solver)
end

return OutputLayer(linear, output_layer, out_size, target_data[:, end])
function train(sr::StandardRidge,
states::AbstractArray{T},
target_data::AbstractArray{T}) where {T <: Number}
#A = states * states' + sr.reg * I
#b = states * target_data
#output_layer = (A \ b)'
output_layer = Matrix(((states * states' + sr.reg * I) \
(states * target_data'))')
return OutputLayer(sr, output_layer, size(target_data, 1), target_data[:, end])
end
11 changes: 0 additions & 11 deletions src/train/supportvector_regression.jl

This file was deleted.

30 changes: 18 additions & 12 deletions test/esn/test_train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,24 @@ const reg = 10e-6
Random.seed!(77)
res = rand_sparse(; radius = 1.2, sparsity = 0.1)
esn = ESN(input_data, 1, res_size;
reservoir = rand_sparse)

training_methods = [
StandardRidge(regularization_coeff = reg),
LinearModel(RidgeRegression, regression_kwargs = (; lambda = reg)),
LinearModel(regression = RidgeRegression, regression_kwargs = (; lambda = reg)),
EpsilonSVR()
]
reservoir = res)
# different models that implement a train dispatch
# TODO add classification
linear_training = [StandardRidge(0.0), LinearRegression(; fit_intercept = false),
RidgeRegression(; fit_intercept = false), LassoRegression(; fit_intercept = false),
ElasticNetRegression(; fit_intercept = false), HuberRegression(; fit_intercept = false),
QuantileRegression(; fit_intercept = false), LADRegression(; fit_intercept = false)]
svm_training = [EpsilonSVR(), NuSVR()]

# TODO check types
@testset "Training Algo Tests: $ta" for ta in training_methods
output_layer = train(esn, target_data, ta)
output = esn(Predictive(input_data), output_layer)
@test mean(abs.(target_data .- output)) ./ mean(abs.(target_data)) < 0.22
@testset "Linear training: $lt" for lt in linear_training
output_layer = train(esn, target_data, lt)
@test output_layer isa OutputLayer
@test output_layer.output_matrix isa AbstractArray
end

@testset "SVM training: $st" for st in svm_training
output_layer = train(esn, target_data, st)
@test output_layer isa OutputLayer
@test output_layer.output_matrix isa typeof(st)
end
Loading