Skip to content

Commit

Permalink
Improve error message for non-broadcasted addition and subtraction (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
odow authored Nov 6, 2023
1 parent c831cc6 commit ec3c779
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 47 deletions.
8 changes: 8 additions & 0 deletions src/nlp_expr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,14 @@ function _MA.operate!!(
return +(x, *(args...))
end

function _MA.operate!!(
::typeof(_MA.add_mul),
::GenericNonlinearExpr,
x::AbstractArray,
)
return _throw_operator_error(_MA.add_mul, x)
end

"""
flatten!(expr::GenericNonlinearExpr)
Expand Down
59 changes: 45 additions & 14 deletions src/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -452,22 +452,53 @@ function LinearAlgebra.issymmetric(x::Matrix{T}) where {T<:_JuMPTypes}
return true
end

function Base.:+(A::AbstractMatrix, x::AbstractJuMPScalar)
return error(
"Addition between a Matrix and a JuMP variable is not supported: instead of `A + x`, " *
"do `A .+ x` for element-wise addition, or if you are modifying the diagonal entries of the matrix " *
"do `A + x * LinearAlgebra.I(n)`, where `n` is the diagonal length.",
)
function _throw_operator_error(
::Union{typeof(+),typeof(_MA.add_mul)},
x::AbstractArray,
)
msg =
"Addition between an array and a JuMP scalar is not supported: " *
"instead of `x + y`, do `x .+ y` for element-wise addition."
if ndims(x) == 2 && size(x, 1) == size(x, 2)
msg *=
" If you are modifying the diagonal entries of a square matrix, " *
"do `x + y * LinearAlgebra.I(n)`, where `n` is the side length."
end
return error(msg)
end

function _throw_operator_error(
::Union{typeof(-),typeof(_MA.sub_mul)},
x::AbstractArray,
)
msg =
"Subtraction between an array and a JuMP scalar is not supported: " *
"instead of `x - y`, do `x .- y` for element-wise subtraction."
if ndims(x) == 2 && size(x, 1) == size(x, 2)
msg *=
" If you are modifying the diagonal entries of a square matrix, " *
"do `x - y * LinearAlgebra.I(n)`, where `n` is the side length."
end
return error(msg)
end

Base.:+(x::AbstractJuMPScalar, A::AbstractMatrix) = A + x
Base.:+(::AbstractJuMPScalar, x::AbstractArray) = _throw_operator_error(+, x)
Base.:+(x::AbstractArray, ::AbstractJuMPScalar) = _throw_operator_error(+, x)
Base.:-(::AbstractJuMPScalar, x::AbstractArray) = _throw_operator_error(-, x)
Base.:-(x::AbstractArray, ::AbstractJuMPScalar) = _throw_operator_error(-, x)

function Base.:-(A::AbstractMatrix, x::AbstractJuMPScalar)
return error(
"Subtraction between a Matrix and a JuMP variable is not supported: instead of `A - x`, " *
"do `A .- x` for element-wise subtraction, or if you are modifying the diagonal entries of the matrix " *
"do `A - x * LinearAlgebra.I(n)`, where `n` is the diagonal length.",
)
function _MA.operate!!(
op::Union{typeof(_MA.add_mul),typeof(_MA.sub_mul)},
x::AbstractArray,
::AbstractJuMPScalar,
)
return _throw_operator_error(op, x)
end

Base.:-(x::AbstractJuMPScalar, A::AbstractMatrix) = A - x
function _MA.operate!!(
op::Union{typeof(_MA.add_mul),typeof(_MA.sub_mul)},
::AbstractJuMPScalar,
x::AbstractArray,
)
return _throw_operator_error(op, x)
end
76 changes: 43 additions & 33 deletions test/test_operator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -628,39 +628,49 @@ end
function test_matrix_abstractscalar_add()
model = Model()
@variable(model, x)
A = rand(Float64, 2, 2)
@test_throws(
ErrorException(
"Addition between a Matrix and a JuMP variable is not supported: instead of `A + x`, " *
"do `A .+ x` for element-wise addition, or if you are modifying the diagonal entries of the matrix " *
"do `A + x * LinearAlgebra.I(n)`, where `n` is the diagonal length.",
),
A + x
),
@test_throws(
ErrorException(
"Addition between a Matrix and a JuMP variable is not supported: instead of `A + x`, " *
"do `A .+ x` for element-wise addition, or if you are modifying the diagonal entries of the matrix " *
"do `A + x * LinearAlgebra.I(n)`, where `n` is the diagonal length.",
),
x + A
),
@test_throws(
ErrorException(
"Subtraction between a Matrix and a JuMP variable is not supported: instead of `A - x`, " *
"do `A .- x` for element-wise subtraction, or if you are modifying the diagonal entries of the matrix " *
"do `A - x * LinearAlgebra.I(n)`, where `n` is the diagonal length.",
),
A - x
),
@test_throws(
ErrorException(
"Subtraction between a Matrix and a JuMP variable is not supported: instead of `A - x`, " *
"do `A .- x` for element-wise subtraction, or if you are modifying the diagonal entries of the matrix " *
"do `A - x * LinearAlgebra.I(n)`, where `n` is the diagonal length.",
),
x - A
),
A = rand(Float64, 3, 2)
B = rand(Float64, 3)
err_add = ErrorException(
"Addition between an array and a JuMP scalar is not supported: " *
"instead of `x + y`, do `x .+ y` for element-wise addition.",
)
err_sub = ErrorException(
"Subtraction between an array and a JuMP scalar is not supported: " *
"instead of `x - y`, do `x .- y` for element-wise subtraction.",
)
for lhs in (A, A', B, B'), rhs in (x, 1.0 * x, x^2, sin(x))
@test_throws(err_add, lhs + rhs)
@test_throws(err_add, rhs + lhs)
@test_throws(err_add, @expression(model, lhs + rhs))
@test_throws(err_add, @expression(model, rhs + lhs))
@test_throws(err_sub, lhs - rhs)
@test_throws(err_sub, rhs - lhs)
@test_throws(err_sub, @expression(model, lhs - rhs))
@test_throws(err_sub, @expression(model, rhs - lhs))
end
C = rand(Float64, 2, 2)
err_add = ErrorException(
"Addition between an array and a JuMP scalar is not supported: " *
"instead of `x + y`, do `x .+ y` for element-wise addition." *
" If you are modifying the diagonal entries of a square matrix, " *
"do `x + y * LinearAlgebra.I(n)`, where `n` is the side length.",
)
err_sub = ErrorException(
"Subtraction between an array and a JuMP scalar is not supported: " *
"instead of `x - y`, do `x .- y` for element-wise subtraction." *
" If you are modifying the diagonal entries of a square matrix, " *
"do `x - y * LinearAlgebra.I(n)`, where `n` is the side length.",
)
for lhs in (C, C'), rhs in (x, 1.0 * x, x^2, sin(x))
@test_throws(err_add, lhs + rhs)
@test_throws(err_add, rhs + lhs)
@test_throws(err_add, @expression(model, lhs + rhs))
@test_throws(err_add, @expression(model, rhs + lhs))
@test_throws(err_sub, lhs - rhs)
@test_throws(err_sub, rhs - lhs)
@test_throws(err_sub, @expression(model, lhs - rhs))
@test_throws(err_sub, @expression(model, rhs - lhs))
end
return
end

Expand Down

0 comments on commit ec3c779

Please sign in to comment.