diff --git a/ext/MathOptAIFluxExt.jl b/ext/MathOptAIFluxExt.jl index c3df772..94e940b 100644 --- a/ext/MathOptAIFluxExt.jl +++ b/ext/MathOptAIFluxExt.jl @@ -141,6 +141,10 @@ function MathOptAI.build_predictor( return inner_predictor end +function _add_predictor(::MathOptAI.Pipeline, layer::Any, ::Dict) + return error("Unsupported layer: $layer") +end + _default(::typeof(identity)) = nothing _default(::Any) = missing _default(::typeof(Flux.relu)) = MathOptAI.ReLU() diff --git a/ext/MathOptAILuxExt.jl b/ext/MathOptAILuxExt.jl index 57dc99f..9690337 100644 --- a/ext/MathOptAILuxExt.jl +++ b/ext/MathOptAILuxExt.jl @@ -160,6 +160,10 @@ function MathOptAI.build_predictor( return inner_predictor end +function _add_predictor(::MathOptAI.Pipeline, layer::Any, ::Any, ::Dict) + return error("Unsupported layer: $layer") +end + _default(::typeof(identity)) = nothing _default(::Any) = missing _default(::typeof(Lux.relu)) = MathOptAI.ReLU() @@ -182,7 +186,7 @@ end function _add_predictor( predictor::MathOptAI.Pipeline, layer::Lux.Dense, - p, + p::Any, config::Dict, ) push!(predictor.layers, MathOptAI.Affine(p.weight, vec(p.bias))) @@ -193,7 +197,7 @@ end function _add_predictor( predictor::MathOptAI.Pipeline, layer::Lux.Scale, - p, + p::Any, config::Dict, ) push!(predictor.layers, MathOptAI.Scale(p.weight, p.bias)) diff --git a/test/test_Flux.jl b/test/test_Flux.jl index 2b9c3d4..86088a3 100644 --- a/test/test_Flux.jl +++ b/test/test_Flux.jl @@ -190,6 +190,17 @@ function test_end_to_end_Tanh() return end +function test_unsupported_layer() + layer = Flux.Conv((5, 5), 3 => 7) + model = Model() + @variable(model, x[1:2]) + @test_throws( + ErrorException("Unsupported layer: $layer"), + MathOptAI.add_predictor(model, Flux.Chain(layer), x), + ) + return +end + end # module TestFluxExt.runtests() diff --git a/test/test_Lux.jl b/test/test_Lux.jl index 42b79cf..38e507f 100644 --- a/test/test_Lux.jl +++ b/test/test_Lux.jl @@ -198,6 +198,20 @@ function test_end_to_end_Tanh() return end +function test_unsupported_layer() + layer = Lux.Conv((5, 5), 3 => 7) + rng = Random.MersenneTwister() + ml_model = Lux.Chain(layer, layer) + parameters, state = Lux.setup(rng, ml_model) + model = Model() + @variable(model, x[1:2]) + @test_throws( + ErrorException("Unsupported layer: $layer"), + MathOptAI.add_predictor(model, (ml_model, parameters, state), x), + ) + return +end + end # module TestLuxExt.runtests()