From f1fadddff02190389ada69e811337262e9496220 Mon Sep 17 00:00:00 2001 From: odow Date: Wed, 31 Jul 2024 12:17:55 +1200 Subject: [PATCH] Update --- src/macros/@constraint.jl | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/src/macros/@constraint.jl b/src/macros/@constraint.jl index b38ad11d2f7..c8e4c873553 100644 --- a/src/macros/@constraint.jl +++ b/src/macros/@constraint.jl @@ -601,6 +601,19 @@ struct Nonnegatives end operator_to_set(::Function, ::Union{Val{:(>=)},Val{:(≥)}}) = Nonnegatives() +""" + _OpGreaterThan() + +A struct used to intercept `>=` when used in macros. +""" +struct _OpGreaterThan end + +operator_to_set(f::Function, sym::Val, ::Val{true}) = operator_to_set(f, sym) + +function operator_to_set(::Function, ::Union{Val{:(>=)},Val{:(≥)}}, ::Val{true}) + return _OpGreaterThan() +end + """ Nonpositives() @@ -632,6 +645,12 @@ struct Nonpositives end operator_to_set(::Function, ::Union{Val{:(<=)},Val{:(≤)}}) = Nonpositives() +struct _OpLessThan end + +function operator_to_set(::Function, ::Union{Val{:(<=)},Val{:(≤)}}, ::Val{true}) + return _OpLessThan() +end + """ Zeros() @@ -780,10 +799,7 @@ function parse_constraint_call( ) func = vectorized ? :($lhs .- $rhs) : :($lhs - $rhs) f, parse_code = _rewrite_expression(func) - set = operator_to_set(error_fn, operator) - # So that we can call a special method to intercept the ambiguous cases of - # `x >= y` and `x <= y` with arrays. - set = _intercept_operator(set) + set = operator_to_set(error_fn, operator, Val{true}()) # `_functionize` deals with the pathological case where the `lhs` is a # `VariableRef` and the `rhs` is a summation with no terms. f = :(_functionize($f)) @@ -795,12 +811,6 @@ function parse_constraint_call( return parse_code, build_call end -_intercept_operator(x) = x - -struct _OpGreaterThan end - -_intercept_operator(::Nonnegatives) = _OpGreaterThan() - function build_constraint( error_fn::Function, f, @@ -841,10 +851,6 @@ function build_constraint( ) end -struct _OpLessThan end - -_intercept_operator(::Nonpositives) = _OpLessThan() - function build_constraint( error_fn::Function, f,