diff --git a/src/convnets/efficientnets/core.jl b/src/convnets/efficientnets/core.jl index 139252a3b..22ddf8172 100644 --- a/src/convnets/efficientnets/core.jl +++ b/src/convnets/efficientnets/core.jl @@ -1,7 +1,7 @@ function mbconv_builder(block_configs::AbstractVector{NTuple{6, Int}}, stage_idx::Integer; scalings::NTuple{2, Real} = (1, 1), norm_layer = BatchNorm) - depth_mult, width_mult = scalings + width_mult, depth_mult = scalings k, inplanes, outplanes, expansion, stride, nrepeats = block_configs[stage_idx] inplanes = _round_channels(inplanes * width_mult, 8) outplanes = _round_channels(outplanes * width_mult, 8) @@ -17,7 +17,8 @@ function mbconv_builder(block_configs::AbstractVector{NTuple{6, Int}}, end function fused_mbconv_builder(block_configs::AbstractVector{NTuple{6, Int}}, - stage_idx::Integer; norm_layer = BatchNorm) + stage_idx::Integer; scalings::NTuple{2, Real} = (1, 1), + norm_layer = BatchNorm) k, inplanes, outplanes, expansion, stride, nrepeats = block_configs[stage_idx] function get_layers(block_idx) inplanes = block_idx == 1 ? inplanes : outplanes @@ -40,22 +41,22 @@ end function efficientnet(block_configs::AbstractVector{NTuple{6, Int}}, residual_fns::AbstractVector; scalings::NTuple{2, Real} = (1, 1), - headplanes::Integer = _round_channels(block_configs[end][3] * - scalings[2], 8) * 4, + headplanes::Integer = block_configs[end][3] * 4, norm_layer = BatchNorm, dropout_rate = nothing, inchannels::Integer = 3, nclasses::Integer = 1000) layers = [] # stem of the model append!(layers, - conv_norm((3, 3), inchannels, block_configs[1][2], swish; norm_layer, - stride = 2, pad = SamePad())) + conv_norm((3, 3), inchannels, + _round_channels(block_configs[1][2] * scalings[1], 8), swish; + norm_layer, stride = 2, pad = SamePad())) # building inverted residual blocks get_layers, block_repeats = efficientnet_builder(block_configs, residual_fns; scalings, norm_layer) append!(layers, resnet_stages(get_layers, block_repeats, +)) # building last layers append!(layers, - conv_norm((1, 1), _round_channels(block_configs[end][3] * scalings[2], 8), + conv_norm((1, 1), _round_channels(block_configs[end][3] * scalings[1], 8), headplanes, swish; pad = SamePad())) return Chain(Chain(layers...), create_classifier(headplanes, nclasses; dropout_rate)) end diff --git a/src/convnets/efficientnets/efficientnetv2.jl b/src/convnets/efficientnets/efficientnetv2.jl index ecbeed07a..d9a9d0d77 100644 --- a/src/convnets/efficientnets/efficientnetv2.jl +++ b/src/convnets/efficientnets/efficientnetv2.jl @@ -56,7 +56,8 @@ function EfficientNetv2(config::Symbol; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) _checkconfig(config, sort(collect(keys(EFFNETV2_CONFIGS)))) layers = efficientnet(EFFNETV2_CONFIGS[config], - vcat(fill(fused_mbconv_builder, 3), fill(mbconv_builder, 4)); + vcat(fill(fused_mbconv_builder, 3), + fill(mbconv_builder, length(EFFNETV2_CONFIGS[config]) - 3)); headplanes = 1280, inchannels, nclasses) if pretrain loadpretrain!(layers, string("efficientnetv2")) diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index 39e283fd0..95291bc6f 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -275,7 +275,7 @@ function resnet_stages(get_layers, block_repeats::AbstractVector{<:Integer}, con # Construct the blocks for each stage blocks = map(1:nblocks) do block_idx branches = get_layers(stage_idx, block_idx) - return (length(branches) == 1) ? only(branches) : + return length(branches) == 1 ? only(branches) : Parallel(connection, branches...) end push!(stages, Chain(blocks...)) @@ -283,10 +283,10 @@ function resnet_stages(get_layers, block_repeats::AbstractVector{<:Integer}, con return Chain(stages...) end -function resnet(img_dims, stem, builders, block_repeats::AbstractVector{<:Integer}, +function resnet(img_dims, stem, get_layers, block_repeats::AbstractVector{<:Integer}, connection, classifier_fn) # Build stages of the ResNet - stage_blocks = resnet_stages(builders, block_repeats, connection) + stage_blocks = resnet_stages(get_layers, block_repeats, connection) backbone = Chain(stem, stage_blocks) # Add classifier to the backbone nfeaturemaps = Flux.outputsize(backbone, img_dims; padbatch = true)[3] @@ -308,17 +308,19 @@ function resnet(block_type, block_repeats::AbstractVector{<:Integer}, if block_type == basicblock @assert cardinality==1 "Cardinality must be 1 for `basicblock`" @assert base_width==64 "Base width must be 64 for `basicblock`" - builder = basicblock_builder(block_repeats; inplanes, reduction_factor, - activation, norm_layer, revnorm, attn_fn, - drop_block_rate, drop_path_rate, - stride_fn = resnet_stride, planes_fn = resnet_planes, - downsample_tuple = downsample_opt, kwargs...) + get_layers = basicblock_builder(block_repeats; inplanes, reduction_factor, + activation, norm_layer, revnorm, attn_fn, + drop_block_rate, drop_path_rate, + stride_fn = resnet_stride, + planes_fn = resnet_planes, + downsample_tuple = downsample_opt, kwargs...) elseif block_type == bottleneck - builder = bottleneck_builder(block_repeats; inplanes, cardinality, - base_width, reduction_factor, activation, norm_layer, - revnorm, attn_fn, drop_block_rate, drop_path_rate, - stride_fn = resnet_stride, planes_fn = resnet_planes, - downsample_tuple = downsample_opt, kwargs...) + get_layers = bottleneck_builder(block_repeats; inplanes, cardinality, base_width, + reduction_factor, activation, norm_layer, revnorm, + attn_fn, drop_block_rate, drop_path_rate, + stride_fn = resnet_stride, + planes_fn = resnet_planes, + downsample_tuple = downsample_opt, kwargs...) elseif block_type == bottle2neck @assert isnothing(drop_block_rate) "DropBlock not supported for `bottle2neck`. Set `drop_block_rate` to nothing" @@ -337,8 +339,8 @@ function resnet(block_type, block_repeats::AbstractVector{<:Integer}, end classifier_fn = nfeatures -> create_classifier(nfeatures, nclasses; dropout_rate, pool_layer, use_conv) - return resnet((imsize..., inchannels), stem, fill(builder, length(block_repeats)), - block_repeats, connection$activation, classifier_fn) + return resnet((imsize..., inchannels), stem, get_layers, block_repeats, + connection$activation, classifier_fn) end function resnet(block_fn, block_repeats, downsample_opt::Symbol = :B; kwargs...) return resnet(block_fn, block_repeats, RESNET_SHORTCUTS[downsample_opt]; kwargs...)