Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
odow committed Aug 1, 2024
1 parent b5ceaac commit f1faddd
Showing 1 changed file with 20 additions and 14 deletions.
34 changes: 20 additions & 14 deletions src/macros/@constraint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,19 @@ struct Nonnegatives end

operator_to_set(::Function, ::Union{Val{:(>=)},Val{:(≥)}}) = Nonnegatives()

"""
_OpGreaterThan()
A struct used to intercept `>=` when used in macros.
"""
struct _OpGreaterThan end

operator_to_set(f::Function, sym::Val, ::Val{true}) = operator_to_set(f, sym)

function operator_to_set(::Function, ::Union{Val{:(>=)},Val{:(≥)}}, ::Val{true})
return _OpGreaterThan()
end

"""
Nonpositives()
Expand Down Expand Up @@ -632,6 +645,12 @@ struct Nonpositives end

operator_to_set(::Function, ::Union{Val{:(<=)},Val{:(≤)}}) = Nonpositives()

struct _OpLessThan end

function operator_to_set(::Function, ::Union{Val{:(<=)},Val{:(≤)}}, ::Val{true})
return _OpLessThan()
end

"""
Zeros()
Expand Down Expand Up @@ -780,10 +799,7 @@ function parse_constraint_call(
)
func = vectorized ? :($lhs .- $rhs) : :($lhs - $rhs)
f, parse_code = _rewrite_expression(func)
set = operator_to_set(error_fn, operator)
# So that we can call a special method to intercept the ambiguous cases of
# `x >= y` and `x <= y` with arrays.
set = _intercept_operator(set)
set = operator_to_set(error_fn, operator, Val{true}())
# `_functionize` deals with the pathological case where the `lhs` is a
# `VariableRef` and the `rhs` is a summation with no terms.
f = :(_functionize($f))
Expand All @@ -795,12 +811,6 @@ function parse_constraint_call(
return parse_code, build_call
end

_intercept_operator(x) = x

struct _OpGreaterThan end

_intercept_operator(::Nonnegatives) = _OpGreaterThan()

function build_constraint(
error_fn::Function,
f,
Expand Down Expand Up @@ -841,10 +851,6 @@ function build_constraint(
)
end

struct _OpLessThan end

_intercept_operator(::Nonpositives) = _OpLessThan()

function build_constraint(
error_fn::Function,
f,
Expand Down

0 comments on commit f1faddd

Please sign in to comment.