From 0cd6a037bac059a01e3cfcce88ee87aff1196c5d Mon Sep 17 00:00:00 2001 From: Oscar Dowson Date: Thu, 2 Nov 2023 16:33:59 +1300 Subject: [PATCH] Fix variable_ref_type for unsupported types and GenericNonlinearExpr (#3556) --- src/nlp_expr.jl | 2 +- src/variables.jl | 8 ++++++++ test/test_nlp_expr.jl | 7 +++++++ test/test_variable.jl | 14 ++++++++++++++ 4 files changed, 30 insertions(+), 1 deletion(-) diff --git a/src/nlp_expr.jl b/src/nlp_expr.jl index e77e9cd09da..ac4fc79aea8 100644 --- a/src/nlp_expr.jl +++ b/src/nlp_expr.jl @@ -133,7 +133,7 @@ Alias for `GenericNonlinearExpr{VariableRef}`, the specific """ const NonlinearExpr = GenericNonlinearExpr{VariableRef} -variable_ref_type(::GenericNonlinearExpr{V}) where {V} = V +variable_ref_type(::Type{GenericNonlinearExpr{V}}) where {V} = V const _PREFIX_OPERATORS = (:+, :-, :*, :/, :^, :||, :&&, :>, :<, :(<=), :(>=), :(==)) diff --git a/src/variables.jl b/src/variables.jl index efdea16ddee..df70cf9e9a3 100644 --- a/src/variables.jl +++ b/src/variables.jl @@ -224,6 +224,14 @@ variable type associated with the model or expression type `F`. """ variable_ref_type(::F) where {F} = variable_ref_type(F) +function variable_ref_type(::Type{F}) where {F} + return error( + "Unable to compute the `variable_ref_type` of the type `$F`. If you " * + "are developing a JuMP extension, define a new method for " * + "`JuMP.variable_ref_type(::Type{$F})`", + ) +end + variable_ref_type(::Type{V}) where {V<:AbstractVariableRef} = V value_type(::Type{<:AbstractVariableRef}) = Float64 diff --git a/test/test_nlp_expr.jl b/test/test_nlp_expr.jl index e07dc1ed036..def51c63ff1 100644 --- a/test/test_nlp_expr.jl +++ b/test/test_nlp_expr.jl @@ -971,4 +971,11 @@ function test_operator_max() return end +function test_variable_ref_type() + for V in (GenericVariableRef{Int}, VariableRef) + @test variable_ref_type(GenericNonlinearExpr{V}) == V + end + return +end + end # module diff --git a/test/test_variable.jl b/test/test_variable.jl index 1029c10a20f..4e4a2221f91 100644 --- a/test/test_variable.jl +++ b/test/test_variable.jl @@ -1555,4 +1555,18 @@ function test_parameter_arrays() return end +function test_variable_ref_type_unsupported() + for F in (Vector{VariableRef}, Vector{Int}) + @test_throws( + ErrorException( + "Unable to compute the `variable_ref_type` of the type `$F`. If you " * + "are developing a JuMP extension, define a new method for " * + "`JuMP.variable_ref_type(::Type{$F})`", + ), + variable_ref_type(F), + ) + end + return +end + end # module TestVariable