From 62fcac362b5504d003e8112a6e0188389f98e115 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Sun, 14 Aug 2022 19:24:52 +0530 Subject: [PATCH] Some refactors, some consistency, some features --- .github/workflows/CI.yml | 3 +- src/convnets/convmixer.jl | 15 ++-- src/convnets/densenet.jl | 15 ++-- src/convnets/efficientnets/core.jl | 10 +-- src/convnets/inceptions/inceptionresnetv2.jl | 6 +- src/convnets/inceptions/inceptionv4.jl | 6 +- src/convnets/inceptions/xception.jl | 4 +- src/convnets/mobilenets/mobilenetv1.jl | 9 +- src/convnets/mobilenets/mobilenetv2.jl | 9 +- src/convnets/mobilenets/mobilenetv3.jl | 37 +++++--- src/convnets/resnets/core.jl | 19 ++-- src/layers/Layers.jl | 5 +- src/layers/classifier.jl | 93 ++++++++++++++++++++ src/layers/conv.jl | 9 +- src/layers/drop.jl | 44 +++++---- src/layers/mlp.jl | 44 +-------- 16 files changed, 209 insertions(+), 119 deletions(-) create mode 100644 src/layers/classifier.jl diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 37cda3263..5304bc317 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -28,8 +28,7 @@ jobs: suite: - '["AlexNet", "VGG"]' - '["GoogLeNet", "SqueezeNet", "MobileNet"]' - - '"EfficientNet"' - - '"EfficientNetv2"' + - '"EfficientNet"' - 'r"/*/ResNet*"' - '[r"ResNeXt", r"SEResNet"]' - '[r"Res2Net", r"Res2NeXt"]' diff --git a/src/convnets/convmixer.jl b/src/convnets/convmixer.jl index 1ca8487a9..bc1a71a5f 100644 --- a/src/convnets/convmixer.jl +++ b/src/convnets/convmixer.jl @@ -17,16 +17,21 @@ Creates a ConvMixer model. - `nclasses`: number of classes in the output """ function convmixer(planes::Integer, depth::Integer; kernel_size::Dims{2} = (9, 9), - patch_size::Dims{2} = (7, 7), activation = gelu, + patch_size::Dims{2} = (7, 7), activation = gelu, dropout_rate = nothing, inchannels::Integer = 3, nclasses::Integer = 1000) - stem = conv_norm(patch_size, inchannels, planes, activation; preact = true, - stride = patch_size[1]) - blocks = [Chain(SkipConnection(Chain(conv_norm(kernel_size, planes, planes, activation; + layers = [] + # stem of the model + append!(layers, + conv_norm(patch_size, inchannels, planes, activation; preact = true, + stride = patch_size[1])) + # stages of the model + stages = [Chain(SkipConnection(Chain(conv_norm(kernel_size, planes, planes, activation; preact = true, groups = planes, pad = SamePad())), +), conv_norm((1, 1), planes, planes, activation; preact = true)...) for _ in 1:depth] - return Chain(Chain(stem..., Chain(blocks...)), create_classifier(planes, nclasses)) + append!(layers, stages) + return Chain(Chain(layers...), create_classifier(planes, nclasses; dropout_rate)) end const CONVMIXER_CONFIGS = Dict(:base => ((1536, 20), diff --git a/src/convnets/densenet.jl b/src/convnets/densenet.jl index ca81b78ea..a7c367c1c 100644 --- a/src/convnets/densenet.jl +++ b/src/convnets/densenet.jl @@ -55,7 +55,8 @@ function dense_block(inplanes::Integer, growth_rates) end """ - densenet(inplanes, growth_rates; reduction = 0.5, nclasses::Integer = 1000) + densenet(inplanes, growth_rates; reduction = 0.5, dropout_rate = nothing, + inchannels::Integer = 3, nclasses::Integer = 1000) Create a DenseNet model ([reference](https://arxiv.org/abs/1608.06993)). @@ -66,10 +67,11 @@ Create a DenseNet model - `growth_rates`: the growth rates of output feature maps within each [`dense_block`](#) (a vector of vectors) - `reduction`: the factor by which the number of feature maps is scaled across each transition + - `dropout_rate`: the dropout rate for the classifier head. Set to `nothing` to disable dropout. - `nclasses`: the number of output classes """ -function densenet(inplanes::Integer, growth_rates; reduction = 0.5, inchannels::Integer = 3, - nclasses::Integer = 1000) +function densenet(inplanes::Integer, growth_rates; reduction = 0.5, dropout_rate = nothing, + inchannels::Integer = 3, nclasses::Integer = 1000) layers = [] append!(layers, conv_norm((7, 7), inchannels, inplanes; stride = 2, pad = (3, 3))) @@ -83,7 +85,7 @@ function densenet(inplanes::Integer, growth_rates; reduction = 0.5, inchannels:: inplanes = floor(Int, outplanes * reduction) end push!(layers, BatchNorm(outplanes, relu)) - return Chain(Chain(layers...), create_classifier(outplanes, nclasses)) + return Chain(Chain(layers...), create_classifier(outplanes, nclasses; dropout_rate)) end """ @@ -100,9 +102,10 @@ Create a DenseNet model - `nclasses`: the number of output classes """ function densenet(nblocks::AbstractVector{<:Integer}; growth_rate::Integer = 32, - reduction = 0.5, inchannels::Integer = 3, nclasses::Integer = 1000) + reduction = 0.5, dropout_rate = nothing, inchannels::Integer = 3, + nclasses::Integer = 1000) return densenet(2 * growth_rate, [fill(growth_rate, n) for n in nblocks]; - reduction, inchannels, nclasses) + reduction, dropout_rate, inchannels, nclasses) end const DENSENET_CONFIGS = Dict(121 => [6, 12, 24, 16], diff --git a/src/convnets/efficientnets/core.jl b/src/convnets/efficientnets/core.jl index 1059cb538..7a221c0e4 100644 --- a/src/convnets/efficientnets/core.jl +++ b/src/convnets/efficientnets/core.jl @@ -4,13 +4,13 @@ struct MBConvConfig <: _MBConfig kernel_size::Dims{2} inplanes::Integer outplanes::Integer - expansion::Number + expansion::Real stride::Integer nrepeats::Integer end function MBConvConfig(kernel_size::Integer, inplanes::Integer, outplanes::Integer, - expansion::Number, stride::Integer, nrepeats::Integer, - width_mult::Number = 1, depth_mult::Number = 1) + expansion::Real, stride::Integer, nrepeats::Integer, + width_mult::Real = 1, depth_mult::Real = 1) inplanes = _round_channels(inplanes * width_mult, 8) outplanes = _round_channels(outplanes * width_mult, 8) nrepeats = ceil(Int, nrepeats * depth_mult) @@ -35,12 +35,12 @@ struct FusedMBConvConfig <: _MBConfig kernel_size::Dims{2} inplanes::Integer outplanes::Integer - expansion::Number + expansion::Real stride::Integer nrepeats::Integer end function FusedMBConvConfig(kernel_size::Integer, inplanes::Integer, outplanes::Integer, - expansion::Number, stride::Integer, nrepeats::Integer) + expansion::Real, stride::Integer, nrepeats::Integer) return FusedMBConvConfig((kernel_size, kernel_size), inplanes, outplanes, expansion, stride, nrepeats) end diff --git a/src/convnets/inceptions/inceptionresnetv2.jl b/src/convnets/inceptions/inceptionresnetv2.jl index 7f462c0cf..bd88648e9 100644 --- a/src/convnets/inceptions/inceptionresnetv2.jl +++ b/src/convnets/inceptions/inceptionresnetv2.jl @@ -64,7 +64,7 @@ function block8(scale = 1.0f0; activation = identity) end """ - inceptionresnetv2(; inchannels::Integer = 3, dropout_rate = 0.0, nclasses::Integer = 1000) + inceptionresnetv2(; inchannels::Integer = 3, dropout_rate = nothing, nclasses::Integer = 1000) Creates an InceptionResNetv2 model. ([reference](https://arxiv.org/abs/1602.07261)) @@ -72,10 +72,10 @@ Creates an InceptionResNetv2 model. # Arguments - `inchannels`: number of input channels. - - `dropout_rate`: rate of dropout in classifier head. + - `dropout_rate`: rate of dropout in classifier head. Set to `nothing` to disable dropout. - `nclasses`: the number of output classes. """ -function inceptionresnetv2(; dropout_rate = 0.0, inchannels::Integer = 3, +function inceptionresnetv2(; dropout_rate = nothing, inchannels::Integer = 3, nclasses::Integer = 1000) backbone = Chain(basic_conv_bn((3, 3), inchannels, 32; stride = 2)..., basic_conv_bn((3, 3), 32, 32)..., diff --git a/src/convnets/inceptions/inceptionv4.jl b/src/convnets/inceptions/inceptionv4.jl index b43f6bc1d..13d40da25 100644 --- a/src/convnets/inceptions/inceptionv4.jl +++ b/src/convnets/inceptions/inceptionv4.jl @@ -85,7 +85,7 @@ function inceptionv4_c() end """ - inceptionv4(; inchannels::Integer = 3, dropout_rate = 0.0, nclasses::Integer = 1000) + inceptionv4(; inchannels::Integer = 3, dropout_rate = nothing, nclasses::Integer = 1000) Create an Inceptionv4 model. ([reference](https://arxiv.org/abs/1602.07261)) @@ -93,10 +93,10 @@ Create an Inceptionv4 model. # Arguments - `inchannels`: number of input channels. - - `dropout_rate`: rate of dropout in classifier head. + - `dropout_rate`: rate of dropout in classifier head. Set to `nothing` to disable dropout. - `nclasses`: the number of output classes. """ -function inceptionv4(; dropout_rate = 0.0, inchannels::Integer = 3, +function inceptionv4(; dropout_rate = nothing, inchannels::Integer = 3, nclasses::Integer = 1000) backbone = Chain(basic_conv_bn((3, 3), inchannels, 32; stride = 2)..., basic_conv_bn((3, 3), 32, 32)..., diff --git a/src/convnets/inceptions/xception.jl b/src/convnets/inceptions/xception.jl index 33222e7be..171bddd19 100644 --- a/src/convnets/inceptions/xception.jl +++ b/src/convnets/inceptions/xception.jl @@ -43,14 +43,14 @@ function xception_block(inchannels::Integer, outchannels::Integer, nrepeats::Int end """ - xception(; dropout_rate = 0.0, inchannels::Integer = 3, nclasses::Integer = 1000) + xception(; dropout_rate = nothing, inchannels::Integer = 3, nclasses::Integer = 1000) Creates an Xception model. ([reference](https://arxiv.org/abs/1610.02357)) # Arguments - - `dropout_rate`: rate of dropout in classifier head. + - `dropout_rate`: rate of dropout in classifier head. Set to `nothing` to disable dropout. - `inchannels`: number of input channels. - `nclasses`: the number of output classes. """ diff --git a/src/convnets/mobilenets/mobilenetv1.jl b/src/convnets/mobilenets/mobilenetv1.jl index caa899a53..542edec81 100644 --- a/src/convnets/mobilenets/mobilenetv1.jl +++ b/src/convnets/mobilenets/mobilenetv1.jl @@ -1,5 +1,6 @@ """ - mobilenetv1(width_mult::Real, config::AbstractVector{<:Tuple}; activation = relu, + mobilenetv1(width_mult::Real, config::AbstractVector{<:Tuple}; + activation = relu, dropout_rate = nothing, inchannels::Integer = 3, nclasses::Integer = 1000) Create a MobileNetv1 model ([reference](https://arxiv.org/abs/1704.04861v1)). @@ -16,10 +17,12 @@ Create a MobileNetv1 model ([reference](https://arxiv.org/abs/1704.04861v1)). + `s`: The stride of the convolutional kernel + `r`: The number of time this configuration block is repeated - `activate`: The activation function to use throughout the network + - `dropout_rate`: The dropout rate to use in the classifier head. Set to `nothing` to disable. - `inchannels`: The number of input channels. The default value is 3. - `nclasses`: The number of output classes """ -function mobilenetv1(width_mult::Real, config::AbstractVector{<:Tuple}; activation = relu, +function mobilenetv1(width_mult::Real, config::AbstractVector{<:Tuple}; + activation = relu, dropout_rate = nothing, inchannels::Integer = 3, nclasses::Integer = 1000) layers = [] for (dw, outchannels, stride, nrepeats) in config @@ -33,7 +36,7 @@ function mobilenetv1(width_mult::Real, config::AbstractVector{<:Tuple}; activati inchannels = outchannels end end - return Chain(Chain(layers...), create_classifier(inchannels, nclasses)) + return Chain(Chain(layers...), create_classifier(inchannels, nclasses; dropout_rate)) end # Layer configurations for MobileNetv1 diff --git a/src/convnets/mobilenets/mobilenetv2.jl b/src/convnets/mobilenets/mobilenetv2.jl index 232286309..d81256968 100644 --- a/src/convnets/mobilenets/mobilenetv2.jl +++ b/src/convnets/mobilenets/mobilenetv2.jl @@ -20,7 +20,7 @@ Create a MobileNetv2 model. (with 1 being the default in the paper) - `max_width`: The maximum number of feature maps in any layer of the network - `divisor`: The divisor used to round the number of feature maps in each block - - `dropout_rate`: rate of dropout in the classifier head + - `dropout_rate`: rate of dropout in the classifier head. Set to `nothing` to disable dropout. - `inchannels`: The number of input channels. - `nclasses`: The number of output classes """ @@ -33,12 +33,13 @@ function mobilenetv2(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, append!(layers, conv_norm((3, 3), inchannels, inplanes; pad = 1, stride = 2)) # building inverted residual blocks - for (t, c, n, s, a) in configs + for (t, c, n, s, activation) in configs outplanes = _round_channels(c * width_mult, divisor) for i in 1:n + stride = i == 1 ? s : 1 push!(layers, - mbconv((3, 3), inplanes, round(Int, inplanes * t), outplanes, a; - stride = i == 1 ? s : 1)) + mbconv((3, 3), inplanes, round(Int, inplanes * t), outplanes, + activation; stride)) inplanes = outplanes end end diff --git a/src/convnets/mobilenets/mobilenetv3.jl b/src/convnets/mobilenets/mobilenetv3.jl index 78c55e144..82265e125 100644 --- a/src/convnets/mobilenets/mobilenetv3.jl +++ b/src/convnets/mobilenets/mobilenetv3.jl @@ -1,7 +1,7 @@ """ mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, - max_width::Integer = 1024, inchannels::Integer = 3, - nclasses::Integer = 1000) + max_width::Integer = 1024, dropout_rate = 0.2, + inchannels::Integer = 3, nclasses::Integer = 1000) Create a MobileNetv3 model. ([reference](https://arxiv.org/abs/1905.02244)). @@ -19,12 +19,14 @@ Create a MobileNetv3 model. - `width_mult`: Controls the number of output feature maps in each block (with 1 being the default in the paper; this is usually a value between 0.1 and 1.4.) - - `inchannels`: The number of input channels. - `max_width`: The maximum number of feature maps in any layer of the network + - `dropout_rate`: The dropout rate to use in the classifier head. Set to `nothing` to disable. + - `inchannels`: The number of input channels. - `nclasses`: the number of output classes """ function mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, - max_width::Integer = 1024, dropout_rate = 0.2, + max_width::Integer = 1024, reduced_tail::Bool = false, + tail_dilated::Bool = false, dropout_rate = 0.2, inchannels::Integer = 3, nclasses::Integer = 1000) # building first layer inplanes = _round_channels(16 * width_mult, 8) @@ -32,25 +34,34 @@ function mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, append!(layers, conv_norm((3, 3), inchannels, inplanes, hardswish; stride = 2, pad = 1)) explanes = 0 + nstages = length(configs) + reduced_divider = 1 # building inverted residual blocks - for (k, t, c, reduction, activation, stride) in configs + for (i, (k, t, c, reduction, activation, stride)) in enumerate(configs) + dilation = 1 + if nstages - i <= 2 + if reduced_tail + reduced_divider = 2 + c /= reduced_divider + end + if tail_dilated + dilation = 2 + end + end # inverted residual layers outplanes = _round_channels(c * width_mult, 8) explanes = _round_channels(inplanes * t, 8) push!(layers, mbconv((k, k), inplanes, explanes, outplanes, activation; - stride, reduction)) + stride, reduction, dilation)) inplanes = outplanes end # building last layers - headplanes = width_mult > 1.0 ? _round_channels(max_width * width_mult, 8) : - max_width + headplanes = _round_channels(max_width ÷ reduced_divider * width_mult, 8) append!(layers, conv_norm((1, 1), inplanes, explanes, hardswish)) - classifier = Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, - Dense(explanes, headplanes, hardswish), - Dropout(dropout_rate), - Dense(headplanes, nclasses)) - return Chain(Chain(layers...), classifier) + return Chain(Chain(layers...), + create_classifier(explanes, headplanes, nclasses, + (hardswish, identity); dropout_rate)) end # Layer configurations for small and large models for MobileNetv3 diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index 35bb34fc4..458481d73 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -27,12 +27,11 @@ function basicblock(inplanes::Integer, planes::Integer; stride::Integer = 1, drop_block = identity, drop_path = identity, attn_fn = planes -> identity) first_planes = planes ÷ reduction_factor - outplanes = planes conv_bn1 = conv_norm((3, 3), inplanes => first_planes, identity; norm_layer, revnorm, stride, pad = 1) - conv_bn2 = conv_norm((3, 3), first_planes => outplanes, identity; norm_layer, revnorm, + conv_bn2 = conv_norm((3, 3), first_planes => planes, identity; norm_layer, revnorm, pad = 1) - layers = [conv_bn1..., drop_block, activation, conv_bn2..., attn_fn(outplanes), + layers = [conv_bn1..., drop_block, activation, conv_bn2..., attn_fn(planes), drop_path] return Chain(filter!(!=(identity), layers)...) end @@ -201,7 +200,7 @@ function basicblock_builder(block_repeats::AbstractVector{<:Integer}; expansion::Integer = 1, norm_layer = BatchNorm, revnorm::Bool = false, activation = relu, attn_fn = planes -> identity, - drop_block_rate = 0.0, drop_path_rate = 0.0, + drop_block_rate = nothing, drop_path_rate = nothing, stride_fn = resnet_stride, planes_fn = resnet_planes, downsample_tuple = (downsample_conv, downsample_identity)) pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats)) @@ -236,7 +235,7 @@ function bottleneck_builder(block_repeats::AbstractVector{<:Integer}; expansion::Integer = 4, norm_layer = BatchNorm, revnorm::Bool = false, activation = relu, attn_fn = planes -> identity, - drop_block_rate = 0.0, drop_path_rate = 0.0, + drop_block_rate = nothing, drop_path_rate = nothing, stride_fn = resnet_stride, planes_fn = resnet_planes, downsample_tuple = (downsample_conv, downsample_identity)) pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats)) @@ -295,8 +294,8 @@ function resnet(block_type, block_repeats::AbstractVector{<:Integer}, inchannels::Integer = 3, stem_fn = resnet_stem, connection = addact, activation = relu, norm_layer = BatchNorm, revnorm::Bool = false, attn_fn = planes -> identity, pool_layer = AdaptiveMeanPool((1, 1)), - use_conv::Bool = false, drop_block_rate = 0.0, drop_path_rate = 0.0, - dropout_rate = 0.0, nclasses::Integer = 1000, kwargs...) + use_conv::Bool = false, drop_block_rate = nothing, drop_path_rate = nothing, + dropout_rate = nothing, nclasses::Integer = 1000, kwargs...) # Build stem stem = stem_fn(; inchannels) # Block builder @@ -319,8 +318,8 @@ function resnet(block_type, block_repeats::AbstractVector{<:Integer}, downsample_tuple = downsample_opt, kwargs...) elseif block_type == bottle2neck - @assert drop_block_rate==0.0 "DropBlock not supported for `bottle2neck`. Set `drop_block_rate` to 0.0" - @assert drop_path_rate==0.0 "DropPath not supported for `bottle2neck`. Set `drop_path_rate` to 0.0" + @assert isnothing(drop_block_rate) "DropBlock not supported for `bottle2neck`. Set `drop_block_rate` to nothing" + @assert isnothing(drop_path_rate) "DropPath not supported for `bottle2neck`. Set `drop_path_rate` to nothing" @assert reduction_factor==1 "Reduction factor not supported for `bottle2neck`. Set `reduction_factor` to 1" get_layers = bottle2neck_builder(block_repeats; inplanes, cardinality, base_width, activation, norm_layer, revnorm, attn_fn, @@ -347,7 +346,7 @@ const RESNET_CONFIGS = Dict(18 => (basicblock, [2, 2, 2, 2]), 50 => (bottleneck, [3, 4, 6, 3]), 101 => (bottleneck, [3, 4, 23, 3]), 152 => (bottleneck, [3, 8, 36, 3])) - +# larger ResNet-like models const LRESNET_CONFIGS = Dict(50 => (bottleneck, [3, 4, 6, 3]), 101 => (bottleneck, [3, 4, 23, 3]), 152 => (bottleneck, [3, 8, 36, 3])) diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index 72ace2c2c..45615df5e 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -28,7 +28,10 @@ include("embeddings.jl") export PatchEmbedding, ViPosEmbedding, ClassTokens include("mlp.jl") -export mlp_block, gated_mlp_block, create_fc, create_classifier +export mlp_block, gated_mlp_block + +include("classifier.jl") +export create_classifier include("normalise.jl") export prenorm, ChannelLayerNorm diff --git a/src/layers/classifier.jl b/src/layers/classifier.jl new file mode 100644 index 000000000..bebdc4099 --- /dev/null +++ b/src/layers/classifier.jl @@ -0,0 +1,93 @@ +""" + create_classifier(inplanes::Integer, nclasses::Integer, activation = identity; + use_conv::Bool = false, pool_layer = AdaptiveMeanPool((1, 1)), + dropout_rate = nothing) + +Creates a classifier head to be used for models. + +# Arguments + + - `inplanes`: number of input feature maps + - `nclasses`: number of output classes + - `activation`: activation function to use + - `use_conv`: whether to use a 1x1 convolutional layer instead of a `Dense` layer. + - `pool_layer`: pooling layer to use. This is passed in with the layer instantiated with + any arguments that are needed i.e. as `AdaptiveMeanPool((1, 1))`, for example. + - `dropout_rate`: dropout rate used in the classifier head. Set to `nothing` to disable dropout. +""" +function create_classifier(inplanes::Integer, nclasses::Integer, activation = identity; + use_conv::Bool = false, pool_layer = AdaptiveMeanPool((1, 1)), + dropout_rate = nothing) + # Decide whether to flatten the input or not + flatten_in_pool = !use_conv && pool_layer !== identity + if use_conv + @assert pool_layer === identity + "`pool_layer` must be identity if `use_conv` is true" + end + classifier = [] + if flatten_in_pool + push!(classifier, pool_layer, MLUtils.flatten) + else + push!(classifier, pool_layer) + end + # Dropout is applied after the pooling layer + isnothing(dropout_rate) ? nothing : push!(classifier, Dropout(dropout_rate)) + # Fully-connected layer + if use_conv + push!(classifier, Conv((1, 1), inplanes => nclasses, activation)) + else + push!(classifier, Dense(inplanes => nclasses, activation)) + end + return Chain(classifier...) +end + +""" + create_classifier(inplanes::Integer, hidden_planes::Integer, nclasses::Integer, + activations::NTuple{2} = (relu, identity); + use_conv::NTuple{2, Bool} = (false, false), + pool_layer = AdaptiveMeanPool((1, 1)), dropout_rate = nothing) + +Creates a classifier head to be used for models with an extra hidden layer. + +# Arguments + + - `inplanes`: number of input feature maps + - `hidden_planes`: number of hidden feature maps + - `nclasses`: number of output classes + - `activations`: activation functions to use for the hidden and output layers. This is a + tuple of two elements, the first being the activation function for the hidden layer and the + second for the output layer. + - `use_conv`: whether to use a 1x1 convolutional layer instead of a `Dense` layer. This + is a tuple of two booleans, the first for the hidden layer and the second for the output + layer. + - `pool_layer`: pooling layer to use. This is passed in with the layer instantiated with + any arguments that are needed i.e. as `AdaptiveMeanPool((1, 1))`, for example. + - `dropout_rate`: dropout rate used in the classifier head. Set to `nothing` to disable dropout. +""" +function create_classifier(inplanes::Integer, hidden_planes::Integer, nclasses::Integer, + activations::NTuple{2, Any} = (relu, identity); + use_conv::NTuple{2, Bool} = (false, false), + pool_layer = AdaptiveMeanPool((1, 1)), dropout_rate = nothing) + fc_layers = [uc ? Conv$(1, 1) : Dense for uc in use_conv] + # Decide whether to flatten the input or not + flatten_in_pool = !use_conv[1] && pool_layer !== identity + if use_conv[1] + @assert pool_layer === identity + "`pool_layer` must be identity if `use_conv[1]` is true" + end + classifier = [] + if flatten_in_pool + push!(classifier, pool_layer, MLUtils.flatten) + else + push!(classifier, pool_layer) + end + # first fully-connected layer + if !isnothing(hidden_planes) + push!(classifier, fc_layers[1](inplanes => hidden_planes, activations[1])) + end + # Dropout is applied after the first dense layer + isnothing(dropout_rate) ? nothing : push!(classifier, Dropout(dropout_rate)) + # second fully-connected layer + push!(classifier, fc_layers[2](hidden_planes => nclasses, activations[2])) + return Chain(classifier...) +end diff --git a/src/layers/conv.jl b/src/layers/conv.jl index c94ceb045..bb39a0e07 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -146,9 +146,9 @@ Create a basic inverted residual block for MobileNet variants - `reduction`: The reduction factor for the number of hidden feature maps in a squeeze and excite layer (see [`squeeze_excite`](#)) """ -function mbconv(kernel_size::Dims{2}, inplanes::Integer, - explanes::Integer, outplanes::Integer, activation = relu; - stride::Integer, reduction::Union{Nothing, Integer} = nothing, +function mbconv(kernel_size::Dims{2}, inplanes::Integer, explanes::Integer, + outplanes::Integer, activation = relu; stride::Integer, + dilation::Integer = 1, reduction::Union{Nothing, Integer} = nothing, norm_layer = BatchNorm) @assert stride in [1, 2] "`stride` has to be 1 or 2" layers = [] @@ -158,9 +158,10 @@ function mbconv(kernel_size::Dims{2}, inplanes::Integer, conv_norm((1, 1), inplanes, explanes, activation; norm_layer)) end # depthwise + stride = dilation > 1 ? 1 : stride append!(layers, conv_norm(kernel_size, explanes, explanes, activation; norm_layer, - stride, pad = SamePad(), groups = explanes)) + stride, dilation, pad = SamePad(), groups = explanes)) # squeeze-excite layer if !isnothing(reduction) push!(layers, diff --git a/src/layers/drop.jl b/src/layers/drop.jl index b252584fe..387b562ef 100644 --- a/src/layers/drop.jl +++ b/src/layers/drop.jl @@ -20,7 +20,8 @@ regions of size `block_size` in the input. Otherwise, it simply returns the inpu - `rng`: can be used to pass in a custom RNG instead of the default. Custom RNGs are only supported on the CPU. - `x`: input array - - `drop_block_prob`: probability of dropping a block + - `drop_block_prob`: probability of dropping a block. If `nothing` is passed, it returns + `identity`. - `block_size`: size of the block to drop - `gamma_scale`: multiplicative factor for `gamma` used. For the calculations, refer to [the paper](https://arxiv.org/abs/1810.12890). @@ -56,11 +57,25 @@ dropblock_mask(rng, x, gamma, bs) = _dropblock_mask(rng, x, gamma, bs) The `DropBlock` layer. While training, it zeroes out continguous regions of size `block_size` in the input. During inference, it simply returns the input `x`. +It can be used in two ways: either with all blocks having the same survival probability +or with a linear scaling rule across the blocks. This is performed only at training time. +At test time, the `DropBlock` layer is equivalent to `identity`. + +!!! warning + + In the case of the linear scaling rule, the calculations of survival probabilities for each + block may lead to a survival probability > 1 for a given block. This will lead to + `DropBlock` erroring. This usually happens with a low number of blocks and a high base + survival probability, so in such cases it is recommended to use a fixed base survival + probability across blocks. If this is not desired, then a lower base survival probability + is recommended. + ((reference)[https://arxiv.org/abs/1810.12890]) # Arguments - - `drop_block_prob`: probability of dropping a block + - `drop_block_prob`: probability of dropping a block. If `nothing` is passed, it returns + `identity`. - `block_size`: size of the block to drop - `gamma_scale`: multiplicative factor for `gamma` used. For the calculation of gamma, refer to [the paper](https://arxiv.org/abs/1810.12890). @@ -90,11 +105,8 @@ ChainRulesCore.@non_differentiable _dropblock_checks(x, drop_block_prob, gamma_s function (m::DropBlock)(x) _dropblock_checks(x, m.drop_block_prob, m.gamma_scale) - if Flux._isactive(m) - return dropblock(m.rng, x, m.drop_block_prob, m.block_size, m.gamma_scale) - else - return x - end + return Flux._isactive(m) ? + dropblock(m.rng, x, m.drop_block_prob, m.block_size, m.gamma_scale) : x end function Flux.testmode!(m::DropBlock, mode = true) @@ -103,7 +115,7 @@ end function DropBlock(drop_block_prob = 0.1, block_size::Integer = 7, gamma_scale = 1.0, rng = rng_from_array()) - if drop_block_prob == 0.0 + if isnothing(drop_block_prob) return identity end return DropBlock(drop_block_prob, block_size, gamma_scale, nothing, rng) @@ -120,8 +132,8 @@ end """ DropPath(p; [rng = rng_from_array(x)]) -Implements Stochastic Depth - equivalent to `Dropout(p; dims = 4)` when `0 < p ≤ 1` and -`identity` otherwise. +Implements Stochastic Depth - equivalent to `Dropout(p; dims = 4)` when `0 ≤ p ≤ 1` and +`identity` if p is `nothing`. ([reference](https://arxiv.org/abs/1603.09382)) This layer can be used to drop certain blocks in a residual structure and allow them to @@ -134,10 +146,10 @@ equivalent to `identity`. In the case of the linear scaling rule, the calculations of survival probabilities for each block may lead to a survival probability > 1 for a given block. This will lead to - `DropPath` returning `identity`, which may not be desirable. This usually happens with - a low number of blocks and a high base survival probability, so it is recommended to - use a fixed base survival probability across blocks. If this is not possible, then - a lower base survival probability is recommended. + `DropPath` erroring. This usually happens with a low number of blocks and a high base + survival probability, so in such cases it is recommended to use a fixed base survival + probability across blocks. If this is not desired, then a lower base survival probability + is recommended. # Arguments @@ -146,4 +158,6 @@ equivalent to `identity`. for more information on the behaviour of this argument. Custom RNGs are only supported on the CPU. """ -DropPath(p; rng = rng_from_array()) = 0 < p ≤ 1 ? Dropout(p; dims = 4, rng) : identity +function DropPath(p; rng = rng_from_array()) + return isnothing(p) ? identity : Dropout(p; dims = 4, rng) +end diff --git a/src/layers/mlp.jl b/src/layers/mlp.jl index 467df30a4..e6336de9c 100644 --- a/src/layers/mlp.jl +++ b/src/layers/mlp.jl @@ -1,3 +1,4 @@ +# TODO @theabhirath figure out consistent behaviour for dropout rates - 0.0 vs `nothing` """ mlp_block(inplanes::Integer, hidden_planes::Integer, outplanes::Integer = inplanes; dropout_rate = 0., activation = gelu) @@ -45,46 +46,3 @@ function gated_mlp_block(gate_layer, inplanes::Integer, hidden_planes::Integer, Dropout(dropout_rate)) end gated_mlp_block(::typeof(identity), args...; kwargs...) = mlp_block(args...; kwargs...) - -""" - create_classifier(inplanes::Integer, nclasses::Integer, activation = identity; - pool_layer = AdaptiveMeanPool((1, 1)), - dropout_rate = 0.0, use_conv::Bool = false) - -Creates a classifier head to be used for models. - -# Arguments - - - `inplanes`: number of input feature maps - - `nclasses`: number of output classes - - `activation`: activation function to use - - `pool_layer`: pooling layer to use. This is passed in with the layer instantiated with - any arguments that are needed i.e. as `AdaptiveMeanPool((1, 1))`, for example. - - `dropout_rate`: dropout rate used in the classifier head. - - `use_conv`: whether to use a 1x1 convolutional layer instead of a `Dense` layer. -""" -function create_classifier(inplanes::Integer, nclasses::Integer, activation = identity; - use_conv::Bool = false, pool_layer = AdaptiveMeanPool((1, 1)), - dropout_rate = nothing) - # Decide whether to flatten the input or not - flatten_in_pool = !use_conv && pool_layer !== identity - if use_conv - @assert pool_layer === identity - "`pool_layer` must be identity if `use_conv` is true" - end - classifier = [] - if flatten_in_pool - push!(classifier, pool_layer, MLUtils.flatten) - else - push!(classifier, pool_layer) - end - # Dropout is applied after the pooling layer - isnothing(dropout_rate) ? nothing : push!(classifier, Dropout(dropout_rate)) - # Fully-connected layer - if use_conv - push!(classifier, Conv((1, 1), inplanes => nclasses, activation)) - else - push!(classifier, Dense(inplanes => nclasses, activation)) - end - return Chain(classifier...) -end