Skip to content

Commit

Permalink
Tidy macros/@constraint.jl (#3612)
Browse files Browse the repository at this point in the history
  • Loading branch information
odow authored Dec 9, 2023
1 parent e199181 commit 4b1269f
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 113 deletions.
14 changes: 13 additions & 1 deletion src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -343,9 +343,12 @@ julia> call
:(f(1, a = 2, \$(Expr(:escape, :(\$(Expr(:kw, :b, 3))))), \$(Expr(:escape, :(\$(Expr(:kw, :c, 4)))))))
```
"""
function _add_kw_args(call, kw_args)
function _add_kw_args(call, kw_args; exclude = Symbol[])
for kw in kw_args
@assert Meta.isexpr(kw, :(=))
if kw.args[1] in exclude
continue
end
push!(call.args, esc(Expr(:kw, kw.args...)))
end
return
Expand Down Expand Up @@ -662,6 +665,15 @@ function _wrap_let(model, code)
return code
end

function _get_kwarg_value(kwargs, key::Symbol; default = nothing)
for kwarg in kwargs
if kwarg.args[1] == key
return esc(kwarg.args[2])
end
end
return default
end

include("macros/@objective.jl")
include("macros/@expression.jl")
include("macros/@constraint.jl")
Expand Down
198 changes: 86 additions & 112 deletions src/macros/@constraint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
# file, You can obtain one at https://mozilla.org/MPL/2.0/.

"""
@constraint(m::GenericModel, expr, kw_args...)
@constraint(model::GenericModel, expr, kwargs...)
Add a constraint described by the expression `expr`.
@constraint(m::GenericModel, ref[i=..., j=..., ...], expr, kw_args...)
@constraint(model::GenericModel, ref[i=..., j=..., ...], expr, kwargs...)
Add a group of constraints described by the expression `expr` parametrized by
`i`, `j`, ...
Expand All @@ -21,24 +21,29 @@ The expression `expr` can either be
[`RotatedSecondOrderCone`](@ref) and [`PSDCone`](@ref), e.g.
`@constraint(model, [1, x-1, y-2] in SecondOrderCone())` constrains the norm
of `[x-1, y-2]` be less than 1;
* of the form `a sign b`, where `sign` is one of `==`, `≥`, `>=`, `≤` and
`<=` building the single constraint enforcing the comparison to hold for the
expression `a` and `b`, e.g. `@constraint(m, x^2 + y^2 == 1)` constrains `x`
and `y` to lie on the unit circle;
expression `a` and `b`, e.g. `@constraint(model, x^2 + y^2 == 1)` constrains
`x` and `y` to lie on the unit circle;
* of the form `a ≤ b ≤ c` or `a ≥ b ≥ c` (where `≤` and `<=` (resp. `≥` and
`>=`) can be used interchangeably) constraining the paired the expression
`b` to lie between `a` and `c`;
* of the forms `@constraint(m, a .sign b)` or
`@constraint(m, a .sign b .sign c)` which broadcast the constraint creation to
each element of the vectors.
The recognized keyword arguments in `kw_args` are the following:
The recognized keyword arguments in `kwargs` are the following:
* `base_name`: Sets the name prefix used to generate constraint names. It
corresponds to the constraint name for scalar constraints, otherwise, the
constraint names are set to `base_name[...]` for each index `...` of the axes
`axes`.
* `container`: Specify the container type.
* `set_string_name::Bool = true`: control whether to set the
[`MOI.ConstraintName`](@ref) attribute. Passing `set_string_name = false` can
improve performance.
Expand All @@ -47,8 +52,10 @@ The recognized keyword arguments in `kw_args` are the following:
Each constraint will be created using
`add_constraint(m, build_constraint(error_fn, func, set))` where
* `error_fn` is an error function showing the constraint call in addition to the
error message given as argument,
* `func` is the expression that is constrained
* and `set` is the set in which it is constrained to belong.
Expand All @@ -68,8 +75,8 @@ For extensions that need to create constraints with more information than just
`func` and `set`, an additional positional argument can be specified to
`@constraint` that will then be passed on `build_constraint`. Hence, we can
enable this syntax by defining extensions of
`build_constraint(error_fn, func, set, my_arg; kw_args...)`. This produces the
user syntax: `@constraint(model, ref[...], expr, my_arg, kw_args...)`.
`build_constraint(error_fn, func, set, my_arg; kwargs...)`. This produces the
user syntax: `@constraint(model, ref[...], expr, my_arg, kwargs...)`.
"""
macro constraint(args...)
return _constraint_macro(args, :constraint, parse_constraint, __source__)
Expand Down Expand Up @@ -139,46 +146,38 @@ julia> @build_constraint(2x >= 1)
ScalarConstraint{AffExpr, MathOptInterface.GreaterThan{Float64}}(2 x, MathOptInterface.GreaterThan{Float64}(1.0))
```
"""
macro build_constraint(constraint_expr)
macro build_constraint(arg)
function error_fn(str...)
return _macro_error(
:build_constraint,
(constraint_expr,),
__source__,
str...,
)
return _macro_error(:build_constraint, (arg,), __source__, str...)
end

if isa(constraint_expr, Symbol)
if arg isa Symbol
error_fn(
"Incomplete constraint specification $constraint_expr. " *
"Incomplete constraint specification $arg. " *
"Are you missing a comparison (<=, >=, or ==)?",
)
end

is_vectorized, parse_code, build_call =
parse_constraint(error_fn, constraint_expr)
result_variable = gensym()
code = quote
_, parse_code, build_call = parse_constraint(error_fn, arg)
return quote
$parse_code
$result_variable = $build_call
$build_call
end

return code
end

"""
_constraint_macro(
args, macro_name::Symbol, parsefun::Function, source::LineNumberNode
args,
macro_name::Symbol,
parse_fn::Function,
source::LineNumberNode,
)
Returns the code for the macro `@constraint args...` of syntax
```julia
@constraint(model, con, extra_arg, kw_args...) # single constraint
@constraint(model, ref, con, extra_arg, kw_args...) # group of constraints
@constraint(model, con, extra_arg, kwargs...) # single constraint
@constraint(model, ref, con, extra_arg, kwargs...) # group of constraints
```
The expression `con` is parsed by `parsefun` which returns a `build_constraint`
The expression `con` is parsed by `parse_fn` which returns a `build_constraint`
call code that, when executed, returns an `AbstractConstraint`. The macro
keyword arguments (except the `container` keyword argument which is used to
determine the container type) are added to the `build_constraint` call. The
Expand All @@ -192,33 +191,24 @@ called from in the user's code. One way of generating this is via the hidden
variable `__source__`.
"""
function _constraint_macro(
args,
input_args,
macro_name::Symbol,
parsefun::Function,
parse_fn::Function,
source::LineNumberNode,
)
error_fn(str...) = _macro_error(macro_name, args, source, str...)

# The positional args can't be `args` otherwise `error_fn` excludes keyword args
pos_args, kw_args, requested_container = Containers._extract_kw_args(args)

# Initial check of the positional arguments and get the model
if length(pos_args) < 2
if length(kw_args) > 0
error_fn(
"No constraint expression detected. If you are trying to " *
"construct an equality constraint, use `==` instead of `=`.",
)
else
error_fn("Not enough arguments")
end
end
model = esc(pos_args[1])
y = pos_args[2]
extra = pos_args[3:end]
if Meta.isexpr(args[2], :block)
error_fn(str...) = _macro_error(macro_name, input_args, source, str...)
args, kwargs, container = Containers._extract_kw_args(input_args)
if length(args) < 2 && !isempty(kwargs)
error_fn(
"No constraint expression detected. If you are trying to " *
"construct an equality constraint, use `==` instead of `=`.",
)
elseif length(args) < 2
error_fn("Not enough arguments")
elseif Meta.isexpr(args[2], :block)
error_fn("Invalid syntax. Did you mean to use `@$(macro_name)s`?")
end
model, y, extra = esc(args[1]), args[2], args[3:end]
# Determine if a reference/container argument was given by the user
# There are six cases to consider:
# y | type of y | y.head
Expand All @@ -230,88 +220,72 @@ function _constraint_macro(
# [i = 1:2, j = 1:2; i + j >= 3] | Expr | :vcat
# a constraint expression | Expr | :call or :comparison
if isa(y, Symbol) || Meta.isexpr(y, (:vect, :vcat, :ref, :typed_vcat))
length(extra) >= 1 || error_fn("No constraint expression was given.")
if length(extra) == 0
error_fn("No constraint expression was given.")
end
c = y
x = popfirst!(extra)
anonvar = Meta.isexpr(y, (:vect, :vcat))
is_anonymous = Meta.isexpr(y, (:vect, :vcat))
else
c = gensym()
x = y
anonvar = true
is_anonymous = true
end

# Enforce that only one extra positional argument can be given
if length(extra) > 1
error_fn("Cannot specify more than 1 additional positional argument.")
end

# Prepare the keyword arguments
extra_kw_args = filter(kw_args) do kw
return kw.args[1] != :base_name && kw.args[1] != :set_string_name
end
base_name_kw_args = filter(kw -> kw.args[1] == :base_name, kw_args)
set_string_name_kw_args =
filter(kw -> kw.args[1] == :set_string_name, kw_args)
# Set the base name
name = Containers._get_name(c)
if isempty(base_name_kw_args)
base_name = anonvar ? "" : string(name)
else
base_name = esc(base_name_kw_args[1].args[2])
end

# Strategy: build up the code for add_constraint, and if needed we will wrap
# in a function returning `ConstraintRef`s and give it to `Containers.container`.
idxvars, indices = Containers.build_ref_sets(error_fn, c)
if pos_args[1] in idxvars
index_vars, indices = Containers.build_ref_sets(error_fn, c)
if args[1] in index_vars
error_fn(
"Index $(pos_args[1]) is the same symbol as the model. Use a " *
"Index $(args[1]) is the same symbol as the model. Use a " *
"different name for the index.",
)
end
vectorized, parsecode, buildcall = parsefun(error_fn, x)
_add_positional_args(buildcall, extra)
_add_kw_args(buildcall, extra_kw_args)
if vectorized
buildcall = :(model_convert.($model, $buildcall))
else
buildcall = :(model_convert($model, $buildcall))
end
name_expr = _name_call(base_name, idxvars)
new_name_expr = if isempty(set_string_name_kw_args)
Expr(:if, :(set_string_names_on_creation($model)), name_expr, "")
else
Expr(:if, esc(set_string_name_kw_args[1].args[2]), name_expr, "")
end
if vectorized
# For vectorized constraints, we set every constraint to have the same
# name.
constraintcall = :(add_constraint.($model, $buildcall, $new_name_expr))
is_vectorized, parse_code, build_call = parse_fn(error_fn, x)
_add_positional_args(build_call, extra)
_add_kw_args(build_call, kwargs; exclude = [:base_name, :set_string_name])
base_name = _get_kwarg_value(
kwargs,
:base_name;
default = is_anonymous ? "" : string(Containers._get_name(c)),
)
set_name_flag = _get_kwarg_value(
kwargs,
:set_string_name;
default = :(set_string_names_on_creation($model)),
)
name_expr = Expr(:if, set_name_flag, _name_call(base_name, index_vars), "")
code = if is_vectorized
quote
$parse_code
# These broadcast calls need to be nested so that the operators
# are fused. Some broadcasted errors result if you put them on
# different lines.
add_constraint.(
$model,
model_convert.($model, $build_call),
$name_expr,
)
end
else
constraintcall = :(add_constraint($model, $buildcall, $new_name_expr))
end
code = quote
$parsecode
$constraintcall
quote
$parse_code
build = model_convert($model, $build_call)
add_constraint($model, build, $name_expr)
end
end
creation_code =
Containers.container_code(idxvars, indices, code, requested_container)
Containers.container_code(index_vars, indices, code, container)
# Wrap the entire code block in a let statement to make the model act as
# a type stable local variable.
creation_code = _wrap_let(model, creation_code)
if anonvar
# Anonymous constraint, no need to register it in the model-level
# dictionary nor to assign it to a variable in the user scope.
# We simply return the constraint reference
macro_code = creation_code
macro_code = if is_anonymous
creation_code
else
# We register the constraint reference to its name and
# we assign it to a variable in the local scope of this name
variable = gensym()
macro_code = _macro_assign_and_return(
_macro_assign_and_return(
creation_code,
variable,
name;
gensym(),
Containers._get_name(c);
model_for_registering = model,
)
end
Expand Down

0 comments on commit 4b1269f

Please sign in to comment.