diff --git a/ext/MathOptAIPythonCallExt.jl b/ext/MathOptAIPythonCallExt.jl index 29e40de..3e9b4fc 100644 --- a/ext/MathOptAIPythonCallExt.jl +++ b/ext/MathOptAIPythonCallExt.jl @@ -43,9 +43,6 @@ function MathOptAI.add_predictor( ) inner_predictor = MathOptAI.build_predictor(predictor; config) if reduced_space - # If config maps to a ReducedSpace predictor, we'll get a MethodError - # when trying to add the nested redcued space predictors. - # TODO: raise a nicer error or try to handle this gracefully. inner_predictor = MathOptAI.ReducedSpace(inner_predictor) end return MathOptAI.add_predictor(model, inner_predictor, x) diff --git a/src/MathOptAI.jl b/src/MathOptAI.jl index c244cbf..60dbabb 100644 --- a/src/MathOptAI.jl +++ b/src/MathOptAI.jl @@ -128,6 +128,8 @@ struct ReducedSpace{P<:AbstractPredictor} <: AbstractPredictor predictor::P end +ReducedSpace(predictor::ReducedSpace) = predictor + include("utilities.jl") for file in filter( diff --git a/test/test_predictors.jl b/test/test_predictors.jl index 3aed416..d0c74a0 100644 --- a/test/test_predictors.jl +++ b/test/test_predictors.jl @@ -331,6 +331,12 @@ function test_ReducedSpace_Tanh() return end +function test_ReducedSpace_ReducedSpace() + predictor = MathOptAI.ReducedSpace(MathOptAI.Tanh()) + @test MathOptAI.ReducedSpace(predictor) === predictor + return +end + end # module TestPredictors.runtests()