Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tidy macros/@constraint.jl #3612

Merged
merged 3 commits into from
Dec 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -343,9 +343,12 @@ julia> call
:(f(1, a = 2, \$(Expr(:escape, :(\$(Expr(:kw, :b, 3))))), \$(Expr(:escape, :(\$(Expr(:kw, :c, 4)))))))
```
"""
function _add_kw_args(call, kw_args)
function _add_kw_args(call, kw_args; exclude = Symbol[])
for kw in kw_args
@assert Meta.isexpr(kw, :(=))
if kw.args[1] in exclude
continue
end
push!(call.args, esc(Expr(:kw, kw.args...)))
end
return
Expand Down Expand Up @@ -662,6 +665,15 @@ function _wrap_let(model, code)
return code
end

function _get_kwarg_value(kwargs, key::Symbol; default = nothing)
for kwarg in kwargs
if kwarg.args[1] == key
return esc(kwarg.args[2])
end
end
return default
end

include("macros/@objective.jl")
include("macros/@expression.jl")
include("macros/@constraint.jl")
Expand Down
198 changes: 86 additions & 112 deletions src/macros/@constraint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
# file, You can obtain one at https://mozilla.org/MPL/2.0/.

"""
@constraint(m::GenericModel, expr, kw_args...)
@constraint(model::GenericModel, expr, kwargs...)
Add a constraint described by the expression `expr`.
@constraint(m::GenericModel, ref[i=..., j=..., ...], expr, kw_args...)
@constraint(model::GenericModel, ref[i=..., j=..., ...], expr, kwargs...)
Add a group of constraints described by the expression `expr` parametrized by
`i`, `j`, ...
Expand All @@ -21,24 +21,29 @@ The expression `expr` can either be
[`RotatedSecondOrderCone`](@ref) and [`PSDCone`](@ref), e.g.
`@constraint(model, [1, x-1, y-2] in SecondOrderCone())` constrains the norm
of `[x-1, y-2]` be less than 1;
* of the form `a sign b`, where `sign` is one of `==`, `≥`, `>=`, `≤` and
`<=` building the single constraint enforcing the comparison to hold for the
expression `a` and `b`, e.g. `@constraint(m, x^2 + y^2 == 1)` constrains `x`
and `y` to lie on the unit circle;
expression `a` and `b`, e.g. `@constraint(model, x^2 + y^2 == 1)` constrains
`x` and `y` to lie on the unit circle;
* of the form `a ≤ b ≤ c` or `a ≥ b ≥ c` (where `≤` and `<=` (resp. `≥` and
`>=`) can be used interchangeably) constraining the paired the expression
`b` to lie between `a` and `c`;
* of the forms `@constraint(m, a .sign b)` or
`@constraint(m, a .sign b .sign c)` which broadcast the constraint creation to
each element of the vectors.
The recognized keyword arguments in `kw_args` are the following:
The recognized keyword arguments in `kwargs` are the following:
* `base_name`: Sets the name prefix used to generate constraint names. It
corresponds to the constraint name for scalar constraints, otherwise, the
constraint names are set to `base_name[...]` for each index `...` of the axes
`axes`.
* `container`: Specify the container type.
* `set_string_name::Bool = true`: control whether to set the
[`MOI.ConstraintName`](@ref) attribute. Passing `set_string_name = false` can
improve performance.
Expand All @@ -47,8 +52,10 @@ The recognized keyword arguments in `kw_args` are the following:
Each constraint will be created using
`add_constraint(m, build_constraint(error_fn, func, set))` where
* `error_fn` is an error function showing the constraint call in addition to the
error message given as argument,
* `func` is the expression that is constrained
* and `set` is the set in which it is constrained to belong.
Expand All @@ -68,8 +75,8 @@ For extensions that need to create constraints with more information than just
`func` and `set`, an additional positional argument can be specified to
`@constraint` that will then be passed on `build_constraint`. Hence, we can
enable this syntax by defining extensions of
`build_constraint(error_fn, func, set, my_arg; kw_args...)`. This produces the
user syntax: `@constraint(model, ref[...], expr, my_arg, kw_args...)`.
`build_constraint(error_fn, func, set, my_arg; kwargs...)`. This produces the
user syntax: `@constraint(model, ref[...], expr, my_arg, kwargs...)`.
"""
macro constraint(args...)
return _constraint_macro(args, :constraint, parse_constraint, __source__)
Expand Down Expand Up @@ -139,46 +146,38 @@ julia> @build_constraint(2x >= 1)
ScalarConstraint{AffExpr, MathOptInterface.GreaterThan{Float64}}(2 x, MathOptInterface.GreaterThan{Float64}(1.0))
```
"""
macro build_constraint(constraint_expr)
macro build_constraint(arg)
function error_fn(str...)
return _macro_error(
:build_constraint,
(constraint_expr,),
__source__,
str...,
)
return _macro_error(:build_constraint, (arg,), __source__, str...)
end

if isa(constraint_expr, Symbol)
if arg isa Symbol
error_fn(
"Incomplete constraint specification $constraint_expr. " *
"Incomplete constraint specification $arg. " *
"Are you missing a comparison (<=, >=, or ==)?",
)
end

is_vectorized, parse_code, build_call =
parse_constraint(error_fn, constraint_expr)
result_variable = gensym()
code = quote
_, parse_code, build_call = parse_constraint(error_fn, arg)
return quote
$parse_code
$result_variable = $build_call
$build_call
end

return code
end

"""
_constraint_macro(
args, macro_name::Symbol, parsefun::Function, source::LineNumberNode
args,
macro_name::Symbol,
parse_fn::Function,
source::LineNumberNode,
)
Returns the code for the macro `@constraint args...` of syntax
```julia
@constraint(model, con, extra_arg, kw_args...) # single constraint
@constraint(model, ref, con, extra_arg, kw_args...) # group of constraints
@constraint(model, con, extra_arg, kwargs...) # single constraint
@constraint(model, ref, con, extra_arg, kwargs...) # group of constraints
```
The expression `con` is parsed by `parsefun` which returns a `build_constraint`
The expression `con` is parsed by `parse_fn` which returns a `build_constraint`
call code that, when executed, returns an `AbstractConstraint`. The macro
keyword arguments (except the `container` keyword argument which is used to
determine the container type) are added to the `build_constraint` call. The
Expand All @@ -192,33 +191,24 @@ called from in the user's code. One way of generating this is via the hidden
variable `__source__`.
"""
function _constraint_macro(
args,
input_args,
macro_name::Symbol,
parsefun::Function,
parse_fn::Function,
source::LineNumberNode,
)
error_fn(str...) = _macro_error(macro_name, args, source, str...)

# The positional args can't be `args` otherwise `error_fn` excludes keyword args
pos_args, kw_args, requested_container = Containers._extract_kw_args(args)

# Initial check of the positional arguments and get the model
if length(pos_args) < 2
if length(kw_args) > 0
error_fn(
"No constraint expression detected. If you are trying to " *
"construct an equality constraint, use `==` instead of `=`.",
)
else
error_fn("Not enough arguments")
end
end
model = esc(pos_args[1])
y = pos_args[2]
extra = pos_args[3:end]
if Meta.isexpr(args[2], :block)
error_fn(str...) = _macro_error(macro_name, input_args, source, str...)
args, kwargs, container = Containers._extract_kw_args(input_args)
if length(args) < 2 && !isempty(kwargs)
error_fn(
"No constraint expression detected. If you are trying to " *
"construct an equality constraint, use `==` instead of `=`.",
)
elseif length(args) < 2
error_fn("Not enough arguments")
elseif Meta.isexpr(args[2], :block)
error_fn("Invalid syntax. Did you mean to use `@$(macro_name)s`?")
end
model, y, extra = esc(args[1]), args[2], args[3:end]
# Determine if a reference/container argument was given by the user
# There are six cases to consider:
# y | type of y | y.head
Expand All @@ -230,88 +220,72 @@ function _constraint_macro(
# [i = 1:2, j = 1:2; i + j >= 3] | Expr | :vcat
# a constraint expression | Expr | :call or :comparison
if isa(y, Symbol) || Meta.isexpr(y, (:vect, :vcat, :ref, :typed_vcat))
length(extra) >= 1 || error_fn("No constraint expression was given.")
if length(extra) == 0
error_fn("No constraint expression was given.")
end
c = y
x = popfirst!(extra)
anonvar = Meta.isexpr(y, (:vect, :vcat))
is_anonymous = Meta.isexpr(y, (:vect, :vcat))
else
c = gensym()
x = y
anonvar = true
is_anonymous = true
end

# Enforce that only one extra positional argument can be given
if length(extra) > 1
error_fn("Cannot specify more than 1 additional positional argument.")
end

# Prepare the keyword arguments
extra_kw_args = filter(kw_args) do kw
return kw.args[1] != :base_name && kw.args[1] != :set_string_name
end
base_name_kw_args = filter(kw -> kw.args[1] == :base_name, kw_args)
set_string_name_kw_args =
filter(kw -> kw.args[1] == :set_string_name, kw_args)
# Set the base name
name = Containers._get_name(c)
if isempty(base_name_kw_args)
base_name = anonvar ? "" : string(name)
else
base_name = esc(base_name_kw_args[1].args[2])
end

# Strategy: build up the code for add_constraint, and if needed we will wrap
# in a function returning `ConstraintRef`s and give it to `Containers.container`.
idxvars, indices = Containers.build_ref_sets(error_fn, c)
if pos_args[1] in idxvars
index_vars, indices = Containers.build_ref_sets(error_fn, c)
if args[1] in index_vars
error_fn(
"Index $(pos_args[1]) is the same symbol as the model. Use a " *
"Index $(args[1]) is the same symbol as the model. Use a " *
"different name for the index.",
)
end
vectorized, parsecode, buildcall = parsefun(error_fn, x)
_add_positional_args(buildcall, extra)
_add_kw_args(buildcall, extra_kw_args)
if vectorized
buildcall = :(model_convert.($model, $buildcall))
else
buildcall = :(model_convert($model, $buildcall))
end
name_expr = _name_call(base_name, idxvars)
new_name_expr = if isempty(set_string_name_kw_args)
Expr(:if, :(set_string_names_on_creation($model)), name_expr, "")
else
Expr(:if, esc(set_string_name_kw_args[1].args[2]), name_expr, "")
end
if vectorized
# For vectorized constraints, we set every constraint to have the same
# name.
constraintcall = :(add_constraint.($model, $buildcall, $new_name_expr))
is_vectorized, parse_code, build_call = parse_fn(error_fn, x)
_add_positional_args(build_call, extra)
_add_kw_args(build_call, kwargs; exclude = [:base_name, :set_string_name])
base_name = _get_kwarg_value(
kwargs,
:base_name;
default = is_anonymous ? "" : string(Containers._get_name(c)),
)
set_name_flag = _get_kwarg_value(
kwargs,
:set_string_name;
default = :(set_string_names_on_creation($model)),
)
name_expr = Expr(:if, set_name_flag, _name_call(base_name, index_vars), "")
code = if is_vectorized
quote
$parse_code
# These broadcast calls need to be nested so that the operators
# are fused. Some broadcasted errors result if you put them on
# different lines.
add_constraint.(
$model,
model_convert.($model, $build_call),
$name_expr,
)
end
else
constraintcall = :(add_constraint($model, $buildcall, $new_name_expr))
end
code = quote
$parsecode
$constraintcall
quote
$parse_code
build = model_convert($model, $build_call)
add_constraint($model, build, $name_expr)
end
end
creation_code =
Containers.container_code(idxvars, indices, code, requested_container)
Containers.container_code(index_vars, indices, code, container)
# Wrap the entire code block in a let statement to make the model act as
# a type stable local variable.
creation_code = _wrap_let(model, creation_code)
if anonvar
# Anonymous constraint, no need to register it in the model-level
# dictionary nor to assign it to a variable in the user scope.
# We simply return the constraint reference
macro_code = creation_code
macro_code = if is_anonymous
creation_code
else
# We register the constraint reference to its name and
# we assign it to a variable in the local scope of this name
variable = gensym()
macro_code = _macro_assign_and_return(
_macro_assign_and_return(
creation_code,
variable,
name;
gensym(),
Containers._get_name(c);
model_for_registering = model,
)
end
Expand Down
Loading