Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of EfficientNetv2 and MNASNet #198

Merged
merged 36 commits into from
Aug 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
49faf09
Initial commit for EfficientNetv2
theabhirath Aug 4, 2022
d428da0
Cleanup
theabhirath Aug 5, 2022
093f573
Add docs for EfficientNetv2
theabhirath Aug 11, 2022
add4d41
Add tests
theabhirath Aug 11, 2022
7d56396
Fix Inception bug, and other misc. cleanup
theabhirath Aug 11, 2022
d3e4add
Refactor: `mbconv` instead of `invertedresidual`
theabhirath Aug 11, 2022
6504f25
Refactor EfficientNets
theabhirath Aug 12, 2022
3e7e5f7
Refactor EfficientNets
theabhirath Aug 12, 2022
20a4b37
Fixes
theabhirath Aug 13, 2022
34caab4
Merge branch 'effnetv2' of https://github.com/theabhirath/Metalhead.j…
theabhirath Aug 13, 2022
ea13dd5
Some refactors, some consistency, some features
theabhirath Aug 14, 2022
c998951
The real hero was `block_idx` all along
theabhirath Aug 15, 2022
fc03d70
Fix minor hiccups
theabhirath Aug 15, 2022
f2461a5
Moving closer to the one true function
theabhirath Aug 16, 2022
9f6b987
Some more reorganisation
theabhirath Aug 16, 2022
785a95a
Huge refactor of MobileNet and EfficientNet families
theabhirath Aug 18, 2022
9e91783
Initial commit for EfficientNetv2
theabhirath Aug 4, 2022
4a94569
Cleanup
theabhirath Aug 5, 2022
69563a6
Add docs for EfficientNetv2
theabhirath Aug 11, 2022
70841f6
Add tests
theabhirath Aug 11, 2022
245eda0
Fix Inception bug, and other misc. cleanup
theabhirath Aug 11, 2022
1c65159
Refactor: `mbconv` instead of `invertedresidual`
theabhirath Aug 11, 2022
744f214
Refactor EfficientNets
theabhirath Aug 12, 2022
9bc75fc
Fixes
theabhirath Aug 13, 2022
3a37a70
Some refactors, some consistency, some features
theabhirath Aug 14, 2022
593752f
The real hero was `block_idx` all along
theabhirath Aug 15, 2022
55bc544
Fix minor hiccups
theabhirath Aug 15, 2022
3d63f72
Moving closer to the one true function
theabhirath Aug 16, 2022
818c584
Some more reorganisation
theabhirath Aug 16, 2022
8092818
Huge refactor of MobileNet and EfficientNet families
theabhirath Aug 18, 2022
510e913
Add MNASNet
theabhirath Aug 20, 2022
8d323e6
Merge branch 'effnetv2' of https://github.com/theabhirath/Metalhead.j…
theabhirath Aug 20, 2022
ab2a15e
Add tests for MNASNet
theabhirath Aug 20, 2022
2d310a9
Final cleanup, hopefully
theabhirath Aug 23, 2022
f76fadb
Minor refactor of `cnn_stages`
theabhirath Aug 23, 2022
992f6a6
`_round_channels` all the way
theabhirath Aug 23, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ jobs:
- x64
suite:
- '["AlexNet", "VGG"]'
- '["GoogLeNet", "SqueezeNet", "MobileNet"]'
- '["EfficientNet"]'
- '["GoogLeNet", "SqueezeNet", "MobileNets"]'
- '"EfficientNet"'
- 'r"/*/ResNet*"'
- '[r"ResNeXt", r"SEResNet"]'
- 'r"/*/SEResNet*"'
- '[r"Res2Net", r"Res2NeXt"]'
- '"Inception"'
- '"DenseNet"'
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ Flux = "0.13"
Functors = "0.2, 0.3"
CUDA = "3"
ChainRulesCore = "1"
PartialFunctions = "1"
MLUtils = "0.2.10"
NNlib = "0.8"
NNlibCUDA = "0.2"
PartialFunctions = "1"
julia = "1.6"

[publish]
Expand Down
38 changes: 25 additions & 13 deletions src/Metalhead.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ include("layers/Layers.jl")
using .Layers

# CNN models
## Builders
include("convnets/builders/core.jl")
include("convnets/builders/mbconv.jl")
include("convnets/builders/resblocks.jl")
## AlexNet and VGG
include("convnets/alexnet.jl")
include("convnets/vgg.jl")
## ResNets
Expand All @@ -28,19 +33,23 @@ include("convnets/resnets/resnext.jl")
include("convnets/resnets/seresnet.jl")
include("convnets/resnets/res2net.jl")
## Inceptions
include("convnets/inception/googlenet.jl")
include("convnets/inception/inceptionv3.jl")
include("convnets/inception/inceptionv4.jl")
include("convnets/inception/inceptionresnetv2.jl")
include("convnets/inception/xception.jl")
include("convnets/inceptions/googlenet.jl")
include("convnets/inceptions/inceptionv3.jl")
include("convnets/inceptions/inceptionv4.jl")
include("convnets/inceptions/inceptionresnetv2.jl")
include("convnets/inceptions/xception.jl")
## EfficientNets
include("convnets/efficientnets/core.jl")
include("convnets/efficientnets/efficientnet.jl")
include("convnets/efficientnets/efficientnetv2.jl")
## MobileNets
include("convnets/mobilenet/mobilenetv1.jl")
include("convnets/mobilenet/mobilenetv2.jl")
include("convnets/mobilenet/mobilenetv3.jl")
include("convnets/mobilenets/mobilenetv1.jl")
include("convnets/mobilenets/mobilenetv2.jl")
include("convnets/mobilenets/mobilenetv3.jl")
include("convnets/mobilenets/mnasnet.jl")
## Others
include("convnets/densenet.jl")
include("convnets/squeezenet.jl")
include("convnets/efficientnet.jl")
include("convnets/convnext.jl")
include("convnets/convmixer.jl")

Expand All @@ -61,13 +70,16 @@ export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19,
WideResNet, ResNeXt, SEResNet, SEResNeXt, Res2Net, Res2NeXt,
DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201,
GoogLeNet, Inception3, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception,
SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, EfficientNet,
SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, MNASNet,
EfficientNet, EfficientNetv2,
MLPMixer, ResMLP, gMLP, ViT, ConvMixer, ConvNeXt

# use Flux._big_show to pretty print large models
for T in (:AlexNet, :VGG, :ResNet, :ResNeXt, :DenseNet, :SEResNet, :SEResNeXt,
:Res2Net, :Res2NeXt, :GoogLeNet, :Inceptionv3, :Inceptionv4,
:Xception, :SqueezeNet, :MobileNetv1, :MobileNetv2, :MobileNetv3, :EfficientNet,
for T in (:AlexNet, :VGG, :SqueezeNet, :ResNet, :WideResNet, :ResNeXt,
:SEResNet, :SEResNeXt, :Res2Net, :Res2NeXt, :GoogLeNet, :DenseNet,
:Inceptionv3, :Inceptionv4, :InceptionResNetv2, :Xception,
:MobileNetv1, :MobileNetv2, :MobileNetv3, :MNASNet,
:EfficientNet, :EfficientNetv2,
:MLPMixer, :ResMLP, :gMLP, :ViT, :ConvMixer, :ConvNeXt)
@eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model)
end
Expand Down
19 changes: 19 additions & 0 deletions src/convnets/builders/core.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
function cnn_stages(get_layers, block_repeats::AbstractVector{<:Integer},
connection = nothing)
# Construct each stage
stages = []
for (stage_idx, nblocks) in enumerate(block_repeats)
# Construct the blocks for each stage
blocks = map(1:nblocks) do block_idx
branches = get_layers(stage_idx, block_idx)
if isnothing(connection)
@assert length(branches)==1 "get_layers should return a single branch for
each block if no connection is specified"
end
return length(branches) == 1 ? only(branches) :
Parallel(connection, branches...)
end
push!(stages, Chain(blocks...))
end
return stages
end
107 changes: 107 additions & 0 deletions src/convnets/builders/mbconv.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
function dwsepconv_builder(block_configs, inplanes::Integer, stage_idx::Integer,
width_mult::Real; norm_layer = BatchNorm, kwargs...)
block_fn, k, outplanes, stride, nrepeats, activation = block_configs[stage_idx]
outplanes = _round_channels(outplanes * width_mult)
if stage_idx != 1
inplanes = _round_channels(block_configs[stage_idx - 1][3] * width_mult)
end
function get_layers(block_idx::Integer)
inplanes = block_idx == 1 ? inplanes : outplanes
stride = block_idx == 1 ? stride : 1
block = Chain(block_fn((k, k), inplanes, outplanes, activation;
stride, pad = SamePad(), norm_layer, kwargs...)...)
return (block,)
end
return get_layers, nrepeats
end

function mbconv_builder(block_configs, inplanes::Integer, stage_idx::Integer,
scalings::NTuple{2, Real}; norm_layer = BatchNorm,
divisor::Integer = 8, se_from_explanes::Bool = false,
kwargs...)
width_mult, depth_mult = scalings
block_fn, k, outplanes, expansion, stride, nrepeats, reduction, activation = block_configs[stage_idx]
# calculate number of reduced channels for squeeze-excite layer from explanes instead of inplanes
if !isnothing(reduction)
reduction = !se_from_explanes ? reduction * expansion : reduction
end
if stage_idx != 1
inplanes = _round_channels(block_configs[stage_idx - 1][3] * width_mult, divisor)
end
outplanes = _round_channels(outplanes * width_mult, divisor)
function get_layers(block_idx::Integer)
inplanes = block_idx == 1 ? inplanes : outplanes
explanes = _round_channels(inplanes * expansion, divisor)
stride = block_idx == 1 ? stride : 1
block = block_fn((k, k), inplanes, explanes, outplanes, activation; norm_layer,
stride, reduction, kwargs...)
return stride == 1 && inplanes == outplanes ? (identity, block) : (block,)
end
return get_layers, ceil(Int, nrepeats * depth_mult)
end

function mbconv_builder(block_configs, inplanes::Integer, stage_idx::Integer,
width_mult::Real; norm_layer = BatchNorm, kwargs...)
return mbconv_builder(block_configs, inplanes, stage_idx, (width_mult, 1);
norm_layer, kwargs...)
end

function fused_mbconv_builder(block_configs, inplanes::Integer, stage_idx::Integer;
norm_layer = BatchNorm, kwargs...)
block_fn, k, outplanes, expansion, stride, nrepeats, activation = block_configs[stage_idx]
inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1][3]
function get_layers(block_idx::Integer)
inplanes = block_idx == 1 ? inplanes : outplanes
explanes = _round_channels(inplanes * expansion, 8)
stride = block_idx == 1 ? stride : 1
block = block_fn((k, k), inplanes, explanes, outplanes, activation;
norm_layer, stride, kwargs...)
return stride == 1 && inplanes == outplanes ? (identity, block) : (block,)
end
return get_layers, nrepeats
end

# TODO - these builders need to be more flexible to potentially specify stuff like
# activation functions and reductions that don't change
function _get_builder(::typeof(dwsep_conv_bn), block_configs::AbstractVector{<:Tuple},
inplanes::Integer, stage_idx::Integer;
scalings::Union{Nothing, NTuple{2, Real}} = nothing,
width_mult::Union{Nothing, Number} = nothing, norm_layer, kwargs...)
@assert isnothing(scalings) "dwsep_conv_bn does not support the `scalings` argument"
return dwsepconv_builder(block_configs, inplanes, stage_idx, width_mult; norm_layer,
kwargs...)
end

function _get_builder(::typeof(mbconv), block_configs::AbstractVector{<:Tuple},
inplanes::Integer, stage_idx::Integer;
scalings::Union{Nothing, NTuple{2, Real}} = nothing,
width_mult::Union{Nothing, Number} = nothing, norm_layer, kwargs...)
if isnothing(scalings)
return mbconv_builder(block_configs, inplanes, stage_idx, width_mult; norm_layer,
kwargs...)
elseif isnothing(width_mult)
return mbconv_builder(block_configs, inplanes, stage_idx, scalings; norm_layer,
kwargs...)
else
throw(ArgumentError("Only one of `scalings` and `width_mult` can be specified"))
end
end

function _get_builder(::typeof(fused_mbconv), block_configs::AbstractVector{<:Tuple},
inplanes::Integer, stage_idx::Integer;
scalings::Union{Nothing, NTuple{2, Real}} = nothing,
width_mult::Union{Nothing, Number} = nothing, norm_layer)
@assert isnothing(width_mult) "fused_mbconv does not support the `width_mult` argument."
@assert isnothing(scalings)||scalings == (1, 1) "fused_mbconv does not support the `scalings` argument"
return fused_mbconv_builder(block_configs, inplanes, stage_idx; norm_layer)
end

function mbconv_stack_builder(block_configs::AbstractVector{<:Tuple}, inplanes::Integer;
scalings::Union{Nothing, NTuple{2, Real}} = nothing,
width_mult::Union{Nothing, Number} = nothing,
norm_layer = BatchNorm, kwargs...)
bxs = [_get_builder(block_configs[idx][1], block_configs, inplanes, idx; scalings,
width_mult, norm_layer, kwargs...)
for idx in eachindex(block_configs)]
return (stage_idx, block_idx) -> first.(bxs)[stage_idx](block_idx), last.(bxs)
end
71 changes: 71 additions & 0 deletions src/convnets/builders/resblocks.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
function basicblock_builder(block_repeats::AbstractVector{<:Integer};
inplanes::Integer = 64, reduction_factor::Integer = 1,
expansion::Integer = 1, norm_layer = BatchNorm,
revnorm::Bool = false, activation = relu,
attn_fn = planes -> identity,
drop_block_rate = nothing, drop_path_rate = nothing,
stride_fn = resnet_stride, planes_fn = resnet_planes,
downsample_tuple = (downsample_conv, downsample_identity))
# DropBlock, DropPath both take in rates based on a linear scaling schedule
# Also get `planes_vec` needed for block `inplanes` and `planes` calculations
pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats))
blockschedule = linear_scheduler(drop_block_rate; depth = sum(block_repeats))
planes_vec = collect(planes_fn(block_repeats))
# closure over `idxs`
function get_layers(stage_idx::Integer, block_idx::Integer)
# DropBlock, DropPath both take in rates based on a linear scaling schedule
# This is also needed for block `inplanes` and `planes` calculations
schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx
planes = planes_vec[schedule_idx]
inplanes = schedule_idx == 1 ? inplanes : planes_vec[schedule_idx - 1] * expansion
# `resnet_stride` is a callback that the user can tweak to change the stride of the
# blocks. It defaults to the standard behaviour as in the paper
stride = stride_fn(stage_idx, block_idx)
downsample_fn = stride != 1 || inplanes != planes * expansion ?
downsample_tuple[1] : downsample_tuple[2]
drop_path = DropPath(pathschedule[schedule_idx])
drop_block = DropBlock(blockschedule[schedule_idx])
block = basicblock(inplanes, planes; stride, reduction_factor, activation,
norm_layer, revnorm, attn_fn, drop_path, drop_block)
downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer,
revnorm)
return block, downsample
end
return get_layers
end

function bottleneck_builder(block_repeats::AbstractVector{<:Integer};
inplanes::Integer = 64, cardinality::Integer = 1,
base_width::Integer = 64, reduction_factor::Integer = 1,
expansion::Integer = 4, norm_layer = BatchNorm,
revnorm::Bool = false, activation = relu,
attn_fn = planes -> identity,
drop_block_rate = nothing, drop_path_rate = nothing,
stride_fn = resnet_stride, planes_fn = resnet_planes,
downsample_tuple = (downsample_conv, downsample_identity))
pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats))
blockschedule = linear_scheduler(drop_block_rate; depth = sum(block_repeats))
planes_vec = collect(planes_fn(block_repeats))
# closure over `idxs`
function get_layers(stage_idx::Integer, block_idx::Integer)
# DropBlock, DropPath both take in rates based on a linear scaling schedule
# This is also needed for block `inplanes` and `planes` calculations
schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx
planes = planes_vec[schedule_idx]
inplanes = schedule_idx == 1 ? inplanes : planes_vec[schedule_idx - 1] * expansion
# `resnet_stride` is a callback that the user can tweak to change the stride of the
# blocks. It defaults to the standard behaviour as in the paper
stride = stride_fn(stage_idx, block_idx)
downsample_fn = stride != 1 || inplanes != planes * expansion ?
downsample_tuple[1] : downsample_tuple[2]
drop_path = DropPath(pathschedule[schedule_idx])
drop_block = DropBlock(blockschedule[schedule_idx])
block = bottleneck(inplanes, planes; stride, cardinality, base_width,
reduction_factor, activation, norm_layer, revnorm,
attn_fn, drop_path, drop_block)
downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer,
revnorm)
return block, downsample
end
return get_layers
end
23 changes: 14 additions & 9 deletions src/convnets/convmixer.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
convmixer(planes::Integer, depth::Integer; kernel_size = (9, 9),
convmixer(planes::Integer, depth::Integer; kernel_size::Dims{2} = (9, 9),
patch_size::Dims{2} = (7, 7), activation = gelu,
inchannels::Integer = 3, nclasses::Integer = 1000)

Expand All @@ -13,20 +13,25 @@ Creates a ConvMixer model.
- `kernel_size`: kernel size of the convolutional layers
- `patch_size`: size of the patches
- `activation`: activation function used after the convolutional layers
- `inchannels`: The number of channels in the input.
- `inchannels`: number of input channels
- `nclasses`: number of classes in the output
"""
function convmixer(planes::Integer, depth::Integer; kernel_size = (9, 9),
patch_size::Dims{2} = (7, 7), activation = gelu,
function convmixer(planes::Integer, depth::Integer; kernel_size::Dims{2} = (9, 9),
patch_size::Dims{2} = (7, 7), activation = gelu, dropout_rate = nothing,
inchannels::Integer = 3, nclasses::Integer = 1000)
stem = conv_norm(patch_size, inchannels, planes, activation; preact = true,
stride = patch_size[1])
blocks = [Chain(SkipConnection(Chain(conv_norm(kernel_size, planes, planes, activation;
layers = []
# stem of the model
append!(layers,
conv_norm(patch_size, inchannels, planes, activation; preact = true,
stride = patch_size[1]))
# stages of the model
stages = [Chain(SkipConnection(Chain(conv_norm(kernel_size, planes, planes, activation;
preact = true, groups = planes,
pad = SamePad())), +),
conv_norm((1, 1), planes, planes, activation; preact = true)...)
for _ in 1:depth]
return Chain(Chain(stem..., Chain(blocks...)), create_classifier(planes, nclasses))
append!(layers, stages)
return Chain(Chain(layers...), create_classifier(planes, nclasses; dropout_rate))
end

const CONVMIXER_CONFIGS = Dict(:base => ((1536, 20),
Expand All @@ -48,7 +53,7 @@ Creates a ConvMixer model.
# Arguments

- `config`: the size of the model, either `:base`, `:small` or `:large`
- `inchannels`: The number of channels in the input.
- `inchannels`: number of input channels
- `nclasses`: number of classes in the output
"""
struct ConvMixer
Expand Down
2 changes: 1 addition & 1 deletion src/convnets/convnext.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ Creates a ConvNeXt model.
# Arguments

- `config`: The size of the model, one of `tiny`, `small`, `base`, `large` or `xlarge`.
- `inchannels`: The number of channels in the input.
- `inchannels`: number of input channels
- `nclasses`: number of output classes

See also [`Metalhead.convnext`](#).
Expand Down
Loading