diff --git a/src/nlp_expr.jl b/src/nlp_expr.jl index 60763939f6a..b76cbcd706b 100644 --- a/src/nlp_expr.jl +++ b/src/nlp_expr.jl @@ -81,15 +81,11 @@ struct GenericNonlinearExpr{V<:AbstractVariableRef} <: AbstractJuMPScalar head::Symbol args::Vector{Any} - function GenericNonlinearExpr(head::Symbol, args::Vector{Any}) - index = findfirst(Base.Fix2(isa, AbstractJuMPScalar), args) - if index === nothing - error( - "Unable to create a nonlinear expression because it did not " * - "contain any JuMP scalars. head = $head, args = $args.", - ) - end - return new{variable_ref_type(args[index])}(head, args) + function GenericNonlinearExpr{V}( + head::Symbol, + args::Vararg{Any}, + ) where {V<:AbstractVariableRef} + return new{V}(head, Any[a for a in args]) end function GenericNonlinearExpr{V}( @@ -100,6 +96,35 @@ struct GenericNonlinearExpr{V<:AbstractVariableRef} <: AbstractJuMPScalar end end +variable_ref_type(::Type{GenericNonlinearExpr}, ::Any) = nothing + +function variable_ref_type(::Type{GenericNonlinearExpr}, x::AbstractJuMPScalar) + return variable_ref_type(x) +end + +function _has_variable_ref_type(a) + return variable_ref_type(GenericNonlinearExpr, a) !== nothing +end + +function _variable_ref_type(head, args) + if (i = findfirst(_has_variable_ref_type, args)) !== nothing + V = variable_ref_type(GenericNonlinearExpr, args[i]) + return V::Type{<:AbstractVariableRef} + end + return error( + "Unable to create a nonlinear expression because it did not contain " * + "any JuMP scalars. head = `:$head`, args = `$args`.", + ) +end + +function GenericNonlinearExpr(head::Symbol, args::Vector{Any}) + return GenericNonlinearExpr{_variable_ref_type(head, args)}(head, args) +end + +function GenericNonlinearExpr(head::Symbol, args::Vararg{Any,N}) where {N} + return GenericNonlinearExpr{_variable_ref_type(head, args)}(head, args...) +end + """ NonlinearExpr @@ -110,15 +135,6 @@ const NonlinearExpr = GenericNonlinearExpr{VariableRef} variable_ref_type(::GenericNonlinearExpr{V}) where {V} = V -# We include this method so that we can refactor the internal representation of -# GenericNonlinearExpr without having to rewrite the method overloads. -function GenericNonlinearExpr{V}( - head::Symbol, - args..., -) where {V<:AbstractVariableRef} - return GenericNonlinearExpr{V}(head, Any[args...]) -end - const _PREFIX_OPERATORS = (:+, :-, :*, :/, :^, :||, :&&, :>, :<, :(<=), :(>=), :(==)) @@ -527,6 +543,8 @@ function moi_function(f::GenericNonlinearExpr{V}) where {V} return ret end +jump_function(::GenericModel{T}, x::Number) where {T} = convert(T, x) + function jump_function(model::GenericModel, f::MOI.ScalarNonlinearFunction) V = variable_ref_type(typeof(model)) ret = GenericNonlinearExpr{V}(f.head, Any[]) @@ -542,8 +560,6 @@ function jump_function(model::GenericModel, f::MOI.ScalarNonlinearFunction) for child in reverse(arg.args) push!(stack, (new_ret, child)) end - elseif arg isa Number - push!(parent.args, arg) else push!(parent.args, jump_function(model, arg)) end @@ -833,33 +849,10 @@ function Base.show(io::IO, f::NonlinearOperator) return print(io, "NonlinearOperator($(f.func), :$(f.head))") end -# Fast overload for unary calls - -(f::NonlinearOperator)(x) = f.func(x) - -(f::NonlinearOperator)(x::AbstractJuMPScalar) = NonlinearExpr(f.head, Any[x]) - -# Fast overload for binary calls - -(f::NonlinearOperator)(x, y) = f.func(x, y) - -function (f::NonlinearOperator)(x::AbstractJuMPScalar, y) - return GenericNonlinearExpr(f.head, Any[x, y]) -end - -function (f::NonlinearOperator)(x, y::AbstractJuMPScalar) - return GenericNonlinearExpr(f.head, Any[x, y]) -end - -function (f::NonlinearOperator)(x::AbstractJuMPScalar, y::AbstractJuMPScalar) - return GenericNonlinearExpr(f.head, Any[x, y]) -end - -# Fallback for more arguments -function (f::NonlinearOperator)(x, y, z...) - args = (x, y, z...) - if any(Base.Fix2(isa, AbstractJuMPScalar), args) - return GenericNonlinearExpr(f.head, Any[a for a in args]) +function (f::NonlinearOperator)(args::Vararg{Any,N}) where {N} + types = variable_ref_type.(GenericNonlinearExpr, args) + if (i = findfirst(!isnothing, types)) !== nothing + return GenericNonlinearExpr{types[i]}(f.head, args...) end return f.func(args...) end diff --git a/src/operators.jl b/src/operators.jl index 7c0ec273ae4..5310234ddc0 100644 --- a/src/operators.jl +++ b/src/operators.jl @@ -195,20 +195,19 @@ function Base.:/(lhs::GenericAffExpr, rhs::_Constant) return map_coefficients(c -> c / rhs, lhs) end -function Base.:^(lhs::AbstractVariableRef, rhs::Integer) - T = value_type(typeof(lhs)) +function Base.:^(lhs::V, rhs::Integer) where {V<:AbstractVariableRef} if rhs == 0 - return one(T) + return one(value_type(V)) elseif rhs == 1 return lhs elseif rhs == 2 return lhs * lhs else - return GenericNonlinearExpr(:^, Any[lhs, rhs]) + return GenericNonlinearExpr{V}(:^, Any[lhs, rhs]) end end -function Base.:^(lhs::GenericAffExpr{T}, rhs::Integer) where {T} +function Base.:^(lhs::GenericAffExpr{T,V}, rhs::Integer) where {T,V} if rhs == 0 return one(T) elseif rhs == 1 @@ -216,7 +215,7 @@ function Base.:^(lhs::GenericAffExpr{T}, rhs::Integer) where {T} elseif rhs == 2 return lhs * lhs else - return GenericNonlinearExpr(:^, Any[lhs, rhs]) + return GenericNonlinearExpr{V}(:^, Any[lhs, rhs]) end end diff --git a/test/test_nlp.jl b/test/test_nlp.jl index d4f34cfe56a..ff8be02a7bc 100644 --- a/test/test_nlp.jl +++ b/test/test_nlp.jl @@ -1605,7 +1605,7 @@ function test_parse_expression_nonlinearexpr_call() model = Model() @variable(model, x) @variable(model, y) - f = GenericNonlinearExpr(:ifelse, Any[x, 0, y]) + f = NonlinearExpr(:ifelse, Any[x, 0, y]) @NLexpression(model, ref, f) nlp = nonlinear_model(model) expr = :(ifelse($x, 0, $y)) @@ -1617,7 +1617,7 @@ function test_parse_expression_nonlinearexpr_or() model = Model() @variable(model, x) @variable(model, y) - f = GenericNonlinearExpr(:||, Any[x, y]) + f = NonlinearExpr(:||, Any[x, y]) @NLexpression(model, ref, f) nlp = nonlinear_model(model) expr = :($x || $y) @@ -1629,7 +1629,7 @@ function test_parse_expression_nonlinearexpr_and() model = Model() @variable(model, x) @variable(model, y) - f = GenericNonlinearExpr(:&&, Any[x, y]) + f = NonlinearExpr(:&&, Any[x, y]) @NLexpression(model, ref, f) nlp = nonlinear_model(model) expr = :($x && $y) @@ -1641,7 +1641,7 @@ function test_parse_expression_nonlinearexpr_unsupported() model = Model() @variable(model, x) @variable(model, y) - f = GenericNonlinearExpr(:foo, Any[x, y]) + f = NonlinearExpr(:foo, Any[x, y]) @test_throws( MOI.UnsupportedNonlinearOperator, @NLexpression(model, ref, f), @@ -1653,8 +1653,8 @@ function test_parse_expression_nonlinearexpr_nested_comparison() model = Model() @variable(model, x) @variable(model, y) - f = GenericNonlinearExpr(:||, Any[x, y]) - g = GenericNonlinearExpr(:&&, Any[f, x]) + f = NonlinearExpr(:||, Any[x, y]) + g = NonlinearExpr(:&&, Any[f, x]) @NLexpression(model, ref, g) nlp = nonlinear_model(model) expr = :(($x || $y) && $x) diff --git a/test/test_nlp_expr.jl b/test/test_nlp_expr.jl index 092c2691b80..bd309f4c1d8 100644 --- a/test/test_nlp_expr.jl +++ b/test/test_nlp_expr.jl @@ -447,46 +447,46 @@ function test_extension_nl_macro( @variable(model, x) @test isequal_canonical( @expression(model, ifelse(x, 1, 2)), - GenericNonlinearExpr(:ifelse, Any[x, 1, 2]), + GenericNonlinearExpr{VariableRefType}(:ifelse, Any[x, 1, 2]), ) @test isequal_canonical( @expression(model, x || 1), - GenericNonlinearExpr(:||, Any[x, 1]), + GenericNonlinearExpr{VariableRefType}(:||, Any[x, 1]), ) @test isequal_canonical( @expression(model, x && 1), - GenericNonlinearExpr(:&&, Any[x, 1]), + GenericNonlinearExpr{VariableRefType}(:&&, Any[x, 1]), ) @test isequal_canonical( @expression(model, x < 0), - GenericNonlinearExpr(:<, Any[x, 0]), + GenericNonlinearExpr{VariableRefType}(:<, Any[x, 0]), ) @test isequal_canonical( @expression(model, x > 0), - GenericNonlinearExpr(:>, Any[x, 0]), + GenericNonlinearExpr{VariableRefType}(:>, Any[x, 0]), ) @test isequal_canonical( @expression(model, x <= 0), - GenericNonlinearExpr(:<=, Any[x, 0]), + GenericNonlinearExpr{VariableRefType}(:<=, Any[x, 0]), ) @test isequal_canonical( @expression(model, x >= 0), - GenericNonlinearExpr(:>=, Any[x, 0]), + GenericNonlinearExpr{VariableRefType}(:>=, Any[x, 0]), ) @test isequal_canonical( @expression(model, x == 0), - GenericNonlinearExpr(:(==), Any[x, 0]), + GenericNonlinearExpr{VariableRefType}(:(==), Any[x, 0]), ) @test isequal_canonical( @expression(model, 0 < x <= 1), - GenericNonlinearExpr( + GenericNonlinearExpr{VariableRefType}( :&&, Any[@expression(model, 0 < x), @expression(model, x <= 1)], ), ) @test isequal_canonical( @expression(model, ifelse(x > 0, x^2, sin(x))), - GenericNonlinearExpr( + GenericNonlinearExpr{VariableRefType}( :ifelse, Any[@expression(model, x > 0), x^2, sin(x)], ), @@ -501,7 +501,7 @@ function test_register_univariate() @test f isa NonlinearOperator @test sprint(show, f) == "NonlinearOperator($(f.func), :f)" @test isequal_canonical(@expression(model, f(x)), f(x)) - @test isequal_canonical(f(x), GenericNonlinearExpr(:f, Any[x])) + @test isequal_canonical(f(x), NonlinearExpr(:f, Any[x])) attrs = MOI.get(model, MOI.ListOfModelAttributesSet()) @test MOI.UserDefinedFunction(:f, 1) in attrs return @@ -522,7 +522,7 @@ function test_register_univariate_gradient() @variable(model, x) @operator(model, f, 1, x -> x^2, x -> 2 * x) @test isequal_canonical(@expression(model, f(x)), f(x)) - @test isequal_canonical(f(x), GenericNonlinearExpr(:f, Any[x])) + @test isequal_canonical(f(x), NonlinearExpr(:f, Any[x])) attrs = MOI.get(model, MOI.ListOfModelAttributesSet()) @test MOI.UserDefinedFunction(:f, 1) in attrs return @@ -533,7 +533,7 @@ function test_register_univariate_gradient_hessian() @variable(model, x) @operator(model, f, 1, x -> x^2, x -> 2 * x, x -> 2.0) @test isequal_canonical(@expression(model, f(x)), f(x)) - @test isequal_canonical(f(x), GenericNonlinearExpr(:f, Any[x])) + @test isequal_canonical(f(x), NonlinearExpr(:f, Any[x])) attrs = MOI.get(model, MOI.ListOfModelAttributesSet()) @test MOI.UserDefinedFunction(:f, 1) in attrs return @@ -545,7 +545,7 @@ function test_register_multivariate() f = (x...) -> sum(x .^ 2) @operator(model, foo, 2, f) @test isequal_canonical(@expression(model, foo(x...)), foo(x...)) - @test isequal_canonical(foo(x...), GenericNonlinearExpr(:foo, Any[x...])) + @test isequal_canonical(foo(x...), NonlinearExpr(:foo, Any[x...])) attrs = MOI.get(model, MOI.ListOfModelAttributesSet()) @test MOI.UserDefinedFunction(:foo, 2) in attrs return @@ -558,7 +558,7 @@ function test_register_multivariate_gradient() ∇f = (g, x...) -> (g .= 2 .* x) @operator(model, foo, 2, f, ∇f) @test isequal_canonical(@expression(model, foo(x...)), foo(x...)) - @test isequal_canonical(foo(x...), GenericNonlinearExpr(:foo, Any[x...])) + @test isequal_canonical(foo(x...), NonlinearExpr(:foo, Any[x...])) attrs = MOI.get(model, MOI.ListOfModelAttributesSet()) @test MOI.UserDefinedFunction(:foo, 2) in attrs return @@ -576,7 +576,7 @@ function test_register_multivariate_gradient_hessian() end @operator(model, foo, 2, f, ∇f, ∇²f) @test isequal_canonical(@expression(model, foo(x...)), foo(x...)) - @test isequal_canonical(foo(x...), GenericNonlinearExpr(:foo, Any[x...])) + @test isequal_canonical(foo(x...), NonlinearExpr(:foo, Any[x...])) attrs = MOI.get(model, MOI.ListOfModelAttributesSet()) @test MOI.UserDefinedFunction(:foo, 2) in attrs return @@ -587,7 +587,7 @@ function test_register_multivariate_many_args() @variable(model, x[1:10]) f = (x...) -> sum(x .^ 2) @operator(model, foo, 10, f) - @test isequal_canonical(foo(x...), GenericNonlinearExpr(:foo, Any[x...])) + @test isequal_canonical(foo(x...), NonlinearExpr(:foo, Any[x...])) @test foo((1:10)...) == 385 return end @@ -607,18 +607,6 @@ function test_register_errors() return end -function test_expression_no_variable() - head, args = :sin, Any[1] - @test_throws( - ErrorException( - "Unable to create a nonlinear expression because it did not " * - "contain any JuMP scalars. head = $head, args = $args.", - ), - GenericNonlinearExpr(head, args), - ) - return -end - function test_value_expression() model = Model() @variable(model, x) @@ -676,7 +664,7 @@ end function test_nonlinear_expr_owner_model() model = Model() @variable(model, x) - f = GenericNonlinearExpr(:sin, Any[x]) + f = NonlinearExpr(:sin, Any[x]) # This shouldn't happen in regular code, but let's test against it to check # we get something similar to AffExpr and QuadExpr. empty!(f.args) @@ -900,4 +888,42 @@ function test_ma_zero_in_operate!!() return end +function test_nonlinear_operator_inferred() + model = Model() + @variable(model, x) + @inferred op_less_than_or_equal_to(x, 1) + @test @inferred(op_less_than_or_equal_to(1, 2)) == true + return +end + +function test_generic_nonlinear_expr_infer_variable_type() + model = Model() + @variable(model, x) + @inferred GenericNonlinearExpr(:sin, x) + @inferred GenericNonlinearExpr GenericNonlinearExpr(:sin, Any[x]) + f = sin(x) + @test isequal_canonical(GenericNonlinearExpr(:sin, x), f) + @test isequal_canonical(GenericNonlinearExpr(:sin, Any[x]), f) + g = @expression(model, 1 <= x) + @inferred GenericNonlinearExpr(:<=, 1, x) + @inferred GenericNonlinearExpr GenericNonlinearExpr(:<=, Any[1, x]) + @test isequal_canonical(GenericNonlinearExpr(:<=, 1, x), g) + @test isequal_canonical(GenericNonlinearExpr(:<=, Any[1, x]), g) + @test_throws( + ErrorException( + "Unable to create a nonlinear expression because it did not " * + "contain any JuMP scalars. head = `:sin`, args = `(1,)`.", + ), + GenericNonlinearExpr(:sin, 1), + ) + @test_throws( + ErrorException( + "Unable to create a nonlinear expression because it did not " * + "contain any JuMP scalars. head = `:sin`, args = `Any[1]`.", + ), + GenericNonlinearExpr(:sin, Any[1]), + ) + return +end + end # module diff --git a/test/test_operator.jl b/test/test_operator.jl index ebda3be21b5..967fb8b666d 100644 --- a/test/test_operator.jl +++ b/test/test_operator.jl @@ -621,7 +621,7 @@ function test_complex_pow() @test y^0 == (1.0 + 0im) @test y^1 == y @test y^2 == y * y - @test isequal_canonical(y^3, GenericNonlinearExpr(:^, Any[y, 3])) + @test isequal_canonical(y^3, NonlinearExpr(:^, Any[y, 3])) return end