Skip to content

Commit

Permalink
Fix bug which mutates user expressions in constraint macro
Browse files Browse the repository at this point in the history
  • Loading branch information
odow committed Nov 14, 2024
1 parent 7111683 commit eda964d
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 2 deletions.
11 changes: 9 additions & 2 deletions src/macros/@constraint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,9 @@ function parse_constraint_head(
"`$ub >= ... >= $lb`.",
)
end
new_aff, parse_aff = _rewrite_expression(aff)
# Add +0 so that it creates a copy that we can mutate in future callers.
# We should fix this in MutableArithmetics.
new_aff, parse_aff = _rewrite_expression(:($aff + 0))
new_lb, parse_lb = _rewrite_expression(lb)
new_ub, parse_ub = _rewrite_expression(ub)
parse_code = quote
Expand Down Expand Up @@ -765,7 +767,9 @@ function parse_constraint_call(
func,
set,
)
f, parse_code = _rewrite_expression(func)
# Add +0 so that it creates a copy that we can mutate in future callers.
# We should fix this in MutableArithmetics.
f, parse_code = _rewrite_expression(:($func + 0))
build_call = if vectorized
:(build_constraint.($error_fn, _desparsify($f), $(esc(set))))
else
Expand Down Expand Up @@ -1020,6 +1024,9 @@ function _clear_constant!(α::Number)
return zero(α), α
end

# !!! warning
# This method assumes that we can mutate `expr`. Ensure that this is the
# case upstream of this call site.
function build_constraint(
::Function,
expr::Union{Number,GenericAffExpr,GenericQuadExpr},
Expand Down
27 changes: 27 additions & 0 deletions test/test_macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2487,4 +2487,31 @@ function test_array_scalar_sets()
return
end

function test_do_not_mutate_expression_double_sided_comparison()
model = Model()
@variable(model, x)
@expression(model, a[1:1], x+1)
@constraint(model, -1 <= a[1] <= 1)
@test isequal_canonical(a[1], x + 1)
return
end

function test_do_not_mutate_expression_single_sided_comparison()
model = Model()
@variable(model, x)
@expression(model, a[1:1], x+1)
@constraint(model, a[1] >= 1)
@test isequal_canonical(a[1], x + 1)
return
end

function test_do_not_mutate_expression_in_set()
model = Model()
@variable(model, x)
@expression(model, a[1:1], x+1)
@constraint(model, a[1] in MOI.Interval(-1, 1))
@test isequal_canonical(a[1], x + 1)
return
end

end # module

0 comments on commit eda964d

Please sign in to comment.