From caba8527e50a5e55ab81c9ffe0d1ec58726c72b7 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Thu, 23 Jun 2022 19:07:22 +1200 Subject: [PATCH 1/9] Add Metalhead as dep --- Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Project.toml b/Project.toml index a2f70565..dd10ab8b 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ ColorTypes = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" +Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -19,6 +20,7 @@ CategoricalArrays = "0.10" ColorTypes = "0.10.3, 0.11" ComputationalResources = "0.3.2" Flux = "0.10.4, 0.11, 0.12, 0.13" +Metalhead = "0.7" MLJModelInterface = "1.1.1" ProgressMeter = "1.7.1" Tables = "1.0" From f63654db2333ca54c93d8469d90a136a7b19986d Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Thu, 23 Jun 2022 19:56:31 +1200 Subject: [PATCH 2/9] first attempt Metalhead integration (with hack); tests lacking minor --- src/MLJFlux.jl | 3 +- src/builders.jl | 128 +++++++++++++++++++++++++++++++++++++++++++++-- src/types.jl | 4 +- test/builders.jl | 16 ++++-- 4 files changed, 143 insertions(+), 8 deletions(-) diff --git a/src/MLJFlux.jl b/src/MLJFlux.jl index 84bce73f..471fd45a 100644 --- a/src/MLJFlux.jl +++ b/src/MLJFlux.jl @@ -1,4 +1,4 @@ -module MLJFlux +module MLJFlux export CUDALibs, CPU1 @@ -13,6 +13,7 @@ using Statistics using ColorTypes using ComputationalResources using Random +import Metalhead include("penalizers.jl") include("core.jl") diff --git a/src/builders.jl b/src/builders.jl index 2c417c20..f87d092b 100644 --- a/src/builders.jl +++ b/src/builders.jl @@ -1,4 +1,4 @@ -## BUILDING CHAINS A FROM HYPERPARAMETERS + INPUT/OUTPUT SHAPE +# # BUILDING CHAINS A FROM HYPERPARAMETERS + INPUT/OUTPUT SHAPE # We introduce chain builders as a way of exposing neural network # hyperparameters (describing, architecture, dropout rates, etc) to @@ -9,7 +9,7 @@ # input/output dimensions/shape. # Below n or (n1, n2) etc refers to network inputs, while m or (m1, -# m2) etc refers to outputs. +# m2) etc refers to outputs. abstract type Builder <: MLJModelInterface.MLJType end @@ -38,7 +38,7 @@ using `n_hidden` nodes in the hidden layer and the specified `dropout` (defaulting to 0.5). An activation function `σ` is applied between the hidden and final layers. If `n_hidden=0` (the default) then `n_hidden` is the geometric mean of the number of input and output nodes. The -number of input and output nodes is determined from the data. +number of input and output nodes is determined from the data. The each layer is initialized using `Flux.glorot_uniform(rng)`. If `rng` is an integer, it is instead used as the seed for a @@ -96,6 +96,128 @@ function MLJFlux.build(mlp::MLP, rng, n_in, n_out) end +# # METALHEAD BUILDERS + +#= + +TODO: After https://github.com/FluxML/Metalhead.jl/issues/176: + +- Export and externally document `metal` method + +- Delete definition of `ResNetHack` below + +- Change default builder in ImageClassifier (see /src/types.jl) from + `metal(ResNetHack(...))` to `metal(Metalhead.ResNet(...))`, + +- Add nicer `show` methods for `MetalheadBuilder` instances + +=# + + +# # Wrapper types and `metal` wrapping function + +struct MetalheadPreBuilder{F} <: MLJFlux.Builder + metalhead_constructor::F +end + +struct MetalheadBuilder{F} <: MLJFlux.Builder + metalhead_constructor::F + args + kwargs +end + +""" + metal(constructor)(args...; kwargs...) + +Return an MLJFlux builder object based on the Metalhead.jl constructor/type +`constructor` (eg, `Metalhead.ResNet`). Here `args` and `kwargs` are +passed to the `MetalheadType` constructor at "build time", along with +the extra keyword specifiers `imsize=...`, `inchannels=...` and +`nclasses=...`, with values inferred from the data. + +# Example + +If in Metalhead.jl you would do + +```julia +using Metalhead +model = ResNet(50, pretrain=true, inchannels=1, nclasses=10) +``` + +then in MLJFlux, it suffices to do + +```julia +using MLJFlux, Metalhead +builder = metal(ResNet)(50, pretrain=true) +``` + +which can be used in `ImageClassifier` as in + +```julia +clf = ImageClassifier( + builder=builder, + epochs=500, + optimiser=Flux.ADAM(0.001), + loss=Flux.crossentropy, + batch_size=5, +) +``` + +""" +metal(metalhead_constructor) = MetalheadPreBuilder(metalhead_constructor) + +(pre_builder::MetalheadPreBuilder)(args...; kwargs...) = MetalheadBuilder( + pre_builder.metalhead_constructor, args, kwargs) + +MLJFlux.build( + b::MetalheadBuilder, + rng, + n_in, + n_out, + n_channels +) = b.metalhead_constructor( + b.args...; + b.kwargs..., + imsize=n_in, + inchannels=n_channels, + nclasses=n_out +) + +# See above "TODO" list. +function VGGHack( + depth::Integer=16; + imsize=nothing, + inchannels=3, + nclasses=1000, + batchnorm=false, + pretrain=false, +) + + # Note `imsize` is ignored, as here: + # https://github.com/FluxML/Metalhead.jl/blob/9edff63222720ff84671b8087dd71eb370a6c35a/src/convnets/vgg.jl#L165 + + @assert( + depth in keys(Metalhead.vgg_config), + "depth must be from one in $(sort(collect(keys(Metalhead.vgg_config))))" + ) + model = Metalhead.VGG((224, 224); + config = Metalhead.vgg_conv_config[Metalhead.vgg_config[depth]], + inchannels, + batchnorm, + nclasses, + fcsize = 4096, + dropout = 0.5) + if pretrain && !batchnorm + Metalhead.loadpretrain!(model, string("VGG", depth)) + elseif pretrain + Metalhead.loadpretrain!(model, "VGG$(depth)-BN)") + end + return model +end + + +# # BUILER MACRO + struct GenericBuilder{F} <: Builder apply::F end diff --git a/src/types.jl b/src/types.jl index bf5674af..13f43d87 100644 --- a/src/types.jl +++ b/src/types.jl @@ -50,6 +50,8 @@ doc_classifier(model_name) = doc_regressor(model_name)*""" for Model in [:NeuralNetworkClassifier, :ImageClassifier] + default_builder_ex = Model == :ImageClassifier ? :(metal(VGGHack)()) : Short() + ex = quote mutable struct $Model{B,F,O,L} <: MLJFluxProbabilistic builder::B @@ -65,7 +67,7 @@ for Model in [:NeuralNetworkClassifier, :ImageClassifier] acceleration::AbstractResource # eg, `CPU1()` or `CUDALibs()` end - function $Model(; builder::B = Short() + function $Model(; builder::B = $default_builder_ex , finaliser::F = Flux.softmax , optimiser::O = Flux.Optimise.ADAM() , loss::L = Flux.crossentropy diff --git a/test/builders.jl b/test/builders.jl index 030cbfa0..cd9d4f00 100644 --- a/test/builders.jl +++ b/test/builders.jl @@ -1,3 +1,11 @@ +# # Helpers + +function an_image(rng, n_in, n_channels) + n_channels == 3 && + return coerce(rand(rng, Float32, n_in..., 3), ColorImage) + return coerce(rand(rng, Float32, n_in...), GreyImage) +end + # to control chain initialization: myinit(n, m) = reshape(convert(Vector{Float32}, (1:n*m)), n , m) @@ -52,9 +60,11 @@ end end @testset_accelerated "@builder" accel begin - builder = MLJFlux.@builder(Flux.Chain(Flux.Dense(n_in, 4, - init = (out, in) -> randn(rng, out, in)), - Flux.Dense(4, n_out))) + builder = MLJFlux.@builder(Flux.Chain(Flux.Dense( + n_in, + 4, + init = (out, in) -> randn(rng, out, in) + ), Flux.Dense(4, n_out))) rng = StableRNGs.StableRNG(123) chain = MLJFlux.build(builder, rng, 5, 3) ps = Flux.params(chain) From 4134e442c935db40986fe3b1c1bd12407d7ba5c7 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Fri, 24 Jun 2022 08:12:05 +1200 Subject: [PATCH 3/9] add docstring comment --- src/builders.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/builders.jl b/src/builders.jl index f87d092b..226ef5f2 100644 --- a/src/builders.jl +++ b/src/builders.jl @@ -163,6 +163,9 @@ clf = ImageClassifier( ) ``` +The keyord arguments `imsize`, `inchannels` and `nclasses` are +dissallowed in `kwargs` (see above). + """ metal(metalhead_constructor) = MetalheadPreBuilder(metalhead_constructor) From a4f212389058dbc0f2ac7a42f4f224fc46d4380f Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Mon, 27 Jun 2022 16:28:00 +1200 Subject: [PATCH 4/9] rm invalidated test --- test/mlj_model_interface.jl | 4 ---- 1 file changed, 4 deletions(-) diff --git a/test/mlj_model_interface.jl b/test/mlj_model_interface.jl index 6b15aca4..24b9a59e 100644 --- a/test/mlj_model_interface.jl +++ b/test/mlj_model_interface.jl @@ -6,10 +6,6 @@ ModelType = MLJFlux.NeuralNetworkRegressor @test model == clone clone.optimiser.eta *= 10 @test model != clone - - clone = deepcopy(model) - clone.builder.dropout *= 0.5 - @test clone != model end @testset "clean!" begin From 0e72e4cfb5eac54f3f2f52c69d5028c11dc863b4 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Mon, 27 Jun 2022 16:30:45 +1200 Subject: [PATCH 5/9] mv metalhead stuff out to separate src file --- src/MLJFlux.jl | 1 + src/builders.jl | 123 ----------------------------------------------- src/metalhead.jl | 119 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 120 insertions(+), 123 deletions(-) create mode 100644 src/metalhead.jl diff --git a/src/MLJFlux.jl b/src/MLJFlux.jl index 471fd45a..d3a88064 100644 --- a/src/MLJFlux.jl +++ b/src/MLJFlux.jl @@ -18,6 +18,7 @@ import Metalhead include("penalizers.jl") include("core.jl") include("builders.jl") +include("metalhead.jl") include("types.jl") include("regressor.jl") include("classifier.jl") diff --git a/src/builders.jl b/src/builders.jl index 226ef5f2..b106058a 100644 --- a/src/builders.jl +++ b/src/builders.jl @@ -96,129 +96,6 @@ function MLJFlux.build(mlp::MLP, rng, n_in, n_out) end -# # METALHEAD BUILDERS - -#= - -TODO: After https://github.com/FluxML/Metalhead.jl/issues/176: - -- Export and externally document `metal` method - -- Delete definition of `ResNetHack` below - -- Change default builder in ImageClassifier (see /src/types.jl) from - `metal(ResNetHack(...))` to `metal(Metalhead.ResNet(...))`, - -- Add nicer `show` methods for `MetalheadBuilder` instances - -=# - - -# # Wrapper types and `metal` wrapping function - -struct MetalheadPreBuilder{F} <: MLJFlux.Builder - metalhead_constructor::F -end - -struct MetalheadBuilder{F} <: MLJFlux.Builder - metalhead_constructor::F - args - kwargs -end - -""" - metal(constructor)(args...; kwargs...) - -Return an MLJFlux builder object based on the Metalhead.jl constructor/type -`constructor` (eg, `Metalhead.ResNet`). Here `args` and `kwargs` are -passed to the `MetalheadType` constructor at "build time", along with -the extra keyword specifiers `imsize=...`, `inchannels=...` and -`nclasses=...`, with values inferred from the data. - -# Example - -If in Metalhead.jl you would do - -```julia -using Metalhead -model = ResNet(50, pretrain=true, inchannels=1, nclasses=10) -``` - -then in MLJFlux, it suffices to do - -```julia -using MLJFlux, Metalhead -builder = metal(ResNet)(50, pretrain=true) -``` - -which can be used in `ImageClassifier` as in - -```julia -clf = ImageClassifier( - builder=builder, - epochs=500, - optimiser=Flux.ADAM(0.001), - loss=Flux.crossentropy, - batch_size=5, -) -``` - -The keyord arguments `imsize`, `inchannels` and `nclasses` are -dissallowed in `kwargs` (see above). - -""" -metal(metalhead_constructor) = MetalheadPreBuilder(metalhead_constructor) - -(pre_builder::MetalheadPreBuilder)(args...; kwargs...) = MetalheadBuilder( - pre_builder.metalhead_constructor, args, kwargs) - -MLJFlux.build( - b::MetalheadBuilder, - rng, - n_in, - n_out, - n_channels -) = b.metalhead_constructor( - b.args...; - b.kwargs..., - imsize=n_in, - inchannels=n_channels, - nclasses=n_out -) - -# See above "TODO" list. -function VGGHack( - depth::Integer=16; - imsize=nothing, - inchannels=3, - nclasses=1000, - batchnorm=false, - pretrain=false, -) - - # Note `imsize` is ignored, as here: - # https://github.com/FluxML/Metalhead.jl/blob/9edff63222720ff84671b8087dd71eb370a6c35a/src/convnets/vgg.jl#L165 - - @assert( - depth in keys(Metalhead.vgg_config), - "depth must be from one in $(sort(collect(keys(Metalhead.vgg_config))))" - ) - model = Metalhead.VGG((224, 224); - config = Metalhead.vgg_conv_config[Metalhead.vgg_config[depth]], - inchannels, - batchnorm, - nclasses, - fcsize = 4096, - dropout = 0.5) - if pretrain && !batchnorm - Metalhead.loadpretrain!(model, string("VGG", depth)) - elseif pretrain - Metalhead.loadpretrain!(model, "VGG$(depth)-BN)") - end - return model -end - - # # BUILER MACRO struct GenericBuilder{F} <: Builder diff --git a/src/metalhead.jl b/src/metalhead.jl new file mode 100644 index 00000000..5b68fcf5 --- /dev/null +++ b/src/metalhead.jl @@ -0,0 +1,119 @@ +#= + +TODO: After https://github.com/FluxML/Metalhead.jl/issues/176: + +- Export and externally document `metal` method + +- Delete definition of `ResNetHack` below + +- Change default builder in ImageClassifier (see /src/types.jl) from + `metal(ResNetHack(...))` to `metal(Metalhead.ResNet(...))`, + +- Add nicer `show` methods for `MetalheadBuilder` instances + +=# + + +# # Wrapper types and `metal` wrapping function + +struct MetalheadPreBuilder{F} <: MLJFlux.Builder + metalhead_constructor::F +end + +struct MetalheadBuilder{F} <: MLJFlux.Builder + metalhead_constructor::F + args + kwargs +end + +""" + metal(constructor)(args...; kwargs...) + +Return an MLJFlux builder object based on the Metalhead.jl constructor/type +`constructor` (eg, `Metalhead.ResNet`). Here `args` and `kwargs` are +passed to the `MetalheadType` constructor at "build time", along with +the extra keyword specifiers `imsize=...`, `inchannels=...` and +`nclasses=...`, with values inferred from the data. + +# Example + +If in Metalhead.jl you would do + +```julia +using Metalhead +model = ResNet(50, pretrain=true, inchannels=1, nclasses=10) +``` + +then in MLJFlux, it suffices to do + +```julia +using MLJFlux, Metalhead +builder = metal(ResNet)(50, pretrain=true) +``` + +which can be used in `ImageClassifier` as in + +```julia +clf = ImageClassifier( + builder=builder, + epochs=500, + optimiser=Flux.ADAM(0.001), + loss=Flux.crossentropy, + batch_size=5, +) +``` + +The keyord arguments `imsize`, `inchannels` and `nclasses` are +dissallowed in `kwargs` (see above). + +""" +metal(metalhead_constructor) = MetalheadPreBuilder(metalhead_constructor) + +(pre_builder::MetalheadPreBuilder)(args...; kwargs...) = MetalheadBuilder( + pre_builder.metalhead_constructor, args, kwargs) + +MLJFlux.build( + b::MetalheadBuilder, + rng, + n_in, + n_out, + n_channels +) = b.metalhead_constructor( + b.args...; + b.kwargs..., + imsize=n_in, + inchannels=n_channels, + nclasses=n_out +) + +# See above "TODO" list. +function VGGHack( + depth::Integer=16; + imsize=nothing, + inchannels=3, + nclasses=1000, + batchnorm=false, + pretrain=false, +) + + # Note `imsize` is ignored, as here: + # https://github.com/FluxML/Metalhead.jl/blob/9edff63222720ff84671b8087dd71eb370a6c35a/src/convnets/vgg.jl#L165 + + @assert( + depth in keys(Metalhead.vgg_config), + "depth must be from one in $(sort(collect(keys(Metalhead.vgg_config))))" + ) + model = Metalhead.VGG((224, 224); + config = Metalhead.vgg_conv_config[Metalhead.vgg_config[depth]], + inchannels, + batchnorm, + nclasses, + fcsize = 4096, + dropout = 0.5) + if pretrain && !batchnorm + Metalhead.loadpretrain!(model, string("VGG", depth)) + elseif pretrain + Metalhead.loadpretrain!(model, "VGG$(depth)-BN)") + end + return model +end From 55ee4e1aae7c7568744ac8ec48b1ea70717c9c96 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Mon, 27 Jun 2022 17:09:03 +1200 Subject: [PATCH 6/9] add show methods for Metalhead wraps --- src/metalhead.jl | 31 +++++++++++++++++++++++++++---- test/runtests.jl | 4 ++++ 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/src/metalhead.jl b/src/metalhead.jl index 5b68fcf5..91510773 100644 --- a/src/metalhead.jl +++ b/src/metalhead.jl @@ -14,9 +14,9 @@ TODO: After https://github.com/FluxML/Metalhead.jl/issues/176: =# -# # Wrapper types and `metal` wrapping function +# # WRAPPING -struct MetalheadPreBuilder{F} <: MLJFlux.Builder +struct MetalheadWrapper{F} <: MLJFlux.Builder metalhead_constructor::F end @@ -26,6 +26,29 @@ struct MetalheadBuilder{F} <: MLJFlux.Builder kwargs end +Base.show(io::IO, w::MetalheadWrapper) = + print(io, "metal($(repr(w.metalhead_constructor)))") + +function Base.show(io::IO, ::MIME"text/plain", w::MetalheadBuilder) + println(io, "builder wrapping $(w.metalhead_constructor)") + if !isempty(w.args) + println(io, " args:") + for (i, arg) in enumerate(w.args) + println(io, " 1: $arg") + end + end + if !isempty(w.kwargs) + println(io, " kwargs:") + for kwarg in w.kwargs + println(io, " $(first(kwarg)) = $(last(kwarg))") + end + end +end + +Base.show(io::IO, w::MetalheadBuilder) = + print(io, "metal($(repr(w.metalhead_constructor)))(…)") + + """ metal(constructor)(args...; kwargs...) @@ -67,9 +90,9 @@ The keyord arguments `imsize`, `inchannels` and `nclasses` are dissallowed in `kwargs` (see above). """ -metal(metalhead_constructor) = MetalheadPreBuilder(metalhead_constructor) +metal(metalhead_constructor) = MetalheadWrapper(metalhead_constructor) -(pre_builder::MetalheadPreBuilder)(args...; kwargs...) = MetalheadBuilder( +(pre_builder::MetalheadWrapper)(args...; kwargs...) = MetalheadBuilder( pre_builder.metalhead_constructor, args, kwargs) MLJFlux.build( diff --git a/test/runtests.jl b/test/runtests.jl index ab44a92f..b0e84fd0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -57,6 +57,10 @@ end include("builders.jl") end +@testset "metalhead" begin + include("metalhead.jl") +end + @testset "mlj_model_interface" begin include("mlj_model_interface.jl") end From 19c516fa39b99a6b233a3056e7fd0fbc6d02262d Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Tue, 28 Jun 2022 08:39:14 +1200 Subject: [PATCH 7/9] add forgotten files with tests --- src/metalhead.jl | 14 ++++++++++-- test/metalhead.jl | 57 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 2 deletions(-) create mode 100644 test/metalhead.jl diff --git a/src/metalhead.jl b/src/metalhead.jl index 91510773..481ff332 100644 --- a/src/metalhead.jl +++ b/src/metalhead.jl @@ -13,6 +13,12 @@ TODO: After https://github.com/FluxML/Metalhead.jl/issues/176: =# +const DISALLOWED_KWARGS = [:imsize, :inchannels, :nclasses] +const human_disallowed_kwargs = join(map(s->"`$s`", DISALLOWED_KWARGS), ", ", " and ") +const ERR_METALHEAD_DISALLOWED_KWARGS = ArgumentError( + "Keyword arguments $human_disallowed_kwargs are disallowed "* + "as their values are inferred from data. " +) # # WRAPPING @@ -92,8 +98,12 @@ dissallowed in `kwargs` (see above). """ metal(metalhead_constructor) = MetalheadWrapper(metalhead_constructor) -(pre_builder::MetalheadWrapper)(args...; kwargs...) = MetalheadBuilder( - pre_builder.metalhead_constructor, args, kwargs) +function (pre_builder::MetalheadWrapper)(args...; kwargs...) + kw_names = keys(kwargs) + isempty(intersect(kw_names, DISALLOWED_KWARGS)) || + throw(ERR_METALHEAD_DISALLOWED_KWARGS) + return MetalheadBuilder(pre_builder.metalhead_constructor, args, kwargs) +end MLJFlux.build( b::MetalheadBuilder, diff --git a/test/metalhead.jl b/test/metalhead.jl new file mode 100644 index 00000000..6c43d62c --- /dev/null +++ b/test/metalhead.jl @@ -0,0 +1,57 @@ +using StableRNGs +using MLJFlux +const Metalhead = MLJFlux.Metalhead + +@testset "display" begin + io = IOBuffer() + builder = MLJFlux.metal(MLJFlux.Metalhead.ResNet)(50, pretrain=false) + show(io, builder) + @test String(take!(io)) == + "builder wrapping Metalhead.ResNet\n args:\n"* + " 1: 50\n kwargs:\n pretrain = false\n" + close(io) +end + +@testset "disallowed kwargs" begin + @test_throws( + MLJFlux.ERR_METALHEAD_DISALLOWED_KWARGS, + MLJFlux.metal(MLJFlux.Metalhead.VGG)(imsize=(241, 241)), + ) + @test_throws( + MLJFlux.ERR_METALHEAD_DISALLOWED_KWARGS, + MLJFlux.metal(MLJFlux.Metalhead.VGG)(inchannels=2), + ) + @test_throws( + MLJFlux.ERR_METALHEAD_DISALLOWED_KWARGS, + MLJFlux.metal(MLJFlux.Metalhead.VGG)(nclasses=10), + ) +end + +@testset "constructors" begin + depth = 16 + imsize = (128, 128) + nclasses = 10 + inchannels = 1 + wrapped = MLJFlux.metal(Metalhead.VGG) + @test wrapped.metalhead_constructor == Metalhead.VGG + builder = wrapped(depth, batchnorm=true) + @test builder.metalhead_constructor == Metalhead.VGG + @test builder.args == (depth, ) + @test (; builder.kwargs...) == (; batchnorm=true) + ref_chain = Metalhead.VGG( + imsize; + config = Metalhead.vgg_conv_config[Metalhead.vgg_config[depth]], + inchannels, + batchnorm=true, + nclasses, + fcsize = 4096, + dropout = 0.5 + ) + # needs https://github.com/FluxML/Metalhead.jl/issues/176 + # chain = + # MLJFlux.build(builder, StableRNGs.StableRNG(123), imsize, nclasses, inchannels) + # @test length.(MLJFlux.Flux.params(ref_chain)) == + # length.(MLJFlux.Flux.params(chain)) +end + +true From 82880a3fa4a9491dfc93419f61cd3f86243eb011 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Tue, 28 Jun 2022 09:30:34 +1200 Subject: [PATCH 8/9] fix test --- test/metalhead.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/metalhead.jl b/test/metalhead.jl index 6c43d62c..5fb4560c 100644 --- a/test/metalhead.jl +++ b/test/metalhead.jl @@ -5,10 +5,12 @@ const Metalhead = MLJFlux.Metalhead @testset "display" begin io = IOBuffer() builder = MLJFlux.metal(MLJFlux.Metalhead.ResNet)(50, pretrain=false) - show(io, builder) + show(io, MIME("text/plain"), builder) @test String(take!(io)) == "builder wrapping Metalhead.ResNet\n args:\n"* " 1: 50\n kwargs:\n pretrain = false\n" + show(io, builder) + @test String(take!(io)) == "metal(Metalhead.ResNet)(…)" close(io) end From 64f9b3dec756d367ae83fdf41b625c24e22b32c8 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Tue, 28 Jun 2022 11:05:09 +1200 Subject: [PATCH 9/9] rename metal -> image_builder --- src/metalhead.jl | 12 ++++++------ src/types.jl | 2 +- test/metalhead.jl | 12 ++++++------ 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/metalhead.jl b/src/metalhead.jl index 481ff332..d0ec1a07 100644 --- a/src/metalhead.jl +++ b/src/metalhead.jl @@ -7,7 +7,7 @@ TODO: After https://github.com/FluxML/Metalhead.jl/issues/176: - Delete definition of `ResNetHack` below - Change default builder in ImageClassifier (see /src/types.jl) from - `metal(ResNetHack(...))` to `metal(Metalhead.ResNet(...))`, + `image_builder(ResNetHack(...))` to `image_builder(Metalhead.ResNet(...))`, - Add nicer `show` methods for `MetalheadBuilder` instances @@ -33,7 +33,7 @@ struct MetalheadBuilder{F} <: MLJFlux.Builder end Base.show(io::IO, w::MetalheadWrapper) = - print(io, "metal($(repr(w.metalhead_constructor)))") + print(io, "image_builder($(repr(w.metalhead_constructor)))") function Base.show(io::IO, ::MIME"text/plain", w::MetalheadBuilder) println(io, "builder wrapping $(w.metalhead_constructor)") @@ -52,11 +52,11 @@ function Base.show(io::IO, ::MIME"text/plain", w::MetalheadBuilder) end Base.show(io::IO, w::MetalheadBuilder) = - print(io, "metal($(repr(w.metalhead_constructor)))(…)") + print(io, "image_builder($(repr(w.metalhead_constructor)))(…)") """ - metal(constructor)(args...; kwargs...) + image_builder(constructor)(args...; kwargs...) Return an MLJFlux builder object based on the Metalhead.jl constructor/type `constructor` (eg, `Metalhead.ResNet`). Here `args` and `kwargs` are @@ -77,7 +77,7 @@ then in MLJFlux, it suffices to do ```julia using MLJFlux, Metalhead -builder = metal(ResNet)(50, pretrain=true) +builder = image_builder(ResNet)(50, pretrain=true) ``` which can be used in `ImageClassifier` as in @@ -96,7 +96,7 @@ The keyord arguments `imsize`, `inchannels` and `nclasses` are dissallowed in `kwargs` (see above). """ -metal(metalhead_constructor) = MetalheadWrapper(metalhead_constructor) +image_builder(metalhead_constructor) = MetalheadWrapper(metalhead_constructor) function (pre_builder::MetalheadWrapper)(args...; kwargs...) kw_names = keys(kwargs) diff --git a/src/types.jl b/src/types.jl index 13f43d87..6a36c2be 100644 --- a/src/types.jl +++ b/src/types.jl @@ -50,7 +50,7 @@ doc_classifier(model_name) = doc_regressor(model_name)*""" for Model in [:NeuralNetworkClassifier, :ImageClassifier] - default_builder_ex = Model == :ImageClassifier ? :(metal(VGGHack)()) : Short() + default_builder_ex = Model == :ImageClassifier ? :(image_builder(VGGHack)()) : Short() ex = quote mutable struct $Model{B,F,O,L} <: MLJFluxProbabilistic diff --git a/test/metalhead.jl b/test/metalhead.jl index 5fb4560c..8c937e54 100644 --- a/test/metalhead.jl +++ b/test/metalhead.jl @@ -4,28 +4,28 @@ const Metalhead = MLJFlux.Metalhead @testset "display" begin io = IOBuffer() - builder = MLJFlux.metal(MLJFlux.Metalhead.ResNet)(50, pretrain=false) + builder = MLJFlux.image_builder(MLJFlux.Metalhead.ResNet)(50, pretrain=false) show(io, MIME("text/plain"), builder) @test String(take!(io)) == "builder wrapping Metalhead.ResNet\n args:\n"* " 1: 50\n kwargs:\n pretrain = false\n" show(io, builder) - @test String(take!(io)) == "metal(Metalhead.ResNet)(…)" + @test String(take!(io)) == "image_builder(Metalhead.ResNet)(…)" close(io) end @testset "disallowed kwargs" begin @test_throws( MLJFlux.ERR_METALHEAD_DISALLOWED_KWARGS, - MLJFlux.metal(MLJFlux.Metalhead.VGG)(imsize=(241, 241)), + MLJFlux.image_builder(MLJFlux.Metalhead.VGG)(imsize=(241, 241)), ) @test_throws( MLJFlux.ERR_METALHEAD_DISALLOWED_KWARGS, - MLJFlux.metal(MLJFlux.Metalhead.VGG)(inchannels=2), + MLJFlux.image_builder(MLJFlux.Metalhead.VGG)(inchannels=2), ) @test_throws( MLJFlux.ERR_METALHEAD_DISALLOWED_KWARGS, - MLJFlux.metal(MLJFlux.Metalhead.VGG)(nclasses=10), + MLJFlux.image_builder(MLJFlux.Metalhead.VGG)(nclasses=10), ) end @@ -34,7 +34,7 @@ end imsize = (128, 128) nclasses = 10 inchannels = 1 - wrapped = MLJFlux.metal(Metalhead.VGG) + wrapped = MLJFlux.image_builder(Metalhead.VGG) @test wrapped.metalhead_constructor == Metalhead.VGG builder = wrapped(depth, batchnorm=true) @test builder.metalhead_constructor == Metalhead.VGG