Skip to content

Commit

Permalink
Merge pull request #55 from darsnack/optimisers-jl-scheduler
Browse files Browse the repository at this point in the history
Add initial Optimisers.jl scheduler
  • Loading branch information
darsnack authored Feb 3, 2024
2 parents 1059200 + 112f577 commit 7b3c081
Show file tree
Hide file tree
Showing 10 changed files with 148 additions and 176 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ jobs:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
with:
version: '1.6'
version: '1'
- run: |
julia --project=docs -e '
using Pkg
Expand Down
9 changes: 5 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
name = "ParameterSchedulers"
uuid = "d7d3b36b-41b8-4d0d-a2bf-768c6151755e"
authors = ["Kyle Daruwalla"]
version = "0.3.7"
version = "0.4.0"

[deps]
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
InfiniteArrays = "4858937d-0d70-526a-a4dd-2d5cb5dd786c"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"

[compat]
Flux = "0.11.2, 0.12, 0.13, 0.14"
InfiniteArrays = "0.10.4, 0.11, 0.12, 0.13"
Optimisers = "0.3.1"
julia = "1.6"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[publish]
ignore = ["^(gh-pages|juliamnt|julia.dmg)$"]
theme = "_flux-theme"
title = "ParameterSchedulers.jl"

[targets]
test = ["Test"]
test = ["Test", "Zygote"]
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ ParameterSchedulers.jl provides common machine learning (ML) schedulers for hype
using Flux, ParameterSchedulers
using ParameterSchedulers: Scheduler

opt = Scheduler(Exp= 1e-2, γ = 0.8), Momentum())
opt = Scheduler(Momentum, Exp= 1e-2, γ = 0.8))
```

## Available Schedules
Expand Down
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
ParameterSchedulers = "d7d3b36b-41b8-4d0d-a2bf-768c6151755e"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
17 changes: 10 additions & 7 deletions docs/src/tutorials/complex-schedules.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,19 +78,22 @@ Notice that our schedule changes around 1 second (half way through the simulatio
For the second example, we'll look at a machine learning use-case. We want to write our schedule in terms of epochs, but our training loop iterates the scheduler every mini-batch.
```@example complex-schedules
using Flux
using Optimisers
using ParameterSchedulers: Scheduler
nepochs = 3
data = [(rand(4, 10), rand([-1, 1], 1, 10)) for _ in 1:3]
data = [(Flux.rand32(4, 10), rand([-1, 1], 1, 10)) for _ in 1:3]
m = Chain(Dense(4, 4, tanh), Dense(4, 1, tanh))
p = Flux.params(m)
s = Interpolator(Sequence(1e-2 => 1, Exp(1e-2, 2.0) => 2), length(data))
opt = Scheduler(s, Descent())
s = Interpolator(Sequence(1f-2 => 1, Exp(1f-2, 2f0) => 2), length(data))
opt = Scheduler(Optimisers.Descent, s)
opt_st = Flux.setup(opt, m)
for epoch in 1:nepochs
for (i, (x, y)) in enumerate(data)
g = Flux.gradient(() -> Flux.mse(m(x), y), p)
Flux.update!(opt, p, g)
println("epoch: $epoch, batch: $i, η: $(opt.optim.eta)")
global opt_st, m
step = opt_st.layers[1].weight.state.t
println("epoch: $epoch, batch: $i, sched step = $step")
g = Flux.gradient(m -> Flux.mse(m(x), y), m)[1]
opt_st, m = Flux.update!(opt_st, m, g)
end
end
```
Expand Down
61 changes: 33 additions & 28 deletions docs/src/tutorials/optimizers.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,35 @@ A schedule by itself is not helpful; we need to use the schedules to adjust para
Since every schedule is a standard iterator, we can insert it into a training loop by simply zipping up with another iterator. For example, the following code adjusts the learning rate of the optimizer before each batch of training.
```@example optimizers
using Flux, ParameterSchedulers
using Optimisers: Descent, adjust!
data = [(Flux.rand32(4, 10), rand([-1, 1], 1, 10)) for _ in 1:3]
m = Chain(Dense(4, 4, tanh), Dense(4, 1, tanh))
p = Flux.params(m)
opt = Descent()
opt_st = Flux.setup(opt, m)
s = Exp(λ = 1e-1, γ = 0.2)
for (η, (x, y)) in zip(s, data)
opt.eta = η
g = Flux.gradient(() -> Flux.mse(m(x), y), p)
Flux.update!(opt, p, g)
println("η: ", opt.eta)
for (eta, (x, y)) in zip(s, data)
global opt_st, m
adjust!(opt_st, eta)
g = Flux.gradient(m -> Flux.mse(m(x), y), m)[1]
opt_st, m = Flux.update!(opt_st, m, g)
println("opt state: ", opt_st.layers[1].weight.rule)
end
```

We can also adjust the learning on an epoch basis instead. All that is required is to change what we zip our schedule with.
```@example optimizers
nepochs = 6
s = Step(λ = 1e-1, γ = 0.2, step_sizes = [3, 2, 1])
for (η, epoch) in zip(s, 1:nepochs)
opt.eta = η
for (eta, epoch) in zip(s, 1:nepochs)
global opt_st
adjust!(opt_st, eta)
for (i, (x, y)) in enumerate(data)
g = Flux.gradient(() -> Flux.mse(m(x), y), p)
Flux.update!(opt, p, g)
println("epoch: $epoch, batch: $i, η: $(opt.eta)")
global m
g = Flux.gradient(m -> Flux.mse(m(x), y), m)[1]
opt_st, m = Flux.update!(opt_st, m, g)
println("epoch: $epoch, batch: $i, opt state: $(opt_st.layers[1].weight.rule)")
end
end
```
Expand All @@ -45,44 +49,45 @@ nepochs = 3
s = ParameterSchedulers.Stateful(Inv(λ = 1e-1, γ = 0.2, p = 2))
for epoch in 1:nepochs
for (i, (x, y)) in enumerate(data)
opt.eta = ParameterSchedulers.next!(s)
g = Flux.gradient(() -> Flux.mse(m(x), y), p)
Flux.update!(opt, p, g)
println("epoch: $epoch, batch: $i, η: $(opt.eta)")
global opt_st, m
adjust!(opt_st, ParameterSchedulers.next!(s))
g = Flux.gradient(m -> Flux.mse(m(x), y), m)[1]
opt_st, m = Flux.update!(opt_st, m, g)
println("epoch: $epoch, batch: $i, opt state: $(opt_st.layers[1].weight.rule)")
end
end
```

## Working with Flux optimizers

!!! warning
Currently, we are porting `Scheduler` to Flux.jl.
It may be renamed once it is ported out of this package.
The API will also undergo minor changes.

While the approaches above can be helpful when dealing with fine-grained training loops, it is usually simpler to just use a [`ParameterSchedulers.Scheduler`](@ref).
```@example optimizers
using ParameterSchedulers: Scheduler
nepochs = 3
s = Inv(λ = 1e-1, p = 2, γ = 0.2)
opt = Scheduler(s, Descent())
opt = Scheduler(Descent, s)
opt_st = Flux.setup(opt, m)
for epoch in 1:nepochs
for (i, (x, y)) in enumerate(data)
g = Flux.gradient(() -> Flux.mse(m(x), y), p)
Flux.update!(opt, p, g)
println("epoch: $epoch, batch: $i, η: $(opt.optim.eta)")
global opt_st, m
sched_step = opt_st.layers[1].weight.state.t
println("epoch: $epoch, batch: $i, sched state: $sched_step")
g = Flux.gradient(m -> Flux.mse(m(x), y), m)[1]
opt_st, m = Flux.update!(opt_st, m, g)
end
end
```
The scheduler, `opt`, can be used anywhere a Flux optimizer can. For example, it can be passed to `Flux.train!`:
```@example optimizers
s = Inv(λ = 1e-1, p = 2, γ = 0.2)
opt = Scheduler(s, Descent())
loss(x, y, m) = Flux.mse(m(x), y)
cb = () -> @show(opt.optim.eta)
opt = Scheduler(Descent, s)
opt_st = Flux.setup(opt, m)
loss(m, x, y) = Flux.mse(m(x), y)
for epoch in 1:nepochs
Flux.train!((x, y) -> loss(x, y, m), Flux.params(m), data, opt, cb = cb)
sched_step = opt_st.layers[1].weight.state.t
println("epoch: $epoch, sched state: $sched_step")
Flux.train!(loss, m, data, opt_st)
end
```

Expand Down
83 changes: 4 additions & 79 deletions src/ParameterSchedulers.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
module ParameterSchedulers

using Base.Iterators
using Flux
using InfiniteArrays: OneToInf
using Optimisers: AbstractRule
import Optimisers

include("interface.jl")

Expand All @@ -19,83 +20,7 @@ export Sequence, Loop, Interpolator, Shifted, ComposedSchedule

include("utils.jl")

# TODO
# Remove this once Optimisers.jl has support
# for schedules + optimizers
"""
Scheduler{T, O, F}(schedule::AbstractSchedule, opt, update_func)
Scheduler(schedule, opt; update_func = (o, s) -> (o.eta = s))
Wrap a `schedule` and `opt` together with a `Scheduler`.
The `schedule` is iterated on every call to
[`Flux.apply!`](https://github.com/FluxML/Flux.jl/blob/master/src/optimise/optimisers.jl).
The `Scheduler` can be used anywhere a Flux optimizer is used.
By default, the learning rate (i.e. `opt.eta`) is scheduled.
Set `update_func = (opt, schedule_val) -> ...` to schedule an alternate field.
If `opt` does not have a field `eta`, then there is no default behavior
(you must manually set `update_func`).
# Arguments
- `schedule`: the schedule to use
- `opt`: a Flux optimizer
- `update_func`: a mutating function of with inputs `(optim, param)`
that mutates `optim`'s fields based on the current `param` value
# Examples
```julia
# cosine annealing schedule for Descent
julia> s = CosAnneal(λ0 = 0.1, λ1 = 0.8, period = 10);
julia> opt = Scheduler(s, Descent())
Scheduler(CosAnneal{Float64,Int64}(0.1, 0.8, 10), Descent(0.1))
# schedule the momentum term of Momentum
julia> opt = Scheduler(s, Momentum(); update_func = (o, s) -> o.rho = s)
Scheduler(CosAnneal{Float64,Int64}(0.1, 0.8, 10), Momentum(0.01, 0.9, IdDict{Any,Any}()))
```
"""
mutable struct Scheduler{T, O, F} <: Flux.Optimise.AbstractOptimiser
state::IdDict{Any, Int}
schedule::T
optim::O
update_func::F

function Scheduler(state::IdDict{Any, Int},
schedule::T,
optim::O,
update_func::F) where {T, O, F}
Base.depwarn("""`Scheduler` will transition to explicit Optimisers.jl style
optimizers in the next release""", :Scheduler)

return new{T, O, F}(state, schedule, optim, update_func)
end
end
Scheduler(schedule, opt, update_func) =
Scheduler(IdDict{Any, Int}(), schedule, opt, update_func)

Base.show(io::IO, s::Scheduler) =
print(io, "Scheduler(", s.schedule, ", ", s.optim, ")")

function Flux.Optimise.apply!(opt::Scheduler, x, Δ)
# get iteration
t = get!(opt.state, x, 1)
opt.state[x] = t + 1

# set param
opt.update_func(opt.optim, opt.schedule(t))

# do normal apply
return Flux.Optimise.apply!(opt.optim, x, Δ)
end

for Opt in (Descent, Momentum, Nesterov, RMSProp,
Adam, RAdam, AdaMax, OAdam, AdaGrad,
AdaDelta, AMSGrad, NAdam, AdaBelief)
@eval begin
Scheduler(schedule, opt::$Opt; update_func = (o, s) -> (o.eta = s)) =
Scheduler(schedule, opt, update_func)
end
end
include("scheduler.jl")
export Scheduler

end
50 changes: 6 additions & 44 deletions src/cyclic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,8 @@ struct Triangle{T, S<:Integer} <: AbstractSchedule{false}
offset::T
period::S
end
function Triangle(range::T, offset::T, period::S) where {T, S}
@warn """Triangle(range0, range1, period) is now Triangle(range, offset, period).
To specify by endpoints, use the keyword argument form.
This message will be removed in the next version.""" _id=(:tri) maxlog=1

Triangle(range::T, offset::T, period::S) where {T, S} =
Triangle{T, S}(range, offset, period)
end
Triangle(;λ0, λ1, period) = Triangle(abs(λ0 - λ1), min(λ0, λ1), period)

Base.eltype(::Type{<:Triangle{T}}) where T = T
Expand All @@ -54,13 +49,7 @@ where `Triangle(t)` is `(2 / π) * abs(asin(sin(π * (t - 1) / schedule.period))
- `range1`/`λ1`: the second range endpoint
- `period::Integer`: the period
"""
function TriangleDecay2(range, offset, period)
@warn """TriangleDecay2(range0, range1, period) is now TriangleDecay2(range, offset, period).
To specify by endpoints, use the keyword argument form.
This message will be removed in the next version.""" _id=(:tri) maxlog=1

return _tridecay2(range, offset, period)
end
TriangleDecay2(range, offset, period) = _tridecay2(range, offset, period)
TriangleDecay2(;λ0, λ1, period) = _tridecay2(abs(λ0 - λ1), min(λ0, λ1), period)

function _tridecay2(range::T, offset, period) where T
Expand Down Expand Up @@ -89,13 +78,7 @@ where `Triangle(t)` is `(2 / π) * abs(asin(sin(π * (t - 1) / schedule.period))
- `period::Integer`: the period
- `decay`/`γ`: the decay rate
"""
function TriangleExp(range, offset, period, γ)
@warn """TriangleExp(range0, range1, period, γ) is now TriangleExp(range, offset, period, γ).
To specify by endpoints, use the keyword argument form.
This message will be removed in the next version.""" _id=(:tri) maxlog=1

return _triexp(range, offset, period, γ)
end
TriangleExp(range, offset, period, γ) = _triexp(range, offset, period, γ)
TriangleExp(;λ0, λ1, period, γ) = _triexp(abs(λ0 - λ1), min(λ0, λ1), period, γ)

_triexp(range, offset, period, γ) =
Expand All @@ -121,13 +104,7 @@ struct Sin{T, S<:Integer} <: AbstractSchedule{false}
offset::T
period::S
end
function Sin(range::T, offset::T, period::S) where {T, S}
@warn """Sin(range0, range1, period) is now Sin(range, offset, period).
To specify by endpoints, use the keyword argument form.
This message will be removed in the next version.""" _id=(:sine) maxlog=1

Sin{T, S}(range, offset, period)
end
Sin(range::T, offset::T, period::S) where {T, S} = Sin{T, S}(range, offset, period)
Sin(;λ0, λ1, period) = Sin(abs(λ0 - λ1), min(λ0, λ1), period)

Base.eltype(::Type{<:Sin{T}}) where T = T
Expand All @@ -150,13 +127,7 @@ where `Sin(t)` is `abs(sin(π * (t - 1) / period))` (see [`Sin`](@ref)).
- `offset == min(λ0, λ1)`: the offset / minimum value
- `period::Integer`: the period
"""
function SinDecay2(range, offset, period)
@warn """SinDecay2(range0, range1, period) is now SinDecay2(range, offset, period).
To specify by endpoints, use the keyword argument form.
This message will be removed in the next version.""" _id=(:sine) maxlog=1

return _sindecay2(range, offset, period)
end
SinDecay2(range, offset, period) = _sindecay2(range, offset, period)
SinDecay2(;λ0, λ1, period) = _sindecay2(abs(λ0 - λ1), min(λ0, λ1), period)

function _sindecay2(range::T, offset, period) where T
Expand All @@ -182,13 +153,7 @@ where `Sin(t)` is `abs(sin(π * (t - 1) / period))` (see [`Sin`](@ref)).
- `period::Integer`: the period
- `γ`: the decay rate
"""
function SinExp(range, offset, period, γ)
@warn """SinExp(range0, range1, period, γ) is now SinExp(range, offset, period, γ).
To specify by endpoints, use the keyword argument form.
This message will be removed in the next version.""" _id=(:sine) maxlog=1

return _sinexp(range, offset, period, γ)
end
SinExp(range, offset, period, γ) = _sinexp(range, offset, period, γ)
SinExp(;λ0, λ1, period, γ) = _sinexp(abs(λ0 - λ1), min(λ0, λ1), period, γ)

_sinexp(range, offset, period, γ) =
Expand Down Expand Up @@ -231,6 +196,3 @@ function (schedule::CosAnneal)(t)

return schedule.range * (1 + cos*/ schedule.period)) / 2 + schedule.offset
end

Base.@deprecate Cos(range0, range1, period) CosAnneal(λ0 = range0, λ1 = range1, period = period)
Base.@deprecate Cos(;λ0, λ1, period) CosAnneal(λ0 = λ0, λ1 = λ1, period = period)
Loading

2 comments on commit 7b3c081

@darsnack
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Release notes:

Breaking changes

Scheduler now supports explicit Optimisers.jl style optimizers only. Support for Flux implicit optimizers has been dropped.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/100173

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.0 -m "<description of version>" 7b3c0810ba651f48cdc5c77e74aa2e337353e1ad
git push origin v0.4.0

Please sign in to comment.