Skip to content

Commit

Permalink
Merge pull request #168 from FluxML/dev
Browse files Browse the repository at this point in the history
For a 0.2 release
  • Loading branch information
ablaom authored Jun 24, 2021
2 parents f32c21f + 5e5d698 commit d3b8cdb
Show file tree
Hide file tree
Showing 16 changed files with 597 additions and 685 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
*.bu
.DS_Store
sandbox/
docs/build
docs/build
/examples/mnist/mnist_machine*
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJFlux"
uuid = "094fc8d1-fd35-5302-93ea-dabda2abf845"
authors = ["Anthony D. Blaom <[email protected]>", "Ayush Shridhar <[email protected]>"]
version = "0.1.17"
version = "0.2.0"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand All @@ -10,6 +10,7 @@ ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

Expand All @@ -29,9 +30,10 @@ MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJScientificTypes = "2e2323e0-db8b-457b-ae0d-bdfb3bc63afd"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["LinearAlgebra", "MLDatasets", "MLJBase", "MLJScientificTypes", "Random", "Statistics", "StatsBase", "Test"]
test = ["LinearAlgebra", "MLDatasets", "MLJBase", "MLJScientificTypes", "Random", "StableRNGs", "Statistics", "StatsBase", "Test"]
200 changes: 122 additions & 78 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,18 +84,18 @@ NeuralNetworkClassifier = @load NeuralNetworkClassifier

julia> clf = NeuralNetworkClassifier()
NeuralNetworkClassifier(
builder = Short(
n_hidden = 0,
dropout = 0.5,
σ = NNlib.σ),
finaliser = NNlib.softmax,
optimiser = ADAM(0.001, (0.9, 0.999), IdDict{Any,Any}()),
loss = Flux.crossentropy,
epochs = 10,
batch_size = 1,
lambda = 0.0,
alpha = 0.0,
optimiser_changes_trigger_retraining = false) @ 160
builder = Short(
n_hidden = 0,
dropout = 0.5,
σ = NNlib.σ),
finaliser = NNlib.softmax,
optimiser = ADAM(0.001, (0.9, 0.999), IdDict{Any,Any}()),
loss = Flux.crossentropy,
epochs = 10,
batch_size = 1,
lambda = 0.0,
alpha = 0.0,
optimiser_changes_trigger_retraining = false) @ 160
```

#### Incremental training
Expand All @@ -121,8 +121,8 @@ julia> fit!(mach, verbosity=2)
[ Info: Loss is 0.7347
Machine{NeuralNetworkClassifier{Short,},} @804 trained 2 times; caches data
args:
1: Source @985`Table{AbstractVector{Continuous}}`
2: Source @367`AbstractVector{Multiclass{3}}`
1: Source @985`Table{AbstractVector{Continuous}}`
2: Source @367`AbstractVector{Multiclass{3}}`

julia> training_loss = cross_entropy(predict(mach, X), y) |> mean
0.7347092796453824
Expand All @@ -140,15 +140,15 @@ Chain(Chain(Dense(4, 3, σ), Flux.Dropout{Float64}(0.5, false), Dense(3, 3)), so
```julia
r = range(clf, :epochs, lower=1, upper=200, scale=:log10)
curve = learning_curve(clf, X, y,
range=r,
resampling=Holdout(fraction_train=0.7),
measure=cross_entropy)
range=r,
resampling=Holdout(fraction_train=0.7),
measure=cross_entropy)
using Plots
plot(curve.parameter_values,
curve.measurements,
xlab=curve.parameter_name,
xscale=curve.parameter_scale,
ylab = "Cross Entropy")
curve.measurements,
xlab=curve.parameter_name,
xscale=curve.parameter_scale,
ylab = "Cross Entropy")

```
Expand Down Expand Up @@ -239,13 +239,31 @@ CPU at then conclusion of `fit!`, and made available as
`fitted_params(mach)`.
### Random number generators and reproducibility
Every MLJFlux model includes an `rng` hyper-parameter that is passed
to builders for the purposes of weight initialization. This can be
any `AbstractRNG` or the seed (integer) for a `MersenneTwister` that
will be reset on every cold restart of model (machine) training.
Until there is a [mechanism for
doing so](https://github.com/FluxML/Flux.jl/issues/1617) `rng` is *not*
passed to dropout layers and one must manually seed the `GLOBAL_RNG`
for reproducibility purposes, when using a builder that includes
`Dropout` (such as `MLJFlux.Short`). If training models on a
GPU (i.e., `acceleration isa CUDALibs`) one must additionally call
`CUDA.seed!(...)`.
### Built-in builders
MLJ provides two simple builders out of the box:
MLJ provides two simple builders out of the box. In all cases weights
are intitialized using `glorot_uniform(rng)` where `rng` is the RNG
(or `MersenneTwister` seed) specified by the MLJFlux model.
- `MLJFlux.Linear(σ=...)` builds a fully connected two layer
network with `n_in` inputs and `n_out` outputs, with activation
function `σ`, defaulting to a `MLJFlux.relu`.
- `MLJFlux.Linear(σ=...)` builds a fully connected two layer network
with `n_in` inputs and `n_out` outputs, with activation function
`σ`, defaulting to a `MLJFlux.relu`.
- `MLJFlux.Short(n_hidden=..., dropout=..., σ=...)` builds a
full-connected three-layer network with `n_in` inputs and `n_out`
Expand All @@ -268,7 +286,8 @@ All models share the following hyper-parameters:
2. `optimiser`: The optimiser to use for training. Default =
`Flux.ADAM()`
3. `loss`: The loss function used for training. Default = `Flux.mse` (regressors) and `Flux.crossentropy` (classifiers)
3. `loss`: The loss function used for training. Default = `Flux.mse`
(regressors) and `Flux.crossentropy` (classifiers)
4. `n_epochs`: Number of epochs to train for. Default = `10`
Expand All @@ -278,9 +297,15 @@ All models share the following hyper-parameters:
7. `alpha`: The L2/L1 mix of regularization. Default = 0. Range = [0, 1]
8. `acceleration`: Use `CUDALibs()` for training on GPU; default is `CPU1()`.
8. `rng`: The random number generator (RNG) passed to builders, for
weight intitialization, for example. Can be any `AbstractRNG` or
the seed (integer) for a `MersenneTwister` that is reset on every
cold restart of model (machine) training. Default =
`GLOBAL_RNG`.
9. `acceleration`: Use `CUDALibs()` for training on GPU; default is `CPU1()`.
9. `optimiser_changes_trigger_retraining`: True if fitting an
10. `optimiser_changes_trigger_retraining`: True if fitting an
associated machine should trigger retraining from scratch whenever
the optimiser changes. Default = `false`
Expand Down Expand Up @@ -309,13 +334,16 @@ any of the first three models in Table 1. The definition includes one
mutable struct and one method:
```julia
mutable struct MyNetwork <: MLJFlux.Builder
n1 :: Int
n2 :: Int
mutable struct MyBuilder <: MLJFlux.Builder
n1 :: Int
n2 :: Int
end

function MLJFlux.build(nn::MyNetwork, n_in, n_out)
return Chain(Dense(n_in, nn.n1), Dense(nn.n1, nn.n2), Dense(nn.n2, n_out))
function MLJFlux.build(nn::MyBuilder, rng, n_in, n_out)
init = Flux.glorot_uniform(rng)
return Chain(Dense(n_in, nn.n1, init=init),
Dense(nn.n1, nn.n2, init=init),
Dense(nn.n2, n_out, init=init))
end
```
Expand All @@ -330,21 +358,22 @@ sub-typing `MLJFlux.Builder` and defining a new `MLJFlux.build` method
with one of these signatures:
```julia
MLJFlux.build(builder::MyNetwork, n_in, n_out)
MLJFlux.build(builder::MyNetwork, n_in, n_out, n_channels) # for use with `ImageClassifier`
MLJFlux.build(builder::MyBuilder, rng, n_in, n_out)
MLJFlux.build(builder::MyBuilder, rng, n_in, n_out, n_channels) # for use with `ImageClassifier`
```
This method must return a `Flux.Chain` instance, `chain`, subject to the
following conditions:
- `chain(x)` must make sense:
- for any `x <: Vector{<:AbstractFloat}` of length `n_in` (for use
with one of the first three model types); or
- for any `x <: Array{<:AbstractFloat, 2}` of size `(n_in,
batch_size)` where `batch_size` is any integer (for use with one
of the first three model types); or
- for any `x <: Array{<:Float32, 4}` of size `(W, H, n_channels,
batch_size)`, where `(W, H) = n_in`, `n_channels` is 1 or 3, and
`batch_size` is any integer (for use with `ImageClassifier`)
- for any `x <: Array{<:Float32, 4}` of size `(W, H, n_channels,
batch_size)`, where `(W, H) = n_in`, `n_channels` is 1 or 3, and
`batch_size` is any integer (for use with `ImageClassifier`)
- The object returned by `chain(x)` must be an `AbstractFloat` vector
of length `n_out`.
Expand Down Expand Up @@ -388,76 +417,74 @@ using MLDatasets

# helper function
function flatten(x::AbstractArray)
return reshape(x, :, size(x)[end])
return reshape(x, :, size(x)[end])
end

import MLJFlux
mutable struct MyConvBuilder
filter_size::Int
channels1::Int
channels2::Int
channels3::Int
filter_size::Int
channels1::Int
channels2::Int
channels3::Int
end

function MLJFlux.build(b::MyConvBuilder, n_in, n_out, n_channels)

k, c1, c2, c3 = b.filter_size, b.channels1, b.channels2, b.channels3
function MLJFlux.build(b::MyConvBuilder, rng, n_in, n_out, n_channels)

mod(k, 2) == 1 || error("`filter_size` must be odd. ")
k, c1, c2, c3 = b.filter_size, b.channels1, b.channels2, b.channels3

# padding to preserve image size on convolution:
p = div(k - 1, 2)
mod(k, 2) == 1 || error("`filter_size` must be odd. ")

# compute size, in first two dims, of output of final maxpool layer:
half(x) = div(x, 2)
h = n_in[1] |> half |> half |> half
w = n_in[2] |> half |> half |> half
# padding to preserve image size on convolution:
p = div(k - 1, 2)

return Chain(
Conv((k, k), n_channels => c1, pad=(p, p), relu),
MaxPool((2, 2)),
Conv((k, k), c1 => c2, pad=(p, p), relu),
MaxPool((2, 2)),
Conv((k, k), c2 => c3, pad=(p, p), relu),
MaxPool((2 ,2)),
flatten,
Dense(h*w*c3, n_out))
front = Chain(
Conv((k, k), n_channels => c1, pad=(p, p), relu),
MaxPool((2, 2)),
Conv((k, k), c1 => c2, pad=(p, p), relu),
MaxPool((2, 2)),
Conv((k, k), c2 => c3, pad=(p, p), relu),
MaxPool((2 ,2)),
flatten)
d = Flux.outputsize(front, (n_in..., n_channels, 1)) |> first
return Chain(front, Dense(d, n_out))
end
```
Next, we load some of the MNIST data and check scientific types
conform to those is the table above:
```julia
N = 1000
X, y = MNIST.traindata();
N = 500
Xraw, yraw = MNIST.traindata();
Xraw = Xraw[:,:,1:N];
yraw = yraw[1:N];

julia> scitype(X)
AbstractArray{GrayImage{28,28},1}
julia> scitype(Xraw)
AbstractArray{Unknown, 3}

julia> scitype(y)
julia> scitype(yraw)
AbstractArray{Count,1}
```
Inputs should have scitype `GrayImage`
Inputs should have element scitype `GrayImage`:
```julia
X = coerce(X, GrayImage);
X = coerce(Xraw, GrayImage);
```
For classifiers, target must have element scitype `<: Finite`, so we fix this:
For classifiers, target must have element scitype `<: Finite`:
```julia
y = coerce(y, Multiclass);
y = coerce(yraw, Multiclass);
```
Instantiating an image classifier model:
```julia
ImageClassifier = @load ImageClassifier
clf = ImageClassifier(builder=MyConvBuilder(3, 16, 32, 32),
epochs=10,
loss=Flux.crossentropy)
epochs=10,
loss=Flux.crossentropy)
```
And evaluating the accuracy of the model on a 30% holdout set:
Expand All @@ -466,12 +493,29 @@ And evaluating the accuracy of the model on a 30% holdout set:
mach = machine(clf, X, y)

julia> evaluate!(mach,
resampling=Holdout(rng=123, fraction_train=0.7),
operation=predict_mode,
measure=misclassification_rate)
resampling=Holdout(rng=123, fraction_train=0.7),
operation=predict_mode,
measure=misclassification_rate)
┌────────────────────────┬───────────────┬────────────┐
│ _.measure │ _.measurement │ _.per_fold │
├────────────────────────┼───────────────┼────────────┤
│ misclassification_rate │ 0.0467 │ [0.0467] │
└────────────────────────┴───────────────┴────────────┘
```
### Adding new models to MLJFlux (advanced)
This section is mainly for MLJFlux developers. It assumes familiarity
with the [MLJ model
API](https://alan-turing-institute.github.io/MLJ.jl/dev/adding_models_for_general_use/)
If one subtypes a new model type as either
`MLJFlux.MLJFluxProbabilistic` or `MLJFlux.MLJFluxDeterministic`, then
instead of defining new methods for `MLJModelInterface.fit` and
`MLJModelInterface.update` one can make use of fallbacks by
implementing the lower level methods `shape`, `build`, and
`fitresult`. See the [classifier source code](/src/classifier.jl) for
an example.
One still needs to implement a new `predict` method.
5 changes: 4 additions & 1 deletion src/MLJFlux.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module MLJFlux
module MLJFlux

export CUDALibs, CPU1

Expand All @@ -12,8 +12,11 @@ using Tables
using Statistics
using ColorTypes
using ComputationalResources
using Random

include("core.jl")
include("builders.jl")
include("types.jl")
include("regressor.jl")
include("classifier.jl")
include("image.jl")
Expand Down
Loading

0 comments on commit d3b8cdb

Please sign in to comment.