From eda964d2201fe708b092c3d0dd8de3aa64fc92a4 Mon Sep 17 00:00:00 2001 From: Oscar Dowson Date: Fri, 15 Nov 2024 09:15:03 +1300 Subject: [PATCH 1/3] Fix bug which mutates user expressions in constraint macro --- src/macros/@constraint.jl | 11 +++++++++-- test/test_macros.jl | 27 +++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/src/macros/@constraint.jl b/src/macros/@constraint.jl index f3dcfe27551..2f5f835046a 100644 --- a/src/macros/@constraint.jl +++ b/src/macros/@constraint.jl @@ -475,7 +475,9 @@ function parse_constraint_head( "`$ub >= ... >= $lb`.", ) end - new_aff, parse_aff = _rewrite_expression(aff) + # Add +0 so that it creates a copy that we can mutate in future callers. + # We should fix this in MutableArithmetics. + new_aff, parse_aff = _rewrite_expression(:($aff + 0)) new_lb, parse_lb = _rewrite_expression(lb) new_ub, parse_ub = _rewrite_expression(ub) parse_code = quote @@ -765,7 +767,9 @@ function parse_constraint_call( func, set, ) - f, parse_code = _rewrite_expression(func) + # Add +0 so that it creates a copy that we can mutate in future callers. + # We should fix this in MutableArithmetics. + f, parse_code = _rewrite_expression(:($func + 0)) build_call = if vectorized :(build_constraint.($error_fn, _desparsify($f), $(esc(set)))) else @@ -1020,6 +1024,9 @@ function _clear_constant!(α::Number) return zero(α), α end +# !!! warning +# This method assumes that we can mutate `expr`. Ensure that this is the +# case upstream of this call site. function build_constraint( ::Function, expr::Union{Number,GenericAffExpr,GenericQuadExpr}, diff --git a/test/test_macros.jl b/test/test_macros.jl index 3ec99b5f498..905a7aa39b4 100644 --- a/test/test_macros.jl +++ b/test/test_macros.jl @@ -2487,4 +2487,31 @@ function test_array_scalar_sets() return end +function test_do_not_mutate_expression_double_sided_comparison() + model = Model() + @variable(model, x) + @expression(model, a[1:1], x+1) + @constraint(model, -1 <= a[1] <= 1) + @test isequal_canonical(a[1], x + 1) + return +end + +function test_do_not_mutate_expression_single_sided_comparison() + model = Model() + @variable(model, x) + @expression(model, a[1:1], x+1) + @constraint(model, a[1] >= 1) + @test isequal_canonical(a[1], x + 1) + return +end + +function test_do_not_mutate_expression_in_set() + model = Model() + @variable(model, x) + @expression(model, a[1:1], x+1) + @constraint(model, a[1] in MOI.Interval(-1, 1)) + @test isequal_canonical(a[1], x + 1) + return +end + end # module From 88453fb179a46569cc28d430ddd56f657b0b83be Mon Sep 17 00:00:00 2001 From: Oscar Dowson Date: Fri, 15 Nov 2024 09:56:31 +1300 Subject: [PATCH 2/3] Update --- src/macros.jl | 10 ++++++++++ src/macros/@constraint.jl | 8 ++------ 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/macros.jl b/src/macros.jl index 209c3ed3f41..c9749d84cc6 100644 --- a/src/macros.jl +++ b/src/macros.jl @@ -253,6 +253,16 @@ function _rewrite_expression(expr::Expr) new_expr = MacroTools.postwalk(_rewrite_to_jump_logic, expr) new_aff, parse_aff = _MA.rewrite(new_expr; move_factors_into_sums = false) ret = gensym() + has_copy_if_mutable = Ref(false) + MacroTools.postwalk(parse_aff) do x + if x === MutableArithmetics.copy_if_mutable + has_copy_if_mutable[] = true + end + return x + end + if !has_copy_if_mutable[] + new_aff = :($_MA.copy_if_mutable($new_aff)) + end code = quote $parse_aff $ret = $flatten!($new_aff) diff --git a/src/macros/@constraint.jl b/src/macros/@constraint.jl index 2f5f835046a..f1f6d289ecd 100644 --- a/src/macros/@constraint.jl +++ b/src/macros/@constraint.jl @@ -475,9 +475,7 @@ function parse_constraint_head( "`$ub >= ... >= $lb`.", ) end - # Add +0 so that it creates a copy that we can mutate in future callers. - # We should fix this in MutableArithmetics. - new_aff, parse_aff = _rewrite_expression(:($aff + 0)) + new_aff, parse_aff = _rewrite_expression(aff) new_lb, parse_lb = _rewrite_expression(lb) new_ub, parse_ub = _rewrite_expression(ub) parse_code = quote @@ -767,9 +765,7 @@ function parse_constraint_call( func, set, ) - # Add +0 so that it creates a copy that we can mutate in future callers. - # We should fix this in MutableArithmetics. - f, parse_code = _rewrite_expression(:($func + 0)) + f, parse_code = _rewrite_expression(func) build_call = if vectorized :(build_constraint.($error_fn, _desparsify($f), $(esc(set)))) else From 29ff799c6aa63b2d41bd3c88b6c0061f03147c88 Mon Sep 17 00:00:00 2001 From: Oscar Dowson Date: Fri, 15 Nov 2024 10:18:32 +1300 Subject: [PATCH 3/3] Apply suggestions from code review --- test/test_macros.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_macros.jl b/test/test_macros.jl index 905a7aa39b4..d923dfedacf 100644 --- a/test/test_macros.jl +++ b/test/test_macros.jl @@ -2490,7 +2490,7 @@ end function test_do_not_mutate_expression_double_sided_comparison() model = Model() @variable(model, x) - @expression(model, a[1:1], x+1) + @expression(model, a[1:1], x + 1) @constraint(model, -1 <= a[1] <= 1) @test isequal_canonical(a[1], x + 1) return @@ -2499,7 +2499,7 @@ end function test_do_not_mutate_expression_single_sided_comparison() model = Model() @variable(model, x) - @expression(model, a[1:1], x+1) + @expression(model, a[1:1], x + 1) @constraint(model, a[1] >= 1) @test isequal_canonical(a[1], x + 1) return @@ -2508,7 +2508,7 @@ end function test_do_not_mutate_expression_in_set() model = Model() @variable(model, x) - @expression(model, a[1:1], x+1) + @expression(model, a[1:1], x + 1) @constraint(model, a[1] in MOI.Interval(-1, 1)) @test isequal_canonical(a[1], x + 1) return