Skip to content

Commit

Permalink
Fix minor hiccups
Browse files Browse the repository at this point in the history
  • Loading branch information
theabhirath committed Aug 15, 2022
1 parent c998951 commit fc03d70
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 23 deletions.
15 changes: 8 additions & 7 deletions src/convnets/efficientnets/core.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
function mbconv_builder(block_configs::AbstractVector{NTuple{6, Int}},
stage_idx::Integer; scalings::NTuple{2, Real} = (1, 1),
norm_layer = BatchNorm)
depth_mult, width_mult = scalings
width_mult, depth_mult = scalings
k, inplanes, outplanes, expansion, stride, nrepeats = block_configs[stage_idx]
inplanes = _round_channels(inplanes * width_mult, 8)
outplanes = _round_channels(outplanes * width_mult, 8)
Expand All @@ -17,7 +17,8 @@ function mbconv_builder(block_configs::AbstractVector{NTuple{6, Int}},
end

function fused_mbconv_builder(block_configs::AbstractVector{NTuple{6, Int}},
stage_idx::Integer; norm_layer = BatchNorm)
stage_idx::Integer; scalings::NTuple{2, Real} = (1, 1),
norm_layer = BatchNorm)
k, inplanes, outplanes, expansion, stride, nrepeats = block_configs[stage_idx]
function get_layers(block_idx)
inplanes = block_idx == 1 ? inplanes : outplanes
Expand All @@ -40,22 +41,22 @@ end

function efficientnet(block_configs::AbstractVector{NTuple{6, Int}},
residual_fns::AbstractVector; scalings::NTuple{2, Real} = (1, 1),
headplanes::Integer = _round_channels(block_configs[end][3] *
scalings[2], 8) * 4,
headplanes::Integer = block_configs[end][3] * 4,
norm_layer = BatchNorm, dropout_rate = nothing,
inchannels::Integer = 3, nclasses::Integer = 1000)
layers = []
# stem of the model
append!(layers,
conv_norm((3, 3), inchannels, block_configs[1][2], swish; norm_layer,
stride = 2, pad = SamePad()))
conv_norm((3, 3), inchannels,
_round_channels(block_configs[1][2] * scalings[1], 8), swish;
norm_layer, stride = 2, pad = SamePad()))
# building inverted residual blocks
get_layers, block_repeats = efficientnet_builder(block_configs, residual_fns;
scalings, norm_layer)
append!(layers, resnet_stages(get_layers, block_repeats, +))
# building last layers
append!(layers,
conv_norm((1, 1), _round_channels(block_configs[end][3] * scalings[2], 8),
conv_norm((1, 1), _round_channels(block_configs[end][3] * scalings[1], 8),
headplanes, swish; pad = SamePad()))
return Chain(Chain(layers...), create_classifier(headplanes, nclasses; dropout_rate))
end
3 changes: 2 additions & 1 deletion src/convnets/efficientnets/efficientnetv2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ function EfficientNetv2(config::Symbol; pretrain::Bool = false,
inchannels::Integer = 3, nclasses::Integer = 1000)
_checkconfig(config, sort(collect(keys(EFFNETV2_CONFIGS))))
layers = efficientnet(EFFNETV2_CONFIGS[config],
vcat(fill(fused_mbconv_builder, 3), fill(mbconv_builder, 4));
vcat(fill(fused_mbconv_builder, 3),
fill(mbconv_builder, length(EFFNETV2_CONFIGS[config]) - 3));
headplanes = 1280, inchannels, nclasses)
if pretrain
loadpretrain!(layers, string("efficientnetv2"))
Expand Down
32 changes: 17 additions & 15 deletions src/convnets/resnets/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -275,18 +275,18 @@ function resnet_stages(get_layers, block_repeats::AbstractVector{<:Integer}, con
# Construct the blocks for each stage
blocks = map(1:nblocks) do block_idx
branches = get_layers(stage_idx, block_idx)
return (length(branches) == 1) ? only(branches) :
return length(branches) == 1 ? only(branches) :
Parallel(connection, branches...)
end
push!(stages, Chain(blocks...))
end
return Chain(stages...)
end

function resnet(img_dims, stem, builders, block_repeats::AbstractVector{<:Integer},
function resnet(img_dims, stem, get_layers, block_repeats::AbstractVector{<:Integer},
connection, classifier_fn)
# Build stages of the ResNet
stage_blocks = resnet_stages(builders, block_repeats, connection)
stage_blocks = resnet_stages(get_layers, block_repeats, connection)
backbone = Chain(stem, stage_blocks)
# Add classifier to the backbone
nfeaturemaps = Flux.outputsize(backbone, img_dims; padbatch = true)[3]
Expand All @@ -308,17 +308,19 @@ function resnet(block_type, block_repeats::AbstractVector{<:Integer},
if block_type == basicblock
@assert cardinality==1 "Cardinality must be 1 for `basicblock`"
@assert base_width==64 "Base width must be 64 for `basicblock`"
builder = basicblock_builder(block_repeats; inplanes, reduction_factor,
activation, norm_layer, revnorm, attn_fn,
drop_block_rate, drop_path_rate,
stride_fn = resnet_stride, planes_fn = resnet_planes,
downsample_tuple = downsample_opt, kwargs...)
get_layers = basicblock_builder(block_repeats; inplanes, reduction_factor,
activation, norm_layer, revnorm, attn_fn,
drop_block_rate, drop_path_rate,
stride_fn = resnet_stride,
planes_fn = resnet_planes,
downsample_tuple = downsample_opt, kwargs...)
elseif block_type == bottleneck
builder = bottleneck_builder(block_repeats; inplanes, cardinality,
base_width, reduction_factor, activation, norm_layer,
revnorm, attn_fn, drop_block_rate, drop_path_rate,
stride_fn = resnet_stride, planes_fn = resnet_planes,
downsample_tuple = downsample_opt, kwargs...)
get_layers = bottleneck_builder(block_repeats; inplanes, cardinality, base_width,
reduction_factor, activation, norm_layer, revnorm,
attn_fn, drop_block_rate, drop_path_rate,
stride_fn = resnet_stride,
planes_fn = resnet_planes,
downsample_tuple = downsample_opt, kwargs...)
elseif block_type == bottle2neck
@assert isnothing(drop_block_rate)
"DropBlock not supported for `bottle2neck`. Set `drop_block_rate` to nothing"
Expand All @@ -337,8 +339,8 @@ function resnet(block_type, block_repeats::AbstractVector{<:Integer},
end
classifier_fn = nfeatures -> create_classifier(nfeatures, nclasses; dropout_rate,
pool_layer, use_conv)
return resnet((imsize..., inchannels), stem, fill(builder, length(block_repeats)),
block_repeats, connection$activation, classifier_fn)
return resnet((imsize..., inchannels), stem, get_layers, block_repeats,
connection$activation, classifier_fn)
end
function resnet(block_fn, block_repeats, downsample_opt::Symbol = :B; kwargs...)
return resnet(block_fn, block_repeats, RESNET_SHORTCUTS[downsample_opt]; kwargs...)
Expand Down
1 change: 1 addition & 0 deletions src/convnets/resnets/res2net.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ function bottle2neck_builder(block_repeats::AbstractVector{<:Integer};
attn_fn = planes -> identity,
stride_fn = resnet_stride, planes_fn = resnet_planes,
downsample_tuple = (downsample_conv, downsample_identity))
planes_vec = collect(planes_fn(block_repeats))
# closure over `idxs`
function get_layers(stage_idx::Integer, block_idx::Integer)
# This is needed for block `inplanes` and `planes` calculations
Expand Down

0 comments on commit fc03d70

Please sign in to comment.