From eda964d2201fe708b092c3d0dd8de3aa64fc92a4 Mon Sep 17 00:00:00 2001 From: Oscar Dowson Date: Fri, 15 Nov 2024 09:15:03 +1300 Subject: [PATCH] Fix bug which mutates user expressions in constraint macro --- src/macros/@constraint.jl | 11 +++++++++-- test/test_macros.jl | 27 +++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/src/macros/@constraint.jl b/src/macros/@constraint.jl index f3dcfe27551..2f5f835046a 100644 --- a/src/macros/@constraint.jl +++ b/src/macros/@constraint.jl @@ -475,7 +475,9 @@ function parse_constraint_head( "`$ub >= ... >= $lb`.", ) end - new_aff, parse_aff = _rewrite_expression(aff) + # Add +0 so that it creates a copy that we can mutate in future callers. + # We should fix this in MutableArithmetics. + new_aff, parse_aff = _rewrite_expression(:($aff + 0)) new_lb, parse_lb = _rewrite_expression(lb) new_ub, parse_ub = _rewrite_expression(ub) parse_code = quote @@ -765,7 +767,9 @@ function parse_constraint_call( func, set, ) - f, parse_code = _rewrite_expression(func) + # Add +0 so that it creates a copy that we can mutate in future callers. + # We should fix this in MutableArithmetics. + f, parse_code = _rewrite_expression(:($func + 0)) build_call = if vectorized :(build_constraint.($error_fn, _desparsify($f), $(esc(set)))) else @@ -1020,6 +1024,9 @@ function _clear_constant!(α::Number) return zero(α), α end +# !!! warning +# This method assumes that we can mutate `expr`. Ensure that this is the +# case upstream of this call site. function build_constraint( ::Function, expr::Union{Number,GenericAffExpr,GenericQuadExpr}, diff --git a/test/test_macros.jl b/test/test_macros.jl index 3ec99b5f498..905a7aa39b4 100644 --- a/test/test_macros.jl +++ b/test/test_macros.jl @@ -2487,4 +2487,31 @@ function test_array_scalar_sets() return end +function test_do_not_mutate_expression_double_sided_comparison() + model = Model() + @variable(model, x) + @expression(model, a[1:1], x+1) + @constraint(model, -1 <= a[1] <= 1) + @test isequal_canonical(a[1], x + 1) + return +end + +function test_do_not_mutate_expression_single_sided_comparison() + model = Model() + @variable(model, x) + @expression(model, a[1:1], x+1) + @constraint(model, a[1] >= 1) + @test isequal_canonical(a[1], x + 1) + return +end + +function test_do_not_mutate_expression_in_set() + model = Model() + @variable(model, x) + @expression(model, a[1:1], x+1) + @constraint(model, a[1] in MOI.Interval(-1, 1)) + @test isequal_canonical(a[1], x + 1) + return +end + end # module