diff --git a/src/nlp_expr.jl b/src/nlp_expr.jl index 6fe42fcb5ee..4f1bf651ec0 100644 --- a/src/nlp_expr.jl +++ b/src/nlp_expr.jl @@ -506,6 +506,8 @@ function moi_function(f::GenericNonlinearExpr{V}) where {V} for i in length(f.args):-1:1 if f.args[i] isa GenericNonlinearExpr{V} push!(stack, (ret, i, f.args[i])) + elseif f.args[i] isa AbstractArray + ret.args[i] = moi_function.(f.args[i]) else ret.args[i] = moi_function(f.args[i]) end @@ -517,6 +519,8 @@ function moi_function(f::GenericNonlinearExpr{V}) where {V} for j in length(arg.args):-1:1 if arg.args[j] isa GenericNonlinearExpr{V} push!(stack, (child, j, arg.args[j])) + elseif arg.args[j] isa AbstractArray + child.args[j] = moi_function.(arg.args[j]) else child.args[j] = moi_function(arg.args[j]) end @@ -542,6 +546,8 @@ function jump_function(model::GenericModel, f::MOI.ScalarNonlinearFunction) end elseif arg isa Number push!(parent.args, arg) + elseif arg isa AbstractArray + push!(parent.args, jump_function.(model, arg)) else push!(parent.args, jump_function(model, arg)) end @@ -817,32 +823,34 @@ function Base.show(io::IO, f::NonlinearOperator) return print(io, "NonlinearOperator(:$(f.head), $(f.func))") end +const AbstractJuMPScalarOrArray = Union{AbstractJuMPScalar, AbstractArray{<:AbstractJuMPScalar}} + # Fast overload for unary calls (f::NonlinearOperator)(x) = f.func(x) -(f::NonlinearOperator)(x::AbstractJuMPScalar) = NonlinearExpr(f.head, Any[x]) +(f::NonlinearOperator)(x::AbstractJuMPScalarOrArray) = NonlinearExpr(f.head, Any[x]) # Fast overload for binary calls (f::NonlinearOperator)(x, y) = f.func(x, y) -function (f::NonlinearOperator)(x::AbstractJuMPScalar, y) +function (f::NonlinearOperator)(x::AbstractJuMPScalarOrArray, y) return GenericNonlinearExpr(f.head, Any[x, y]) end -function (f::NonlinearOperator)(x, y::AbstractJuMPScalar) +function (f::NonlinearOperator)(x, y::AbstractJuMPScalarOrArray) return GenericNonlinearExpr(f.head, Any[x, y]) end -function (f::NonlinearOperator)(x::AbstractJuMPScalar, y::AbstractJuMPScalar) +function (f::NonlinearOperator)(x::AbstractJuMPScalarOrArray, y::AbstractJuMPScalarOrArray) return GenericNonlinearExpr(f.head, Any[x, y]) end # Fallback for more arguments function (f::NonlinearOperator)(x, y, z...) args = (x, y, z...) - if any(Base.Fix2(isa, AbstractJuMPScalar), args) + if any(Base.Fix2(isa, AbstractJuMPScalarOrArray), args) return GenericNonlinearExpr(f.head, Any[a for a in args]) end return f.func(args...) diff --git a/test/test_nlp_expr.jl b/test/test_nlp_expr.jl index 53b6704073e..d56673fb683 100644 --- a/test/test_nlp_expr.jl +++ b/test/test_nlp_expr.jl @@ -6,6 +6,7 @@ module TestNLPExpr using JuMP +using LinearAlgebra using Test function test_extension_univariate_operators( @@ -828,4 +829,14 @@ function test_redefinition_of_function() return end +function test_array() + model = Model() + @variable(model, x) + op_norm = NonlinearOperator(:det, det) + @objective(model, Min, op_norm([x])) + f = MOI.get(model, MOI.ObjectiveFunction{MOI.ScalarNonlinearFunction}()) + @test f.head == :norm + @test f.args == [[index(x)]] +end + end # module