Skip to content

Commit

Permalink
Some refactors, some consistency, some features
Browse files Browse the repository at this point in the history
  • Loading branch information
theabhirath committed Aug 14, 2022
1 parent 34caab4 commit 62fcac3
Show file tree
Hide file tree
Showing 16 changed files with 209 additions and 119 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ jobs:
suite:
- '["AlexNet", "VGG"]'
- '["GoogLeNet", "SqueezeNet", "MobileNet"]'
- '"EfficientNet"'
- '"EfficientNetv2"'
- '"EfficientNet"'
- 'r"/*/ResNet*"'
- '[r"ResNeXt", r"SEResNet"]'
- '[r"Res2Net", r"Res2NeXt"]'
Expand Down
15 changes: 10 additions & 5 deletions src/convnets/convmixer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,21 @@ Creates a ConvMixer model.
- `nclasses`: number of classes in the output
"""
function convmixer(planes::Integer, depth::Integer; kernel_size::Dims{2} = (9, 9),
patch_size::Dims{2} = (7, 7), activation = gelu,
patch_size::Dims{2} = (7, 7), activation = gelu, dropout_rate = nothing,
inchannels::Integer = 3, nclasses::Integer = 1000)
stem = conv_norm(patch_size, inchannels, planes, activation; preact = true,
stride = patch_size[1])
blocks = [Chain(SkipConnection(Chain(conv_norm(kernel_size, planes, planes, activation;
layers = []
# stem of the model
append!(layers,
conv_norm(patch_size, inchannels, planes, activation; preact = true,
stride = patch_size[1]))
# stages of the model
stages = [Chain(SkipConnection(Chain(conv_norm(kernel_size, planes, planes, activation;
preact = true, groups = planes,
pad = SamePad())), +),
conv_norm((1, 1), planes, planes, activation; preact = true)...)
for _ in 1:depth]
return Chain(Chain(stem..., Chain(blocks...)), create_classifier(planes, nclasses))
append!(layers, stages)
return Chain(Chain(layers...), create_classifier(planes, nclasses; dropout_rate))
end

const CONVMIXER_CONFIGS = Dict(:base => ((1536, 20),
Expand Down
15 changes: 9 additions & 6 deletions src/convnets/densenet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ function dense_block(inplanes::Integer, growth_rates)
end

"""
densenet(inplanes, growth_rates; reduction = 0.5, nclasses::Integer = 1000)
densenet(inplanes, growth_rates; reduction = 0.5, dropout_rate = nothing,
inchannels::Integer = 3, nclasses::Integer = 1000)
Create a DenseNet model
([reference](https://arxiv.org/abs/1608.06993)).
Expand All @@ -66,10 +67,11 @@ Create a DenseNet model
- `growth_rates`: the growth rates of output feature maps within each
[`dense_block`](#) (a vector of vectors)
- `reduction`: the factor by which the number of feature maps is scaled across each transition
- `dropout_rate`: the dropout rate for the classifier head. Set to `nothing` to disable dropout.
- `nclasses`: the number of output classes
"""
function densenet(inplanes::Integer, growth_rates; reduction = 0.5, inchannels::Integer = 3,
nclasses::Integer = 1000)
function densenet(inplanes::Integer, growth_rates; reduction = 0.5, dropout_rate = nothing,
inchannels::Integer = 3, nclasses::Integer = 1000)
layers = []
append!(layers,
conv_norm((7, 7), inchannels, inplanes; stride = 2, pad = (3, 3)))
Expand All @@ -83,7 +85,7 @@ function densenet(inplanes::Integer, growth_rates; reduction = 0.5, inchannels::
inplanes = floor(Int, outplanes * reduction)
end
push!(layers, BatchNorm(outplanes, relu))
return Chain(Chain(layers...), create_classifier(outplanes, nclasses))
return Chain(Chain(layers...), create_classifier(outplanes, nclasses; dropout_rate))
end

"""
Expand All @@ -100,9 +102,10 @@ Create a DenseNet model
- `nclasses`: the number of output classes
"""
function densenet(nblocks::AbstractVector{<:Integer}; growth_rate::Integer = 32,
reduction = 0.5, inchannels::Integer = 3, nclasses::Integer = 1000)
reduction = 0.5, dropout_rate = nothing, inchannels::Integer = 3,
nclasses::Integer = 1000)
return densenet(2 * growth_rate, [fill(growth_rate, n) for n in nblocks];
reduction, inchannels, nclasses)
reduction, dropout_rate, inchannels, nclasses)
end

const DENSENET_CONFIGS = Dict(121 => [6, 12, 24, 16],
Expand Down
10 changes: 5 additions & 5 deletions src/convnets/efficientnets/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ struct MBConvConfig <: _MBConfig
kernel_size::Dims{2}
inplanes::Integer
outplanes::Integer
expansion::Number
expansion::Real
stride::Integer
nrepeats::Integer
end
function MBConvConfig(kernel_size::Integer, inplanes::Integer, outplanes::Integer,
expansion::Number, stride::Integer, nrepeats::Integer,
width_mult::Number = 1, depth_mult::Number = 1)
expansion::Real, stride::Integer, nrepeats::Integer,
width_mult::Real = 1, depth_mult::Real = 1)
inplanes = _round_channels(inplanes * width_mult, 8)
outplanes = _round_channels(outplanes * width_mult, 8)
nrepeats = ceil(Int, nrepeats * depth_mult)
Expand All @@ -35,12 +35,12 @@ struct FusedMBConvConfig <: _MBConfig
kernel_size::Dims{2}
inplanes::Integer
outplanes::Integer
expansion::Number
expansion::Real
stride::Integer
nrepeats::Integer
end
function FusedMBConvConfig(kernel_size::Integer, inplanes::Integer, outplanes::Integer,
expansion::Number, stride::Integer, nrepeats::Integer)
expansion::Real, stride::Integer, nrepeats::Integer)
return FusedMBConvConfig((kernel_size, kernel_size), inplanes, outplanes, expansion,
stride, nrepeats)
end
Expand Down
6 changes: 3 additions & 3 deletions src/convnets/inceptions/inceptionresnetv2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,18 @@ function block8(scale = 1.0f0; activation = identity)
end

"""
inceptionresnetv2(; inchannels::Integer = 3, dropout_rate = 0.0, nclasses::Integer = 1000)
inceptionresnetv2(; inchannels::Integer = 3, dropout_rate = nothing, nclasses::Integer = 1000)
Creates an InceptionResNetv2 model.
([reference](https://arxiv.org/abs/1602.07261))
# Arguments
- `inchannels`: number of input channels.
- `dropout_rate`: rate of dropout in classifier head.
- `dropout_rate`: rate of dropout in classifier head. Set to `nothing` to disable dropout.
- `nclasses`: the number of output classes.
"""
function inceptionresnetv2(; dropout_rate = 0.0, inchannels::Integer = 3,
function inceptionresnetv2(; dropout_rate = nothing, inchannels::Integer = 3,
nclasses::Integer = 1000)
backbone = Chain(basic_conv_bn((3, 3), inchannels, 32; stride = 2)...,
basic_conv_bn((3, 3), 32, 32)...,
Expand Down
6 changes: 3 additions & 3 deletions src/convnets/inceptions/inceptionv4.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,18 +85,18 @@ function inceptionv4_c()
end

"""
inceptionv4(; inchannels::Integer = 3, dropout_rate = 0.0, nclasses::Integer = 1000)
inceptionv4(; inchannels::Integer = 3, dropout_rate = nothing, nclasses::Integer = 1000)
Create an Inceptionv4 model.
([reference](https://arxiv.org/abs/1602.07261))
# Arguments
- `inchannels`: number of input channels.
- `dropout_rate`: rate of dropout in classifier head.
- `dropout_rate`: rate of dropout in classifier head. Set to `nothing` to disable dropout.
- `nclasses`: the number of output classes.
"""
function inceptionv4(; dropout_rate = 0.0, inchannels::Integer = 3,
function inceptionv4(; dropout_rate = nothing, inchannels::Integer = 3,
nclasses::Integer = 1000)
backbone = Chain(basic_conv_bn((3, 3), inchannels, 32; stride = 2)...,
basic_conv_bn((3, 3), 32, 32)...,
Expand Down
4 changes: 2 additions & 2 deletions src/convnets/inceptions/xception.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@ function xception_block(inchannels::Integer, outchannels::Integer, nrepeats::Int
end

"""
xception(; dropout_rate = 0.0, inchannels::Integer = 3, nclasses::Integer = 1000)
xception(; dropout_rate = nothing, inchannels::Integer = 3, nclasses::Integer = 1000)
Creates an Xception model.
([reference](https://arxiv.org/abs/1610.02357))
# Arguments
- `dropout_rate`: rate of dropout in classifier head.
- `dropout_rate`: rate of dropout in classifier head. Set to `nothing` to disable dropout.
- `inchannels`: number of input channels.
- `nclasses`: the number of output classes.
"""
Expand Down
9 changes: 6 additions & 3 deletions src/convnets/mobilenets/mobilenetv1.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""
mobilenetv1(width_mult::Real, config::AbstractVector{<:Tuple}; activation = relu,
mobilenetv1(width_mult::Real, config::AbstractVector{<:Tuple};
activation = relu, dropout_rate = nothing,
inchannels::Integer = 3, nclasses::Integer = 1000)
Create a MobileNetv1 model ([reference](https://arxiv.org/abs/1704.04861v1)).
Expand All @@ -16,10 +17,12 @@ Create a MobileNetv1 model ([reference](https://arxiv.org/abs/1704.04861v1)).
+ `s`: The stride of the convolutional kernel
+ `r`: The number of time this configuration block is repeated
- `activate`: The activation function to use throughout the network
- `dropout_rate`: The dropout rate to use in the classifier head. Set to `nothing` to disable.
- `inchannels`: The number of input channels. The default value is 3.
- `nclasses`: The number of output classes
"""
function mobilenetv1(width_mult::Real, config::AbstractVector{<:Tuple}; activation = relu,
function mobilenetv1(width_mult::Real, config::AbstractVector{<:Tuple};
activation = relu, dropout_rate = nothing,
inchannels::Integer = 3, nclasses::Integer = 1000)
layers = []
for (dw, outchannels, stride, nrepeats) in config
Expand All @@ -33,7 +36,7 @@ function mobilenetv1(width_mult::Real, config::AbstractVector{<:Tuple}; activati
inchannels = outchannels
end
end
return Chain(Chain(layers...), create_classifier(inchannels, nclasses))
return Chain(Chain(layers...), create_classifier(inchannels, nclasses; dropout_rate))
end

# Layer configurations for MobileNetv1
Expand Down
9 changes: 5 additions & 4 deletions src/convnets/mobilenets/mobilenetv2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Create a MobileNetv2 model.
(with 1 being the default in the paper)
- `max_width`: The maximum number of feature maps in any layer of the network
- `divisor`: The divisor used to round the number of feature maps in each block
- `dropout_rate`: rate of dropout in the classifier head
- `dropout_rate`: rate of dropout in the classifier head. Set to `nothing` to disable dropout.
- `inchannels`: The number of input channels.
- `nclasses`: The number of output classes
"""
Expand All @@ -33,12 +33,13 @@ function mobilenetv2(configs::AbstractVector{<:Tuple}; width_mult::Real = 1,
append!(layers,
conv_norm((3, 3), inchannels, inplanes; pad = 1, stride = 2))
# building inverted residual blocks
for (t, c, n, s, a) in configs
for (t, c, n, s, activation) in configs
outplanes = _round_channels(c * width_mult, divisor)
for i in 1:n
stride = i == 1 ? s : 1
push!(layers,
mbconv((3, 3), inplanes, round(Int, inplanes * t), outplanes, a;
stride = i == 1 ? s : 1))
mbconv((3, 3), inplanes, round(Int, inplanes * t), outplanes,
activation; stride))
inplanes = outplanes
end
end
Expand Down
37 changes: 24 additions & 13 deletions src/convnets/mobilenets/mobilenetv3.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1,
max_width::Integer = 1024, inchannels::Integer = 3,
nclasses::Integer = 1000)
max_width::Integer = 1024, dropout_rate = 0.2,
inchannels::Integer = 3, nclasses::Integer = 1000)
Create a MobileNetv3 model.
([reference](https://arxiv.org/abs/1905.02244)).
Expand All @@ -19,38 +19,49 @@ Create a MobileNetv3 model.
- `width_mult`: Controls the number of output feature maps in each block
(with 1 being the default in the paper; this is usually a value between 0.1 and 1.4.)
- `inchannels`: The number of input channels.
- `max_width`: The maximum number of feature maps in any layer of the network
- `dropout_rate`: The dropout rate to use in the classifier head. Set to `nothing` to disable.
- `inchannels`: The number of input channels.
- `nclasses`: the number of output classes
"""
function mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1,
max_width::Integer = 1024, dropout_rate = 0.2,
max_width::Integer = 1024, reduced_tail::Bool = false,
tail_dilated::Bool = false, dropout_rate = 0.2,
inchannels::Integer = 3, nclasses::Integer = 1000)
# building first layer
inplanes = _round_channels(16 * width_mult, 8)
layers = []
append!(layers,
conv_norm((3, 3), inchannels, inplanes, hardswish; stride = 2, pad = 1))
explanes = 0
nstages = length(configs)
reduced_divider = 1
# building inverted residual blocks
for (k, t, c, reduction, activation, stride) in configs
for (i, (k, t, c, reduction, activation, stride)) in enumerate(configs)
dilation = 1
if nstages - i <= 2
if reduced_tail
reduced_divider = 2
c /= reduced_divider
end
if tail_dilated
dilation = 2
end
end
# inverted residual layers
outplanes = _round_channels(c * width_mult, 8)
explanes = _round_channels(inplanes * t, 8)
push!(layers,
mbconv((k, k), inplanes, explanes, outplanes, activation;
stride, reduction))
stride, reduction, dilation))
inplanes = outplanes
end
# building last layers
headplanes = width_mult > 1.0 ? _round_channels(max_width * width_mult, 8) :
max_width
headplanes = _round_channels(max_width ÷ reduced_divider * width_mult, 8)
append!(layers, conv_norm((1, 1), inplanes, explanes, hardswish))
classifier = Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten,
Dense(explanes, headplanes, hardswish),
Dropout(dropout_rate),
Dense(headplanes, nclasses))
return Chain(Chain(layers...), classifier)
return Chain(Chain(layers...),
create_classifier(explanes, headplanes, nclasses,
(hardswish, identity); dropout_rate))
end

# Layer configurations for small and large models for MobileNetv3
Expand Down
19 changes: 9 additions & 10 deletions src/convnets/resnets/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,11 @@ function basicblock(inplanes::Integer, planes::Integer; stride::Integer = 1,
drop_block = identity, drop_path = identity,
attn_fn = planes -> identity)
first_planes = planes ÷ reduction_factor
outplanes = planes
conv_bn1 = conv_norm((3, 3), inplanes => first_planes, identity; norm_layer, revnorm,
stride, pad = 1)
conv_bn2 = conv_norm((3, 3), first_planes => outplanes, identity; norm_layer, revnorm,
conv_bn2 = conv_norm((3, 3), first_planes => planes, identity; norm_layer, revnorm,
pad = 1)
layers = [conv_bn1..., drop_block, activation, conv_bn2..., attn_fn(outplanes),
layers = [conv_bn1..., drop_block, activation, conv_bn2..., attn_fn(planes),
drop_path]
return Chain(filter!(!=(identity), layers)...)
end
Expand Down Expand Up @@ -201,7 +200,7 @@ function basicblock_builder(block_repeats::AbstractVector{<:Integer};
expansion::Integer = 1, norm_layer = BatchNorm,
revnorm::Bool = false, activation = relu,
attn_fn = planes -> identity,
drop_block_rate = 0.0, drop_path_rate = 0.0,
drop_block_rate = nothing, drop_path_rate = nothing,
stride_fn = resnet_stride, planes_fn = resnet_planes,
downsample_tuple = (downsample_conv, downsample_identity))
pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats))
Expand Down Expand Up @@ -236,7 +235,7 @@ function bottleneck_builder(block_repeats::AbstractVector{<:Integer};
expansion::Integer = 4, norm_layer = BatchNorm,
revnorm::Bool = false, activation = relu,
attn_fn = planes -> identity,
drop_block_rate = 0.0, drop_path_rate = 0.0,
drop_block_rate = nothing, drop_path_rate = nothing,
stride_fn = resnet_stride, planes_fn = resnet_planes,
downsample_tuple = (downsample_conv, downsample_identity))
pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats))
Expand Down Expand Up @@ -295,8 +294,8 @@ function resnet(block_type, block_repeats::AbstractVector{<:Integer},
inchannels::Integer = 3, stem_fn = resnet_stem, connection = addact,
activation = relu, norm_layer = BatchNorm, revnorm::Bool = false,
attn_fn = planes -> identity, pool_layer = AdaptiveMeanPool((1, 1)),
use_conv::Bool = false, drop_block_rate = 0.0, drop_path_rate = 0.0,
dropout_rate = 0.0, nclasses::Integer = 1000, kwargs...)
use_conv::Bool = false, drop_block_rate = nothing, drop_path_rate = nothing,
dropout_rate = nothing, nclasses::Integer = 1000, kwargs...)
# Build stem
stem = stem_fn(; inchannels)
# Block builder
Expand All @@ -319,8 +318,8 @@ function resnet(block_type, block_repeats::AbstractVector{<:Integer},
downsample_tuple = downsample_opt,
kwargs...)
elseif block_type == bottle2neck
@assert drop_block_rate==0.0 "DropBlock not supported for `bottle2neck`. Set `drop_block_rate` to 0.0"
@assert drop_path_rate==0.0 "DropPath not supported for `bottle2neck`. Set `drop_path_rate` to 0.0"
@assert isnothing(drop_block_rate) "DropBlock not supported for `bottle2neck`. Set `drop_block_rate` to nothing"
@assert isnothing(drop_path_rate) "DropPath not supported for `bottle2neck`. Set `drop_path_rate` to nothing"
@assert reduction_factor==1 "Reduction factor not supported for `bottle2neck`. Set `reduction_factor` to 1"
get_layers = bottle2neck_builder(block_repeats; inplanes, cardinality, base_width,
activation, norm_layer, revnorm, attn_fn,
Expand All @@ -347,7 +346,7 @@ const RESNET_CONFIGS = Dict(18 => (basicblock, [2, 2, 2, 2]),
50 => (bottleneck, [3, 4, 6, 3]),
101 => (bottleneck, [3, 4, 23, 3]),
152 => (bottleneck, [3, 8, 36, 3]))

# larger ResNet-like models
const LRESNET_CONFIGS = Dict(50 => (bottleneck, [3, 4, 6, 3]),
101 => (bottleneck, [3, 4, 23, 3]),
152 => (bottleneck, [3, 8, 36, 3]))
5 changes: 4 additions & 1 deletion src/layers/Layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ include("embeddings.jl")
export PatchEmbedding, ViPosEmbedding, ClassTokens

include("mlp.jl")
export mlp_block, gated_mlp_block, create_fc, create_classifier
export mlp_block, gated_mlp_block

include("classifier.jl")
export create_classifier

include("normalise.jl")
export prenorm, ChannelLayerNorm
Expand Down
Loading

0 comments on commit 62fcac3

Please sign in to comment.