diff --git a/README.md b/README.md index fb8b215f5..c4e909b45 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,7 @@ | [MobileNetv1](https://arxiv.org/abs/1704.04861) | [`MobileNetv1`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.MobileNetv1.html) | N | | [MobileNetv2](https://arxiv.org/abs/1801.04381) | [`MobileNetv2`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.MobileNetv2.html) | N | | [MobileNetv3](https://arxiv.org/abs/1905.02244) | [`MobileNetv3`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.MobileNetv3.html) | N | +| [EfficientNet](https://arxiv.org/abs/1905.11946) | [`EfficientNet`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.EfficientNet.html) | N | | [MLPMixer](https://arxiv.org/pdf/2105.01601) | [`MLPMixer`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.MLPMixer.html) | N | | [ResMLP](https://arxiv.org/abs/2105.03404) | [`ResMLP`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.ResMLP.html) | N | | [gMLP](https://arxiv.org/abs/2105.08050) | [`gMLP`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.gMLP.html) | N | diff --git a/src/Metalhead.jl b/src/Metalhead.jl index b489b2fca..f391c0c66 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -27,6 +27,7 @@ include("convnets/resnext.jl") include("convnets/densenet.jl") include("convnets/squeezenet.jl") include("convnets/mobilenet.jl") +include("convnets/efficientnet.jl") include("convnets/convnext.jl") include("convnets/convmixer.jl") @@ -42,7 +43,7 @@ export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19, ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, ResNeXt, DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201, GoogLeNet, Inception3, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception, - SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, + SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, EfficientNet, MLPMixer, ResMLP, gMLP, ViT, ConvMixer, ConvNeXt diff --git a/src/convnets/efficientnet.jl b/src/convnets/efficientnet.jl new file mode 100644 index 000000000..1465eb238 --- /dev/null +++ b/src/convnets/efficientnet.jl @@ -0,0 +1,156 @@ +""" + efficientnet(scalings, block_config; + inchannels = 3, nclasses = 1000, max_width = 1280) + +Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)). + +# Arguments + +- `scalings`: global width and depth scaling (given as a tuple) +- `block_config`: configuration for each inverted residual block, + given as a vector of tuples with elements: + - `n`: number of block repetitions (will be scaled by global depth scaling) + - `k`: kernel size + - `s`: kernel stride + - `e`: expansion ratio + - `i`: block input channels (will be scaled by global width scaling) + - `o`: block output channels (will be scaled by global width scaling) +- `inchannels`: number of input channels +- `nclasses`: number of output classes +- `max_width`: maximum number of output channels before the fully connected + classification blocks +""" +function efficientnet(scalings, block_config; + inchannels = 3, nclasses = 1000, max_width = 1280) + wscale, dscale = scalings + scalew(w) = wscale ≈ 1 ? w : ceil(Int64, wscale * w) + scaled(d) = dscale ≈ 1 ? d : ceil(Int64, dscale * d) + + out_channels = _round_channels(scalew(32), 8) + stem = conv_bn((3, 3), inchannels, out_channels, swish; + bias = false, stride = 2, pad = SamePad()) + + blocks = [] + for (n, k, s, e, i, o) in block_config + in_channels = _round_channels(scalew(i), 8) + out_channels = _round_channels(scalew(o), 8) + repeats = scaled(n) + + push!(blocks, + invertedresidual(k, in_channels, in_channels * e, out_channels, swish; + stride = s, reduction = 4)) + for _ in 1:(repeats - 1) + push!(blocks, + invertedresidual(k, out_channels, out_channels * e, out_channels, swish; + stride = 1, reduction = 4)) + end + end + blocks = Chain(blocks...) + + head_out_channels = _round_channels(max_width, 8) + head = conv_bn((1, 1), out_channels, head_out_channels, swish; + bias = false, pad = SamePad()) + + top = Dense(head_out_channels, nclasses) + + return Chain(Chain([stem..., blocks, head...]), + Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, top)) +end + +# n: # of block repetitions +# k: kernel size k x k +# s: stride +# e: expantion ratio +# i: block input channels +# o: block output channels +const efficientnet_block_configs = [ +# (n, k, s, e, i, o) + (1, 3, 1, 1, 32, 16), + (2, 3, 2, 6, 16, 24), + (2, 5, 2, 6, 24, 40), + (3, 3, 2, 6, 40, 80), + (3, 5, 1, 6, 80, 112), + (4, 5, 2, 6, 112, 192), + (1, 3, 1, 6, 192, 320) +] + +# w: width scaling +# d: depth scaling +# r: image resolution +const efficientnet_global_configs = Dict( +# ( r, ( w, d)) + :b0 => (224, (1.0, 1.0)), + :b1 => (240, (1.0, 1.1)), + :b2 => (260, (1.1, 1.2)), + :b3 => (300, (1.2, 1.4)), + :b4 => (380, (1.4, 1.8)), + :b5 => (456, (1.6, 2.2)), + :b6 => (528, (1.8, 2.6)), + :b7 => (600, (2.0, 3.1)), + :b8 => (672, (2.2, 3.6)) +) + +struct EfficientNet + layers::Any +end + +""" + EfficientNet(scalings, block_config; + inchannels = 3, nclasses = 1000, max_width = 1280) + +Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)). +See also [`efficientnet`](#). + +# Arguments + +- `scalings`: global width and depth scaling (given as a tuple) +- `block_config`: configuration for each inverted residual block, + given as a vector of tuples with elements: + - `n`: number of block repetitions (will be scaled by global depth scaling) + - `k`: kernel size + - `s`: kernel stride + - `e`: expansion ratio + - `i`: block input channels (will be scaled by global width scaling) + - `o`: block output channels (will be scaled by global width scaling) +- `inchannels`: number of input channels +- `nclasses`: number of output classes +- `max_width`: maximum number of output channels before the fully connected + classification blocks +""" +function EfficientNet(scalings, block_config; + inchannels = 3, nclasses = 1000, max_width = 1280) + layers = efficientnet(scalings, block_config; + inchannels = inchannels, + nclasses = nclasses, + max_width = max_width) + return EfficientNet(layers) +end + +@functor EfficientNet + +(m::EfficientNet)(x) = m.layers(x) + +backbone(m::EfficientNet) = m.layers[1] +classifier(m::EfficientNet) = m.layers[2] + +""" + EfficientNet(name::Symbol; pretrain = false) + +Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)). +See also [`efficientnet`](#). + +# Arguments + +- `name`: name of default configuration + (can be `:b0`, `:b1`, `:b2`, `:b3`, `:b4`, `:b5`, `:b6`, `:b7`, `:b8`) +- `pretrain`: set to `true` to load the pre-trained weights for ImageNet +""" +function EfficientNet(name::Symbol; pretrain = false) + @assert name in keys(efficientnet_global_configs) + "`name` must be one of $(sort(collect(keys(efficientnet_global_configs))))" + + model = EfficientNet(efficientnet_global_configs[name][2], efficientnet_block_configs) + pretrain && loadpretrain!(model, string("efficientnet-", name)) + + return model +end diff --git a/test/convnets.jl b/test/convnets.jl index 4949e34e6..9d1645865 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -70,6 +70,27 @@ end GC.safepoint() GC.gc() +@testset "EfficientNet" begin + @testset "EfficientNet($name)" for name in [:b0, :b1, :b2, :b3, :b4] #, :b5, :b6, :b7, :b8] + # preferred image resolution scaling + r = Metalhead.efficientnet_global_configs[name][1] + x = rand(Float32, r, r, 3, 1) + m = EfficientNet(name) + @test size(m(x)) == (1000, 1) + if (EfficientNet, name) in PRETRAINED_MODELS + @test acctest(EfficientNet(name, pretrain = true)) + else + @test_throws ArgumentError EfficientNet(name, pretrain = true) + end + @test gradtest(m, x) + GC.safepoint() + GC.gc() + end +end + +GC.safepoint() +GC.gc() + @testset "GoogLeNet" begin m = GoogLeNet() @test size(m(x_224)) == (1000, 1) @@ -215,7 +236,7 @@ GC.safepoint() GC.gc() @testset "ConvNeXt" verbose = true begin - @testset for mode in [:small, :base, :large] # :tiny, #, :xlarge] + @testset for mode in [:small, :base] #, :large # :tiny, #, :xlarge] @testset for drop_path_rate in [0.0, 0.5] m = ConvNeXt(mode; drop_path_rate) @test size(m(x_224)) == (1000, 1) @@ -230,7 +251,7 @@ GC.safepoint() GC.gc() @testset "ConvMixer" verbose = true begin - @testset for mode in [:small, :base, :large] + @testset for mode in [:small, :base] #, :large] m = ConvMixer(mode) @test size(m(x_224)) == (1000, 1) diff --git a/test/other.jl b/test/other.jl index 769539720..3c1752f3a 100644 --- a/test/other.jl +++ b/test/other.jl @@ -1,5 +1,5 @@ @testset "MLPMixer" begin - @testset for mode in [:small, :base, :large] # :huge] + @testset for mode in [:small, :base] # :large, # :huge] @testset for drop_path_rate in [0.0, 0.5] m = MLPMixer(mode; drop_path_rate) @test size(m(x_224)) == (1000, 1) @@ -11,7 +11,7 @@ end @testset "ResMLP" begin - @testset for mode in [:small, :base, :large] # :huge] + @testset for mode in [:small, :base] # :large, # :huge] @testset for drop_path_rate in [0.0, 0.5] m = ResMLP(mode; drop_path_rate) @test size(m(x_224)) == (1000, 1) @@ -23,7 +23,7 @@ end end @testset "gMLP" begin - @testset for mode in [:small, :base, :large] # :huge] + @testset for mode in [:small, :base] # :large, # :huge] @testset for drop_path_rate in [0.0, 0.5] m = gMLP(mode; drop_path_rate) @test size(m(x_224)) == (1000, 1)