diff --git a/src/macros.jl b/src/macros.jl index 60fc8ae17e5..620ddceef8c 100644 --- a/src/macros.jl +++ b/src/macros.jl @@ -343,9 +343,12 @@ julia> call :(f(1, a = 2, \$(Expr(:escape, :(\$(Expr(:kw, :b, 3))))), \$(Expr(:escape, :(\$(Expr(:kw, :c, 4))))))) ``` """ -function _add_kw_args(call, kw_args) +function _add_kw_args(call, kw_args; exclude = Symbol[]) for kw in kw_args @assert Meta.isexpr(kw, :(=)) + if kw.args[1] in exclude + continue + end push!(call.args, esc(Expr(:kw, kw.args...))) end return @@ -662,6 +665,15 @@ function _wrap_let(model, code) return code end +function _get_kwarg_value(kwargs, key::Symbol; default = nothing) + for kwarg in kwargs + if kwarg.args[1] == key + return esc(kwarg.args[2]) + end + end + return default +end + include("macros/@objective.jl") include("macros/@expression.jl") include("macros/@constraint.jl") diff --git a/src/macros/@constraint.jl b/src/macros/@constraint.jl index e48f2c0b26b..765669c1acf 100644 --- a/src/macros/@constraint.jl +++ b/src/macros/@constraint.jl @@ -4,11 +4,11 @@ # file, You can obtain one at https://mozilla.org/MPL/2.0/. """ - @constraint(m::GenericModel, expr, kw_args...) + @constraint(model::GenericModel, expr, kwargs...) Add a constraint described by the expression `expr`. - @constraint(m::GenericModel, ref[i=..., j=..., ...], expr, kw_args...) + @constraint(model::GenericModel, ref[i=..., j=..., ...], expr, kwargs...) Add a group of constraints described by the expression `expr` parametrized by `i`, `j`, ... @@ -21,24 +21,29 @@ The expression `expr` can either be [`RotatedSecondOrderCone`](@ref) and [`PSDCone`](@ref), e.g. `@constraint(model, [1, x-1, y-2] in SecondOrderCone())` constrains the norm of `[x-1, y-2]` be less than 1; + * of the form `a sign b`, where `sign` is one of `==`, `≥`, `>=`, `≤` and `<=` building the single constraint enforcing the comparison to hold for the - expression `a` and `b`, e.g. `@constraint(m, x^2 + y^2 == 1)` constrains `x` - and `y` to lie on the unit circle; + expression `a` and `b`, e.g. `@constraint(model, x^2 + y^2 == 1)` constrains + `x` and `y` to lie on the unit circle; + * of the form `a ≤ b ≤ c` or `a ≥ b ≥ c` (where `≤` and `<=` (resp. `≥` and `>=`) can be used interchangeably) constraining the paired the expression `b` to lie between `a` and `c`; + * of the forms `@constraint(m, a .sign b)` or `@constraint(m, a .sign b .sign c)` which broadcast the constraint creation to each element of the vectors. -The recognized keyword arguments in `kw_args` are the following: +The recognized keyword arguments in `kwargs` are the following: * `base_name`: Sets the name prefix used to generate constraint names. It corresponds to the constraint name for scalar constraints, otherwise, the constraint names are set to `base_name[...]` for each index `...` of the axes `axes`. + * `container`: Specify the container type. + * `set_string_name::Bool = true`: control whether to set the [`MOI.ConstraintName`](@ref) attribute. Passing `set_string_name = false` can improve performance. @@ -47,8 +52,10 @@ The recognized keyword arguments in `kw_args` are the following: Each constraint will be created using `add_constraint(m, build_constraint(error_fn, func, set))` where + * `error_fn` is an error function showing the constraint call in addition to the error message given as argument, + * `func` is the expression that is constrained * and `set` is the set in which it is constrained to belong. @@ -68,8 +75,8 @@ For extensions that need to create constraints with more information than just `func` and `set`, an additional positional argument can be specified to `@constraint` that will then be passed on `build_constraint`. Hence, we can enable this syntax by defining extensions of -`build_constraint(error_fn, func, set, my_arg; kw_args...)`. This produces the -user syntax: `@constraint(model, ref[...], expr, my_arg, kw_args...)`. +`build_constraint(error_fn, func, set, my_arg; kwargs...)`. This produces the +user syntax: `@constraint(model, ref[...], expr, my_arg, kwargs...)`. """ macro constraint(args...) return _constraint_macro(args, :constraint, parse_constraint, __source__) @@ -139,46 +146,38 @@ julia> @build_constraint(2x >= 1) ScalarConstraint{AffExpr, MathOptInterface.GreaterThan{Float64}}(2 x, MathOptInterface.GreaterThan{Float64}(1.0)) ``` """ -macro build_constraint(constraint_expr) +macro build_constraint(arg) function error_fn(str...) - return _macro_error( - :build_constraint, - (constraint_expr,), - __source__, - str..., - ) + return _macro_error(:build_constraint, (arg,), __source__, str...) end - - if isa(constraint_expr, Symbol) + if arg isa Symbol error_fn( - "Incomplete constraint specification $constraint_expr. " * + "Incomplete constraint specification $arg. " * "Are you missing a comparison (<=, >=, or ==)?", ) end - - is_vectorized, parse_code, build_call = - parse_constraint(error_fn, constraint_expr) - result_variable = gensym() - code = quote + _, parse_code, build_call = parse_constraint(error_fn, arg) + return quote $parse_code - $result_variable = $build_call + $build_call end - - return code end """ _constraint_macro( - args, macro_name::Symbol, parsefun::Function, source::LineNumberNode + args, + macro_name::Symbol, + parse_fn::Function, + source::LineNumberNode, ) Returns the code for the macro `@constraint args...` of syntax ```julia -@constraint(model, con, extra_arg, kw_args...) # single constraint -@constraint(model, ref, con, extra_arg, kw_args...) # group of constraints +@constraint(model, con, extra_arg, kwargs...) # single constraint +@constraint(model, ref, con, extra_arg, kwargs...) # group of constraints ``` -The expression `con` is parsed by `parsefun` which returns a `build_constraint` +The expression `con` is parsed by `parse_fn` which returns a `build_constraint` call code that, when executed, returns an `AbstractConstraint`. The macro keyword arguments (except the `container` keyword argument which is used to determine the container type) are added to the `build_constraint` call. The @@ -192,33 +191,24 @@ called from in the user's code. One way of generating this is via the hidden variable `__source__`. """ function _constraint_macro( - args, + input_args, macro_name::Symbol, - parsefun::Function, + parse_fn::Function, source::LineNumberNode, ) - error_fn(str...) = _macro_error(macro_name, args, source, str...) - - # The positional args can't be `args` otherwise `error_fn` excludes keyword args - pos_args, kw_args, requested_container = Containers._extract_kw_args(args) - - # Initial check of the positional arguments and get the model - if length(pos_args) < 2 - if length(kw_args) > 0 - error_fn( - "No constraint expression detected. If you are trying to " * - "construct an equality constraint, use `==` instead of `=`.", - ) - else - error_fn("Not enough arguments") - end - end - model = esc(pos_args[1]) - y = pos_args[2] - extra = pos_args[3:end] - if Meta.isexpr(args[2], :block) + error_fn(str...) = _macro_error(macro_name, input_args, source, str...) + args, kwargs, container = Containers._extract_kw_args(input_args) + if length(args) < 2 && !isempty(kwargs) + error_fn( + "No constraint expression detected. If you are trying to " * + "construct an equality constraint, use `==` instead of `=`.", + ) + elseif length(args) < 2 + error_fn("Not enough arguments") + elseif Meta.isexpr(args[2], :block) error_fn("Invalid syntax. Did you mean to use `@$(macro_name)s`?") end + model, y, extra = esc(args[1]), args[2], args[3:end] # Determine if a reference/container argument was given by the user # There are six cases to consider: # y | type of y | y.head @@ -230,88 +220,72 @@ function _constraint_macro( # [i = 1:2, j = 1:2; i + j >= 3] | Expr | :vcat # a constraint expression | Expr | :call or :comparison if isa(y, Symbol) || Meta.isexpr(y, (:vect, :vcat, :ref, :typed_vcat)) - length(extra) >= 1 || error_fn("No constraint expression was given.") + if length(extra) == 0 + error_fn("No constraint expression was given.") + end c = y x = popfirst!(extra) - anonvar = Meta.isexpr(y, (:vect, :vcat)) + is_anonymous = Meta.isexpr(y, (:vect, :vcat)) else c = gensym() x = y - anonvar = true + is_anonymous = true end - - # Enforce that only one extra positional argument can be given if length(extra) > 1 error_fn("Cannot specify more than 1 additional positional argument.") end - - # Prepare the keyword arguments - extra_kw_args = filter(kw_args) do kw - return kw.args[1] != :base_name && kw.args[1] != :set_string_name - end - base_name_kw_args = filter(kw -> kw.args[1] == :base_name, kw_args) - set_string_name_kw_args = - filter(kw -> kw.args[1] == :set_string_name, kw_args) - # Set the base name - name = Containers._get_name(c) - if isempty(base_name_kw_args) - base_name = anonvar ? "" : string(name) - else - base_name = esc(base_name_kw_args[1].args[2]) - end - - # Strategy: build up the code for add_constraint, and if needed we will wrap - # in a function returning `ConstraintRef`s and give it to `Containers.container`. - idxvars, indices = Containers.build_ref_sets(error_fn, c) - if pos_args[1] in idxvars + index_vars, indices = Containers.build_ref_sets(error_fn, c) + if args[1] in index_vars error_fn( - "Index $(pos_args[1]) is the same symbol as the model. Use a " * + "Index $(args[1]) is the same symbol as the model. Use a " * "different name for the index.", ) end - vectorized, parsecode, buildcall = parsefun(error_fn, x) - _add_positional_args(buildcall, extra) - _add_kw_args(buildcall, extra_kw_args) - if vectorized - buildcall = :(model_convert.($model, $buildcall)) - else - buildcall = :(model_convert($model, $buildcall)) - end - name_expr = _name_call(base_name, idxvars) - new_name_expr = if isempty(set_string_name_kw_args) - Expr(:if, :(set_string_names_on_creation($model)), name_expr, "") - else - Expr(:if, esc(set_string_name_kw_args[1].args[2]), name_expr, "") - end - if vectorized - # For vectorized constraints, we set every constraint to have the same - # name. - constraintcall = :(add_constraint.($model, $buildcall, $new_name_expr)) + is_vectorized, parse_code, build_call = parse_fn(error_fn, x) + _add_positional_args(build_call, extra) + _add_kw_args(build_call, kwargs; exclude = [:base_name, :set_string_name]) + base_name = _get_kwarg_value( + kwargs, + :base_name; + default = is_anonymous ? "" : string(Containers._get_name(c)), + ) + set_name_flag = _get_kwarg_value( + kwargs, + :set_string_name; + default = :(set_string_names_on_creation($model)), + ) + name_expr = Expr(:if, set_name_flag, _name_call(base_name, index_vars), "") + code = if is_vectorized + quote + $parse_code + # These broadcast calls need to be nested so that the operators + # are fused. Some broadcasted errors result if you put them on + # different lines. + add_constraint.( + $model, + model_convert.($model, $build_call), + $name_expr, + ) + end else - constraintcall = :(add_constraint($model, $buildcall, $new_name_expr)) - end - code = quote - $parsecode - $constraintcall + quote + $parse_code + build = model_convert($model, $build_call) + add_constraint($model, build, $name_expr) + end end creation_code = - Containers.container_code(idxvars, indices, code, requested_container) + Containers.container_code(index_vars, indices, code, container) # Wrap the entire code block in a let statement to make the model act as # a type stable local variable. creation_code = _wrap_let(model, creation_code) - if anonvar - # Anonymous constraint, no need to register it in the model-level - # dictionary nor to assign it to a variable in the user scope. - # We simply return the constraint reference - macro_code = creation_code + macro_code = if is_anonymous + creation_code else - # We register the constraint reference to its name and - # we assign it to a variable in the local scope of this name - variable = gensym() - macro_code = _macro_assign_and_return( + _macro_assign_and_return( creation_code, - variable, - name; + gensym(), + Containers._get_name(c); model_for_registering = model, ) end