Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of EfficientNetv2 and MNASNet #198

Merged
merged 36 commits into from
Aug 24, 2022
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
49faf09
Initial commit for EfficientNetv2
theabhirath Aug 4, 2022
d428da0
Cleanup
theabhirath Aug 5, 2022
093f573
Add docs for EfficientNetv2
theabhirath Aug 11, 2022
add4d41
Add tests
theabhirath Aug 11, 2022
7d56396
Fix Inception bug, and other misc. cleanup
theabhirath Aug 11, 2022
d3e4add
Refactor: `mbconv` instead of `invertedresidual`
theabhirath Aug 11, 2022
6504f25
Refactor EfficientNets
theabhirath Aug 12, 2022
3e7e5f7
Refactor EfficientNets
theabhirath Aug 12, 2022
20a4b37
Fixes
theabhirath Aug 13, 2022
34caab4
Merge branch 'effnetv2' of https://github.com/theabhirath/Metalhead.j…
theabhirath Aug 13, 2022
ea13dd5
Some refactors, some consistency, some features
theabhirath Aug 14, 2022
c998951
The real hero was `block_idx` all along
theabhirath Aug 15, 2022
fc03d70
Fix minor hiccups
theabhirath Aug 15, 2022
f2461a5
Moving closer to the one true function
theabhirath Aug 16, 2022
9f6b987
Some more reorganisation
theabhirath Aug 16, 2022
785a95a
Huge refactor of MobileNet and EfficientNet families
theabhirath Aug 18, 2022
9e91783
Initial commit for EfficientNetv2
theabhirath Aug 4, 2022
4a94569
Cleanup
theabhirath Aug 5, 2022
69563a6
Add docs for EfficientNetv2
theabhirath Aug 11, 2022
70841f6
Add tests
theabhirath Aug 11, 2022
245eda0
Fix Inception bug, and other misc. cleanup
theabhirath Aug 11, 2022
1c65159
Refactor: `mbconv` instead of `invertedresidual`
theabhirath Aug 11, 2022
744f214
Refactor EfficientNets
theabhirath Aug 12, 2022
9bc75fc
Fixes
theabhirath Aug 13, 2022
3a37a70
Some refactors, some consistency, some features
theabhirath Aug 14, 2022
593752f
The real hero was `block_idx` all along
theabhirath Aug 15, 2022
55bc544
Fix minor hiccups
theabhirath Aug 15, 2022
3d63f72
Moving closer to the one true function
theabhirath Aug 16, 2022
818c584
Some more reorganisation
theabhirath Aug 16, 2022
8092818
Huge refactor of MobileNet and EfficientNet families
theabhirath Aug 18, 2022
510e913
Add MNASNet
theabhirath Aug 20, 2022
8d323e6
Merge branch 'effnetv2' of https://github.com/theabhirath/Metalhead.j…
theabhirath Aug 20, 2022
ab2a15e
Add tests for MNASNet
theabhirath Aug 20, 2022
2d310a9
Final cleanup, hopefully
theabhirath Aug 23, 2022
f76fadb
Minor refactor of `cnn_stages`
theabhirath Aug 23, 2022
992f6a6
`_round_channels` all the way
theabhirath Aug 23, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
suite:
- '["AlexNet", "VGG"]'
- '["GoogLeNet", "SqueezeNet", "MobileNet"]'
- '["EfficientNet"]'
- '"EfficientNet"'
- 'r"/*/ResNet*"'
- '[r"ResNeXt", r"SEResNet"]'
- '[r"Res2Net", r"Res2NeXt"]'
Expand Down
30 changes: 17 additions & 13 deletions src/Metalhead.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,22 @@ include("convnets/resnets/resnext.jl")
include("convnets/resnets/seresnet.jl")
include("convnets/resnets/res2net.jl")
## Inceptions
include("convnets/inception/googlenet.jl")
include("convnets/inception/inceptionv3.jl")
include("convnets/inception/inceptionv4.jl")
include("convnets/inception/inceptionresnetv2.jl")
include("convnets/inception/xception.jl")
include("convnets/inceptions/googlenet.jl")
include("convnets/inceptions/inceptionv3.jl")
include("convnets/inceptions/inceptionv4.jl")
include("convnets/inceptions/inceptionresnetv2.jl")
include("convnets/inceptions/xception.jl")
## EfficientNets
include("convnets/efficientnets/core.jl")
include("convnets/efficientnets/efficientnet.jl")
include("convnets/efficientnets/efficientnetv2.jl")
## MobileNets
include("convnets/mobilenet/mobilenetv1.jl")
include("convnets/mobilenet/mobilenetv2.jl")
include("convnets/mobilenet/mobilenetv3.jl")
include("convnets/mobilenets/mobilenetv1.jl")
include("convnets/mobilenets/mobilenetv2.jl")
include("convnets/mobilenets/mobilenetv3.jl")
## Others
include("convnets/densenet.jl")
include("convnets/squeezenet.jl")
include("convnets/efficientnet.jl")
include("convnets/convnext.jl")
include("convnets/convmixer.jl")

Expand All @@ -61,13 +64,14 @@ export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19,
WideResNet, ResNeXt, SEResNet, SEResNeXt, Res2Net, Res2NeXt,
DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201,
GoogLeNet, Inception3, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception,
SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, EfficientNet,
SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, EfficientNet, EfficientNetv2,
MLPMixer, ResMLP, gMLP, ViT, ConvMixer, ConvNeXt

# use Flux._big_show to pretty print large models
for T in (:AlexNet, :VGG, :ResNet, :ResNeXt, :DenseNet, :SEResNet, :SEResNeXt,
:Res2Net, :Res2NeXt, :GoogLeNet, :Inceptionv3, :Inceptionv4,
:Xception, :SqueezeNet, :MobileNetv1, :MobileNetv2, :MobileNetv3, :EfficientNet,
for T in (:AlexNet, :VGG, :SqueezeNet, :ResNet, :WideResNet, :ResNeXt,
:SEResNet, :SEResNeXt, :Res2Net, :Res2NeXt, :GoogLeNet, :DenseNet,
:Inceptionv3, :Inceptionv4, :InceptionResNetv2, :Xception,
:MobileNetv1, :MobileNetv2, :MobileNetv3, :EfficientNet, :EfficientNetv2,
:MLPMixer, :ResMLP, :gMLP, :ViT, :ConvMixer, :ConvNeXt)
@eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model)
end
Expand Down
23 changes: 14 additions & 9 deletions src/convnets/convmixer.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
convmixer(planes::Integer, depth::Integer; kernel_size = (9, 9),
convmixer(planes::Integer, depth::Integer; kernel_size::Dims{2} = (9, 9),
patch_size::Dims{2} = (7, 7), activation = gelu,
inchannels::Integer = 3, nclasses::Integer = 1000)

Expand All @@ -13,20 +13,25 @@ Creates a ConvMixer model.
- `kernel_size`: kernel size of the convolutional layers
- `patch_size`: size of the patches
- `activation`: activation function used after the convolutional layers
- `inchannels`: The number of channels in the input.
- `inchannels`: number of input channels
- `nclasses`: number of classes in the output
"""
function convmixer(planes::Integer, depth::Integer; kernel_size = (9, 9),
patch_size::Dims{2} = (7, 7), activation = gelu,
function convmixer(planes::Integer, depth::Integer; kernel_size::Dims{2} = (9, 9),
patch_size::Dims{2} = (7, 7), activation = gelu, dropout_rate = nothing,
inchannels::Integer = 3, nclasses::Integer = 1000)
stem = conv_norm(patch_size, inchannels, planes, activation; preact = true,
stride = patch_size[1])
blocks = [Chain(SkipConnection(Chain(conv_norm(kernel_size, planes, planes, activation;
layers = []
# stem of the model
append!(layers,
conv_norm(patch_size, inchannels, planes, activation; preact = true,
stride = patch_size[1]))
# stages of the model
stages = [Chain(SkipConnection(Chain(conv_norm(kernel_size, planes, planes, activation;
preact = true, groups = planes,
pad = SamePad())), +),
conv_norm((1, 1), planes, planes, activation; preact = true)...)
for _ in 1:depth]
return Chain(Chain(stem..., Chain(blocks...)), create_classifier(planes, nclasses))
append!(layers, stages)
return Chain(Chain(layers...), create_classifier(planes, nclasses; dropout_rate))
end

const CONVMIXER_CONFIGS = Dict(:base => ((1536, 20),
Expand All @@ -48,7 +53,7 @@ Creates a ConvMixer model.
# Arguments

- `config`: the size of the model, either `:base`, `:small` or `:large`
- `inchannels`: The number of channels in the input.
- `inchannels`: number of input channels
- `nclasses`: number of classes in the output
"""
struct ConvMixer
Expand Down
2 changes: 1 addition & 1 deletion src/convnets/convnext.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ Creates a ConvNeXt model.
# Arguments

- `config`: The size of the model, one of `tiny`, `small`, `base`, `large` or `xlarge`.
- `inchannels`: The number of channels in the input.
- `inchannels`: number of input channels
- `nclasses`: number of output classes

See also [`Metalhead.convnext`](#).
Expand Down
30 changes: 17 additions & 13 deletions src/convnets/densenet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ Create a Densenet bottleneck layer
- `outplanes`: number of output feature maps on bottleneck branch
(and scaling factor for inner feature maps; see ref)
"""
function dense_bottleneck(inplanes::Integer, outplanes::Integer)
inner_channels = 4 * outplanes
return SkipConnection(Chain(conv_norm((1, 1), inplanes, inner_channels; bias = false,
function dense_bottleneck(inplanes::Integer, outplanes::Integer; expansion::Integer = 4)
inner_channels = expansion * outplanes
return SkipConnection(Chain(conv_norm((1, 1), inplanes, inner_channels;
revnorm = true)...,
conv_norm((3, 3), inner_channels, outplanes; pad = 1,
bias = false, revnorm = true)...),
revnorm = true)...),
cat_channels)
end

Expand All @@ -31,7 +31,7 @@ Create a DenseNet transition sequence
- `outplanes`: number of output feature maps
"""
function transition(inplanes::Integer, outplanes::Integer)
return Chain(conv_norm((1, 1), inplanes, outplanes; bias = false, revnorm = true)...,
return Chain(conv_norm((1, 1), inplanes, outplanes; revnorm = true)...,
MeanPool((2, 2)))
end

Expand All @@ -55,7 +55,8 @@ function dense_block(inplanes::Integer, growth_rates)
end

"""
densenet(inplanes, growth_rates; reduction = 0.5, nclasses::Integer = 1000)
densenet(inplanes, growth_rates; reduction = 0.5, dropout_rate = nothing,
inchannels::Integer = 3, nclasses::Integer = 1000)

Create a DenseNet model
([reference](https://arxiv.org/abs/1608.06993)).
Expand All @@ -66,13 +67,14 @@ Create a DenseNet model
- `growth_rates`: the growth rates of output feature maps within each
[`dense_block`](#) (a vector of vectors)
- `reduction`: the factor by which the number of feature maps is scaled across each transition
- `dropout_rate`: the dropout rate for the classifier head. Set to `nothing` to disable dropout.
- `nclasses`: the number of output classes
"""
function densenet(inplanes::Integer, growth_rates; reduction = 0.5, inchannels::Integer = 3,
nclasses::Integer = 1000)
function densenet(inplanes::Integer, growth_rates; reduction = 0.5, dropout_rate = nothing,
inchannels::Integer = 3, nclasses::Integer = 1000)
layers = []
append!(layers,
conv_norm((7, 7), inchannels, inplanes; stride = 2, pad = (3, 3), bias = false))
conv_norm((7, 7), inchannels, inplanes; stride = 2, pad = (3, 3)))
push!(layers, MaxPool((3, 3); stride = 2, pad = (1, 1)))
outplanes = 0
for (i, rates) in enumerate(growth_rates)
Expand All @@ -83,7 +85,7 @@ function densenet(inplanes::Integer, growth_rates; reduction = 0.5, inchannels::
inplanes = floor(Int, outplanes * reduction)
end
push!(layers, BatchNorm(outplanes, relu))
return Chain(Chain(layers...), create_classifier(outplanes, nclasses))
return Chain(Chain(layers...), create_classifier(outplanes, nclasses; dropout_rate))
end

"""
Expand All @@ -100,9 +102,10 @@ Create a DenseNet model
- `nclasses`: the number of output classes
"""
function densenet(nblocks::AbstractVector{<:Integer}; growth_rate::Integer = 32,
reduction = 0.5, inchannels::Integer = 3, nclasses::Integer = 1000)
reduction = 0.5, dropout_rate = nothing, inchannels::Integer = 3,
nclasses::Integer = 1000)
return densenet(2 * growth_rate, [fill(growth_rate, n) for n in nblocks];
reduction, inchannels, nclasses)
reduction, dropout_rate, inchannels, nclasses)
end

const DENSENET_CONFIGS = Dict(121 => [6, 12, 24, 16],
Expand Down Expand Up @@ -132,7 +135,8 @@ end
function DenseNet(config::Integer; pretrain::Bool = false, growth_rate::Integer = 32,
reduction = 0.5, inchannels::Integer = 3, nclasses::Integer = 1000)
_checkconfig(config, keys(DENSENET_CONFIGS))
layers = densenet(DENSENET_CONFIGS[config]; growth_rate, reduction, inchannels, nclasses)
layers = densenet(DENSENET_CONFIGS[config]; growth_rate, reduction, inchannels,
nclasses)
if pretrain
loadpretrain!(layers, string("densenet", config))
end
Expand Down
116 changes: 0 additions & 116 deletions src/convnets/efficientnet.jl

This file was deleted.

78 changes: 78 additions & 0 deletions src/convnets/efficientnets/core.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
abstract type _MBConfig end

struct MBConvConfig <: _MBConfig
kernel_size::Dims{2}
inplanes::Integer
outplanes::Integer
expansion::Real
stride::Integer
nrepeats::Integer
end
function MBConvConfig(kernel_size::Integer, inplanes::Integer, outplanes::Integer,
expansion::Real, stride::Integer, nrepeats::Integer,
width_mult::Real = 1, depth_mult::Real = 1)
inplanes = _round_channels(inplanes * width_mult, 8)
outplanes = _round_channels(outplanes * width_mult, 8)
nrepeats = ceil(Int, nrepeats * depth_mult)
return MBConvConfig((kernel_size, kernel_size), inplanes, outplanes, expansion,
stride, nrepeats)
end

function efficientnetblock(m::MBConvConfig, norm_layer)
layers = []
explanes = _round_channels(m.inplanes * m.expansion, 8)
push!(layers,
mbconv(m.kernel_size, m.inplanes, explanes, m.outplanes, swish; norm_layer,
stride = m.stride, reduction = 4))
explanes = _round_channels(m.outplanes * m.expansion, 8)
append!(layers,
[mbconv(m.kernel_size, m.outplanes, explanes, m.outplanes, swish; norm_layer,
stride = 1, reduction = 4) for _ in 1:(m.nrepeats - 1)])
return Chain(layers...)
end

struct FusedMBConvConfig <: _MBConfig
kernel_size::Dims{2}
inplanes::Integer
outplanes::Integer
expansion::Real
stride::Integer
nrepeats::Integer
end
function FusedMBConvConfig(kernel_size::Integer, inplanes::Integer, outplanes::Integer,
expansion::Real, stride::Integer, nrepeats::Integer)
return FusedMBConvConfig((kernel_size, kernel_size), inplanes, outplanes, expansion,
stride, nrepeats)
end

function efficientnetblock(m::FusedMBConvConfig, norm_layer)
layers = []
explanes = _round_channels(m.inplanes * m.expansion, 8)
push!(layers,
fused_mbconv(m.kernel_size, m.inplanes, explanes, m.outplanes, swish;
norm_layer, stride = m.stride))
explanes = _round_channels(m.outplanes * m.expansion, 8)
append!(layers,
[fused_mbconv(m.kernel_size, m.outplanes, explanes, m.outplanes, swish;
norm_layer, stride = 1) for _ in 1:(m.nrepeats - 1)])
return Chain(layers...)
end

function efficientnet(block_configs::AbstractVector{<:_MBConfig};
headplanes::Union{Nothing, Integer} = nothing,
norm_layer = BatchNorm, dropout_rate = nothing,
inchannels::Integer = 3, nclasses::Integer = 1000)
layers = []
# stem of the model
append!(layers,
conv_norm((3, 3), inchannels, block_configs[1].inplanes, swish; norm_layer,
stride = 2, pad = SamePad()))
# building inverted residual blocks
append!(layers, [efficientnetblock(cfg, norm_layer) for cfg in block_configs])
# building last layers
outplanes = block_configs[end].outplanes
headplanes = isnothing(headplanes) ? outplanes * 4 : headplanes
append!(layers,
conv_norm((1, 1), outplanes, headplanes, swish; pad = SamePad()))
return Chain(Chain(layers...), create_classifier(headplanes, nclasses; dropout_rate))
end
Loading