Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metalhead integration #206

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ ColorTypes = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand All @@ -19,6 +20,7 @@ CategoricalArrays = "0.10"
ColorTypes = "0.10.3, 0.11"
ComputationalResources = "0.3.2"
Flux = "0.10.4, 0.11, 0.12, 0.13"
Metalhead = "0.7"
MLJModelInterface = "1.1.1"
ProgressMeter = "1.7.1"
Tables = "1.0"
Expand Down
4 changes: 3 additions & 1 deletion src/MLJFlux.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module MLJFlux
module MLJFlux

export CUDALibs, CPU1

Expand All @@ -13,10 +13,12 @@ using Statistics
using ColorTypes
using ComputationalResources
using Random
import Metalhead

include("penalizers.jl")
include("core.jl")
include("builders.jl")
include("metalhead.jl")
include("types.jl")
include("regressor.jl")
include("classifier.jl")
Expand Down
8 changes: 5 additions & 3 deletions src/builders.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
## BUILDING CHAINS A FROM HYPERPARAMETERS + INPUT/OUTPUT SHAPE
# # BUILDING CHAINS A FROM HYPERPARAMETERS + INPUT/OUTPUT SHAPE

# We introduce chain builders as a way of exposing neural network
# hyperparameters (describing, architecture, dropout rates, etc) to
Expand All @@ -9,7 +9,7 @@
# input/output dimensions/shape.

# Below n or (n1, n2) etc refers to network inputs, while m or (m1,
# m2) etc refers to outputs.
# m2) etc refers to outputs.

abstract type Builder <: MLJModelInterface.MLJType end

Expand Down Expand Up @@ -38,7 +38,7 @@ using `n_hidden` nodes in the hidden layer and the specified `dropout`
(defaulting to 0.5). An activation function `σ` is applied between the
hidden and final layers. If `n_hidden=0` (the default) then `n_hidden`
is the geometric mean of the number of input and output nodes. The
number of input and output nodes is determined from the data.
number of input and output nodes is determined from the data.

The each layer is initialized using `Flux.glorot_uniform(rng)`. If
`rng` is an integer, it is instead used as the seed for a
Expand Down Expand Up @@ -96,6 +96,8 @@ function MLJFlux.build(mlp::MLP, rng, n_in, n_out)
end


# # BUILER MACRO

struct GenericBuilder{F} <: Builder
apply::F
end
Expand Down
152 changes: 152 additions & 0 deletions src/metalhead.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
#=

TODO: After https://github.com/FluxML/Metalhead.jl/issues/176:

- Export and externally document `metal` method

- Delete definition of `ResNetHack` below

- Change default builder in ImageClassifier (see /src/types.jl) from
`image_builder(ResNetHack(...))` to `image_builder(Metalhead.ResNet(...))`,

- Add nicer `show` methods for `MetalheadBuilder` instances

=#

const DISALLOWED_KWARGS = [:imsize, :inchannels, :nclasses]
const human_disallowed_kwargs = join(map(s->"`$s`", DISALLOWED_KWARGS), ", ", " and ")
const ERR_METALHEAD_DISALLOWED_KWARGS = ArgumentError(
"Keyword arguments $human_disallowed_kwargs are disallowed "*
"as their values are inferred from data. "
)

# # WRAPPING

struct MetalheadWrapper{F} <: MLJFlux.Builder
metalhead_constructor::F
end

struct MetalheadBuilder{F} <: MLJFlux.Builder
metalhead_constructor::F
args
kwargs
end

Base.show(io::IO, w::MetalheadWrapper) =
print(io, "image_builder($(repr(w.metalhead_constructor)))")

function Base.show(io::IO, ::MIME"text/plain", w::MetalheadBuilder)
println(io, "builder wrapping $(w.metalhead_constructor)")
if !isempty(w.args)
println(io, " args:")
for (i, arg) in enumerate(w.args)
println(io, " 1: $arg")
end
end
if !isempty(w.kwargs)
println(io, " kwargs:")
for kwarg in w.kwargs
println(io, " $(first(kwarg)) = $(last(kwarg))")
end
end
end

Base.show(io::IO, w::MetalheadBuilder) =
print(io, "image_builder($(repr(w.metalhead_constructor)))(…)")


"""
image_builder(constructor)(args...; kwargs...)

Return an MLJFlux builder object based on the Metalhead.jl constructor/type
`constructor` (eg, `Metalhead.ResNet`). Here `args` and `kwargs` are
passed to the `MetalheadType` constructor at "build time", along with
the extra keyword specifiers `imsize=...`, `inchannels=...` and
`nclasses=...`, with values inferred from the data.

# Example

If in Metalhead.jl you would do

```julia
using Metalhead
model = ResNet(50, pretrain=true, inchannels=1, nclasses=10)
```

then in MLJFlux, it suffices to do

```julia
using MLJFlux, Metalhead
builder = image_builder(ResNet)(50, pretrain=true)
```

which can be used in `ImageClassifier` as in

```julia
clf = ImageClassifier(
builder=builder,
epochs=500,
optimiser=Flux.ADAM(0.001),
loss=Flux.crossentropy,
batch_size=5,
)
```

The keyord arguments `imsize`, `inchannels` and `nclasses` are
dissallowed in `kwargs` (see above).

"""
image_builder(metalhead_constructor) = MetalheadWrapper(metalhead_constructor)

function (pre_builder::MetalheadWrapper)(args...; kwargs...)
kw_names = keys(kwargs)
isempty(intersect(kw_names, DISALLOWED_KWARGS)) ||
throw(ERR_METALHEAD_DISALLOWED_KWARGS)
return MetalheadBuilder(pre_builder.metalhead_constructor, args, kwargs)
end

MLJFlux.build(
b::MetalheadBuilder,
rng,
n_in,
n_out,
n_channels
) = b.metalhead_constructor(
b.args...;
b.kwargs...,
imsize=n_in,
inchannels=n_channels,
nclasses=n_out
)

# See above "TODO" list.
function VGGHack(
depth::Integer=16;
imsize=nothing,
inchannels=3,
nclasses=1000,
batchnorm=false,
pretrain=false,
)

# Note `imsize` is ignored, as here:
# https://github.com/FluxML/Metalhead.jl/blob/9edff63222720ff84671b8087dd71eb370a6c35a/src/convnets/vgg.jl#L165

@assert(
depth in keys(Metalhead.vgg_config),
"depth must be from one in $(sort(collect(keys(Metalhead.vgg_config))))"
)
model = Metalhead.VGG((224, 224);
config = Metalhead.vgg_conv_config[Metalhead.vgg_config[depth]],
inchannels,
batchnorm,
nclasses,
fcsize = 4096,
dropout = 0.5)
if pretrain && !batchnorm
Metalhead.loadpretrain!(model, string("VGG", depth))
elseif pretrain
Metalhead.loadpretrain!(model, "VGG$(depth)-BN)")
end
return model
end
4 changes: 3 additions & 1 deletion src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ doc_classifier(model_name) = doc_regressor(model_name)*"""

for Model in [:NeuralNetworkClassifier, :ImageClassifier]

default_builder_ex = Model == :ImageClassifier ? :(image_builder(VGGHack)()) : Short()

ex = quote
mutable struct $Model{B,F,O,L} <: MLJFluxProbabilistic
builder::B
Expand All @@ -65,7 +67,7 @@ for Model in [:NeuralNetworkClassifier, :ImageClassifier]
acceleration::AbstractResource # eg, `CPU1()` or `CUDALibs()`
end

function $Model(; builder::B = Short()
function $Model(; builder::B = $default_builder_ex
, finaliser::F = Flux.softmax
, optimiser::O = Flux.Optimise.ADAM()
, loss::L = Flux.crossentropy
Expand Down
16 changes: 13 additions & 3 deletions test/builders.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
# # Helpers

function an_image(rng, n_in, n_channels)
n_channels == 3 &&
return coerce(rand(rng, Float32, n_in..., 3), ColorImage)
return coerce(rand(rng, Float32, n_in...), GreyImage)
end

# to control chain initialization:
myinit(n, m) = reshape(convert(Vector{Float32}, (1:n*m)), n , m)

Expand Down Expand Up @@ -52,9 +60,11 @@ end
end

@testset_accelerated "@builder" accel begin
builder = MLJFlux.@builder(Flux.Chain(Flux.Dense(n_in, 4,
init = (out, in) -> randn(rng, out, in)),
Flux.Dense(4, n_out)))
builder = MLJFlux.@builder(Flux.Chain(Flux.Dense(
n_in,
4,
init = (out, in) -> randn(rng, out, in)
), Flux.Dense(4, n_out)))
rng = StableRNGs.StableRNG(123)
chain = MLJFlux.build(builder, rng, 5, 3)
ps = Flux.params(chain)
Expand Down
59 changes: 59 additions & 0 deletions test/metalhead.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
using StableRNGs
using MLJFlux
const Metalhead = MLJFlux.Metalhead

@testset "display" begin
io = IOBuffer()
builder = MLJFlux.image_builder(MLJFlux.Metalhead.ResNet)(50, pretrain=false)
show(io, MIME("text/plain"), builder)
@test String(take!(io)) ==
"builder wrapping Metalhead.ResNet\n args:\n"*
" 1: 50\n kwargs:\n pretrain = false\n"
show(io, builder)
@test String(take!(io)) == "image_builder(Metalhead.ResNet)(…)"
close(io)
end

@testset "disallowed kwargs" begin
@test_throws(
MLJFlux.ERR_METALHEAD_DISALLOWED_KWARGS,
MLJFlux.image_builder(MLJFlux.Metalhead.VGG)(imsize=(241, 241)),
)
@test_throws(
MLJFlux.ERR_METALHEAD_DISALLOWED_KWARGS,
MLJFlux.image_builder(MLJFlux.Metalhead.VGG)(inchannels=2),
)
@test_throws(
MLJFlux.ERR_METALHEAD_DISALLOWED_KWARGS,
MLJFlux.image_builder(MLJFlux.Metalhead.VGG)(nclasses=10),
)
end

@testset "constructors" begin
depth = 16
imsize = (128, 128)
nclasses = 10
inchannels = 1
wrapped = MLJFlux.image_builder(Metalhead.VGG)
@test wrapped.metalhead_constructor == Metalhead.VGG
builder = wrapped(depth, batchnorm=true)
@test builder.metalhead_constructor == Metalhead.VGG
@test builder.args == (depth, )
@test (; builder.kwargs...) == (; batchnorm=true)
ref_chain = Metalhead.VGG(
imsize;
config = Metalhead.vgg_conv_config[Metalhead.vgg_config[depth]],
inchannels,
batchnorm=true,
nclasses,
fcsize = 4096,
dropout = 0.5
)
# needs https://github.com/FluxML/Metalhead.jl/issues/176
# chain =
# MLJFlux.build(builder, StableRNGs.StableRNG(123), imsize, nclasses, inchannels)
# @test length.(MLJFlux.Flux.params(ref_chain)) ==
# length.(MLJFlux.Flux.params(chain))
end

true
4 changes: 0 additions & 4 deletions test/mlj_model_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@ ModelType = MLJFlux.NeuralNetworkRegressor
@test model == clone
clone.optimiser.eta *= 10
@test model != clone

clone = deepcopy(model)
clone.builder.dropout *= 0.5
@test clone != model
end

@testset "clean!" begin
Expand Down
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ end
include("builders.jl")
end

@testset "metalhead" begin
include("metalhead.jl")
end

@testset "mlj_model_interface" begin
include("mlj_model_interface.jl")
end
Expand Down