Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Metalhead.jl models to model registry #269

Draft
wants to merge 21 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
60b4bea
Add a feature registry for models
lorenzoh Nov 26, 2022
936cd1c
Use 1.6 supported syntax
lorenzoh Nov 26, 2022
89c8a61
Fix model variant printing
lorenzoh Nov 26, 2022
3919a1d
WIP: Add Metalhead.jl models to registry
lorenzoh Nov 26, 2022
4a0e6e0
Use correct `load` function in model registry.
lorenzoh Nov 27, 2022
4c9bdbc
Merge branch 'lo/model-registry-2' into lo/metalhead-models
lorenzoh Nov 27, 2022
377daac
Merge branch 'master' of github.com:FluxML/FastAI.jl into lo/metalhea…
lorenzoh Nov 27, 2022
0a9353e
Add all Metalhead models
lorenzoh Nov 27, 2022
9aef6d2
Change `ModelVariant` API
lorenzoh Dec 4, 2022
4ce6294
Merge branch 'lo/model-registry-2' into lo/metalhead-models
lorenzoh Dec 4, 2022
bc7acd0
WIP: adapt based on changes to model variant API
lorenzoh Dec 11, 2022
85c88c3
Model registry now has a field :loadfn
lorenzoh Feb 3, 2023
b0bf160
Merge branch 'lo/model-registry-2' into lo/metalhead-models
lorenzoh Feb 3, 2023
ea4e153
Use base `loadfn` for default model instead of variant
lorenzoh Feb 4, 2023
16715ae
Remove dead code
lorenzoh Feb 4, 2023
c7d1073
Remove unneeded deps
lorenzoh Feb 4, 2023
f3e9bb5
Add `ConvFeatures` block to represent bakbone outputs
lorenzoh Feb 4, 2023
7d69cef
Finish docstring.
lorenzoh Feb 4, 2023
90a2f88
Merge branch 'master' into lo/model-registry-2
lorenzoh Feb 4, 2023
95570d7
Merge branch 'lo/model-registry-2' into lo/metalhead-models
lorenzoh Feb 4, 2023
dfad189
Merge branch 'lo/conv-features' into lo/metalhead-models
lorenzoh Feb 4, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions FastVision/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ IndirectArrays = "9b13fd28-a010-5f03-acff-a1bbcff69959"
InlineTest = "bd334432-b1e7-49c7-a2dc-dd9149e4ebd6"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
MakieCore = "20f20a25-4f0e-4fdf-b5d1-57303727442b"
Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
ShowCases = "605ecd9f-84a6-4c9e-81e2-4798472b76a3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Expand All @@ -33,14 +34,15 @@ FastAI = "0.5"
FixedPointNumbers = "0.8"
Flux = "0.12, 0.13"
ImageIO = "0.6"
ImageInTerminal = "0.4"
ImageInTerminal = "0.4, 0.5"
IndirectArrays = "0.5, 1"
InlineTest = "0.2"
MLUtils = "0.2, 0.3, 0.4"
MakieCore = "0.3, 0.4, 0.5, 0.6"
Metalhead = "0.8"
ProgressMeter = "1"
ShowCases = "0.1"
StaticArrays = "1.1"
UnicodePlots = "2"
UnicodePlots = "2, 3"
Zygote = "0.6"
julia = "1.6"
11 changes: 10 additions & 1 deletion FastVision/src/FastVision.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,15 @@ using FastAI: # blocks
Context, Training, Validation, Inference,
Datasets
using FastAI.Datasets
import FastAI.Registries: ModelVariant, compatibleblocks, loadvariant

# extending
import FastAI:
blockmodel, blockbackbone, blocklossfn, encode, decode, checkblock,
encodedblock, decodedblock, showblock!, mockblock, setup, encodestate,
decodestate

import Flux
using Flux: Flux, Chain, Conv, Dense
import MLUtils: getobs, numobs, mapobs, eachobs
import Colors: colormaps_sequential, Colorant, Color, Gray, Normed, RGB,
alphacolor, deuteranopic, distinguishable_colors
Expand All @@ -63,6 +64,7 @@ import IndirectArrays: IndirectArray
import MakieCore
import MakieCore: @recipe
import MakieCore.Observables: @map
import Metalhead: Metalhead
import ProgressMeter: Progress, next!
import StaticArrays: SVector
import Statistics: mean, std
Expand All @@ -76,6 +78,7 @@ include("blocks/bounded.jl")
include("blocks/image.jl")
include("blocks/mask.jl")
include("blocks/keypoints.jl")
include("blocks/convfeatures.jl")

include("encodings/onehot.jl")
include("encodings/imagepreprocessing.jl")
Expand All @@ -93,6 +96,7 @@ include("tasks/keypointregression.jl")
include("datasets.jl")
include("recipes.jl")
include("makie.jl")
include("modelregistry.jl")

include("tests.jl")

Expand All @@ -103,6 +107,11 @@ function __init__()
push!(FastAI.learningtasks(), t)
end
end
foreach(values(_models)) do t
if !haskey(FastAI.models(), t.id)
push!(FastAI.models(), t)
end
end
end

export Image, Mask, Keypoints, Bounded,
Expand Down
52 changes: 52 additions & 0 deletions FastVision/src/blocks/convfeatures.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@

"""
ConvFeatures{N}(n) <: Block
ConvFeatures(n, size)

Block representing features from a convolutional neural network backbone
with `n` feature channels and `N` spatial dimensions.

For example, a 2D ResNet's convolutional layers may produce a `h`x`w`x`ch` output
that is passed further to the classifier head.

## Examples

A feature block with 512 channels and variable spatial dimensions:

```julia
FastVision.ConvFeatures{2}(512)
# or equivalently
FastVision.ConvFeatures(512, (:, :))
```

A feature block with 512 channels and fixed spatial dimensions:

```julia
FastVision.ConvFeatures(512, (4, 4))
```

"""
struct ConvFeatures{N} <: Block
n::Int
size::NTuple{N, DimSize}
end

ConvFeatures{N}(n) where {N} = ConvFeatures{N}(n, ntuple(_ -> :, N))

function FastAI.checkblock(block::ConvFeatures{N}, a::AbstractArray{T, M}) where {M, N, T}
M == N + 1 || return false
return checksize(block.size, size(a)[begin:N])
end

function FastAI.mockblock(block::ConvFeatures)
rand(Float32, map(l -> l isa Colon ? 8 : l, block.size)..., block.n)
end


@testset "ConvFeatures [block]" begin
@test ConvFeatures(16, (:, :)) == ConvFeatures{2}(16)
@test checkblock(ConvFeatures(16, (:, :)), rand(Float32, 2, 2, 16))
@test checkblock(ConvFeatures(16, (:, :)), rand(Float32, 3, 2, 16))
@test checkblock(ConvFeatures(16, (2, 2)), rand(Float32, 2, 2, 16))
@test !checkblock(ConvFeatures(16, (2, :)), rand(Float32, 3, 2, 16))
end
133 changes: 133 additions & 0 deletions FastVision/src/modelregistry.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@

# ## Model variants for Metalhead.jl models

struct MetalheadClassifierVariant <: ModelVariant
fn
end
compatibleblocks(::MetalheadClassifierVariant) = (ImageTensor{2}, FastAI.OneHotTensor{0})
function loadvariant(v::MetalheadClassifierVariant, xblock::ImageTensor{2}, yblock::FastAI.OneHotTensor{0}, checkpoint; kwargs...)
return v.fn(; pretrain = checkpoint == "imagenet1k", inchannels=xblock.nchannels,
nclasses=length(yblock.classes), kwargs...)
end
function loadvariant(v::MetalheadClassifierVariant, xblock, yblock, checkpoint; kwargs...)
return v.fn(; pretrain = checkpoint == "imagenet1k", kwargs...)
end

struct MetalheadBackboneVariant <: ModelVariant
fn
nfeatures::Int
end
compatibleblocks(variant::MetalheadBackboneVariant) = (ImageTensor{2}, ConvFeatures{2}(variant.nfeatures))
function loadvariant(v::MetalheadBackboneVariant, xblock::ImageTensor{2}, yblock::ConvFeatures{2}, checkpoint; kwargs...)
model = v.fn(; pretrain = checkpoint == "imagenet1k", inchannels=xblock.nchannels,
kwargs...)
return model.layers[1]
end
function loadvariant(v::MetalheadBackboneVariant, xblock, yblock, checkpoint; kwargs...)
model = v.fn(; pretrain = checkpoint == "imagenet1k", kwargs...)
return model.layers[1]
end

function metalheadvariants(modelfn, nfeatures)
return [
"classifier" => MetalheadClassifierVariant(modelfn),
"backbone" => MetalheadBackboneVariant(modelfn, nfeatures),
]
end


const _models = Dict{String, Any}()


fix(fn, args...; kwargs...) = (_args...; _kwargs...) -> fn(args..., _args...; kwargs..., _kwargs...)



# model config: id, description, basefn, variant, hasweights, nfeatures
const METALHEAD_CONFIGS = [
("metalhead/resnet18", "ResNet18", fix(Metalhead.ResNet, 18), true, 512),
("metalhead/resnet34", "ResNet34", fix(Metalhead.ResNet, 34), true, 512),
("metalhead/resnet50", "ResNet50", fix(Metalhead.ResNet, 50), true, 2048),
("metalhead/resnet101", "ResNet101", fix(Metalhead.ResNet, 101), true,
2048),
("metalhead/resnet152", "ResNet152", fix(Metalhead.ResNet, 152), true,
2048),
("metalhead/wideresnet50", "WideResNet50", fix(Metalhead.WideResNet, 50),
true, 2048),
("metalhead/wideresnet101", "WideResNet101",
fix(Metalhead.WideResNet, 101), true, 2048),
("metalhead/wideresnet152", "WideResNet152",
fix(Metalhead.WideResNet, 152), true, 2048),
("metalhead/googlenet", "GoogLeNet", Metalhead.GoogLeNet, false,
1024),
("metalhead/inceptionv3", "InceptionV3", Metalhead.Inceptionv3, false,
2048),
("metalhead/inceptionv4", "InceptionV4", Metalhead.Inceptionv4, false,
1536),
("metalhead/squeezenet", "SqueezeNet", Metalhead.SqueezeNet, true,
512),
("metalhead/densenet-121", "DenseNet121", fix(Metalhead.DenseNet, 121),
false, 1024),
("metalhead/densenet-161", "DenseNet161", fix(Metalhead.DenseNet, 161),
false, 1472),
("metalhead/densenet-169", "DenseNet169", fix(Metalhead.DenseNet, 169),
false, 1664),
("metalhead/densenet-201", "DenseNet201", fix(Metalhead.DenseNet, 201),
false, 1920),
("metalhead/resnext50", "ResNeXt50", fix(Metalhead.ResNeXt, 50), true,
2048),
("metalhead/resnext101", "ResNeXt101", fix(Metalhead.ResNeXt, 101), true,
2048),
("metalhead/resnext152", "ResNeXt152", fix(Metalhead.ResNeXt, 152), true,
2048),
("metalhead/mobilenetv1", "MobileNetV1", Metalhead.MobileNetv1, false,
1024),
("metalhead/mobilenetv2", "MobileNetV2", Metalhead.MobileNetv2, false,
1280),
("metalhead/mobilenetv3-small", "MobileNetV3 Small",
fix(Metalhead.MobileNetv3, :small), false, 576),
("metalhead/mobilenetv3-large", "MobileNetV3 Large",
fix(Metalhead.MobileNetv3, :large), false, 960),
("metalhead/efficientnet-b0", "EfficientNet-B0",
fix(Metalhead.EfficientNet, :b0), false, 1280),
("metalhead/efficientnet-b0", "EfficientNet-B0",
fix(Metalhead.EfficientNet, :b0), false, 1280),
("metalhead/efficientnet-b1", "EfficientNet-B1",
fix(Metalhead.EfficientNet, :b1), false, 1280),
("metalhead/efficientnet-b2", "EfficientNet-B2",
fix(Metalhead.EfficientNet, :b2), false, 1280),
("metalhead/efficientnet-b3", "EfficientNet-B3",
fix(Metalhead.EfficientNet, :b3), false, 1280),
("metalhead/efficientnet-b4", "EfficientNet-B4",
fix(Metalhead.EfficientNet, :b4), false, 1280),
("metalhead/efficientnet-b5", "EfficientNet-B5",
fix(Metalhead.EfficientNet, :b5), false, 1280),
("metalhead/efficientnet-b6", "EfficientNet-B6",
fix(Metalhead.EfficientNet, :b6), false, 1280),
("metalhead/efficientnet-b7", "EfficientNet-B7",
fix(Metalhead.EfficientNet, :b7), false, 1280),
("metalhead/efficientnet-b8", "EfficientNet-B8",
fix(Metalhead.EfficientNet, :b8), false, 1280),
]

metalheadloadfn(fn, hasweights) = function loadfn(ckpt; kwargs...)
hasweights ? fn(; pretrain = ckpt !== nothing, kwargs...) : fn(; kwargs...)
end

for (id, description, loadfn, hasweights, nfeatures) in METALHEAD_CONFIGS
_models[id] = (;
id, description,
loadfn = metalheadloadfn(loadfn, hasweights),
variants = metalheadvariants(loadfn, nfeatures),
checkpoints = hasweights ? ["imagenet1k"] : String[],
backend = :flux)
end


@testset "Metalhead models" begin
@test_nowarn load(models()["metalhead/resnet18"]; variant = "backbone")
@test_nowarn load(models()["metalhead/resnet18"]; variant = "classifier")
@test_nowarn load(models()["metalhead/resnet18"]; output = FastAI.OneHotLabel)
@test_nowarn load(models()["metalhead/resnet18"]; input = FastVision.ImageTensor)
@test_throws FastAI.Registries.NoModelVariantFoundError load(models()["metalhead/resnet18"]; output = FastAI.Label)
end
4 changes: 3 additions & 1 deletion src/Registries/Registries.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module Registries

using ..FastAI
using ..FastAI: FastAI, BlockLike, Label, LabelMulti, issubblock
using ..FastAI.Datasets
using ..FastAI.Datasets: DatasetLoader, DataDepLoader, isavailable, loaddata, typify

Expand Down Expand Up @@ -48,10 +48,12 @@ end
include("datasets.jl")
include("tasks.jl")
include("recipes.jl")
include("models.jl")

export datasets,
learningtasks,
datarecipes,
models,
find,
info,
load
Expand Down
Loading