Skip to content

Commit

Permalink
Add parse_macro_arguments to unify how we handle macro inputs (#3616)
Browse files Browse the repository at this point in the history
  • Loading branch information
odow authored Dec 11, 2023
1 parent df46ad1 commit 65115f3
Show file tree
Hide file tree
Showing 9 changed files with 206 additions and 200 deletions.
62 changes: 41 additions & 21 deletions src/Containers/macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,46 @@ function _get_name(c::Expr)
return error("Expression $c cannot be used as a name.")
end

function _reorder_parameters(args)
if !Meta.isexpr(args[1], :parameters)
return args
end
args = collect(args)
p = popfirst!(args)
for arg in p.args
@assert arg.head == :kw
push!(args, Expr(:(=), arg.args[1], arg.args[2]))
end
return args
end

"""
_extract_kw_args(args)
parse_macro_arguments(error_fn::Function, args)
Process the arguments to a macro, separating out the keyword arguments.
Returns a `Tuple{Vector{Any},Dict{Symbol,Any}}` containing the ordered
positional arguments and a dictionary mapping the keyword arguments.
Return a tuple of (flat_arguments, keyword arguments, and requested_container),
where `requested_container` is a symbol to be passed to `container_code`.
This specially handles the distinction of `@foo(key = value)` and
`@foo(; key = value)` in macros.
Throws an error if mulitple keyword arguments are passed with the same name.
"""
function _extract_kw_args(args)
flat_args, kw_args, requested_container = Any[], Any[], :Auto
for arg in args
if Meta.isexpr(arg, :(=))
if arg.args[1] == :container
requested_container = arg.args[2]
else
push!(kw_args, arg)
function parse_macro_arguments(error_fn::Function, args)
pos_args, kw_args = Any[], Dict{Symbol,Any}()
for arg in _reorder_parameters(args)
if Meta.isexpr(arg, :(=), 2)
if haskey(kw_args, arg.args[1])
error_fn(
"the keyword argument `$(arg.args[1])` was given " *
"multiple times.",
)
end
kw_args[arg.args[1]] = arg.args[2]
else
push!(flat_args, arg)
push!(pos_args, arg)
end
end
return flat_args, kw_args, requested_container
return pos_args, kw_args
end

"""
Expand Down Expand Up @@ -381,14 +399,16 @@ SparseAxisArray{Int64, 2, Tuple{Int64, Int64}} with 6 entries:
[3, 3] = 6
```
"""
macro container(args...)
args, kw_args, requested_container = _extract_kw_args(args)
macro container(input_args...)
args, kw_args = parse_macro_arguments(error, input_args)
container = get(kw_args, :container, :Auto)
@assert length(args) == 2
@assert isempty(kw_args)
var, value = args
index_vars, indices = build_ref_sets(error, var)
code = container_code(index_vars, indices, esc(value), requested_container)
name = _get_name(var)
for key in keys(kw_args)
@assert key == :container
end
index_vars, indices = build_ref_sets(error, args[1])
code = container_code(index_vars, indices, esc(args[2]), container)
name = _get_name(args[1])
if name === nothing
return code
end
Expand Down
81 changes: 10 additions & 71 deletions src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -323,45 +323,24 @@ function model_convert(
return model_convert.(model, x)
end

"""
_add_kw_args(call, kw_args)
Add the keyword arguments `kw_args` to the function call expression `call`,
escaping the expressions. The elements of `kw_args` should be expressions of the
form `:(key = value)`. The `kw_args` vector can be extracted from the arguments
of a macro with [`Containers._extract_kw_args`](@ref).
## Example
```jldoctest
julia> call = :(f(1, a=2))
:(f(1, a = 2))
julia> JuMP._add_kw_args(call, [:(b=3), :(c=4)])
julia> call
:(f(1, a = 2, \$(Expr(:escape, :(\$(Expr(:kw, :b, 3))))), \$(Expr(:escape, :(\$(Expr(:kw, :c, 4)))))))
```
"""
function _add_kw_args(call, kw_args; exclude = Symbol[])
for kw in kw_args
@assert Meta.isexpr(kw, :(=))
if kw.args[1] in exclude
function _add_keyword_args(call::Expr, kwargs::Dict; exclude = Symbol[])
for (key, value) in kwargs
if key in exclude
continue
end
push!(call.args, esc(Expr(:kw, kw.args...)))
push!(call.args, esc(Expr(:kw, key, value)))
end
return
end

"""
_add_positional_args(call, args)::Nothing
_add_positional_args(call::Expr, args::Vector{Any})::Nothing
Add the positional arguments `args` to the function call expression `call`,
escaping each argument expression. The elements of `args` should be ones that
were extracted via [`Containers._extract_kw_args`](@ref) and had appropriate
arguments filtered out (e.g., the model argument). This is able to incorporate
additional positional arguments to `call`s that already have keyword arguments.
escaping each argument expression.
This function is able to incorporate additional positional arguments to `call`s
that already have keyword arguments.
## Example
Expand All @@ -375,7 +354,7 @@ julia> call
:(f(1, $(Expr(:escape, :x)), a = 2))
```
"""
function _add_positional_args(call, args)
function _add_positional_args(call::Expr, args::Vector)
call_args = call.args
if Meta.isexpr(call, :.)
# call is broadcasted
Expand All @@ -392,19 +371,6 @@ function _add_positional_args(call, args)
return
end

function _reorder_parameters(args)
if !Meta.isexpr(args[1], :parameters)
return args
end
args = collect(args)
p = popfirst!(args)
for arg in p.args
@assert arg.head == :kw
push!(args, Expr(:(=), arg.args[1], arg.args[2]))
end
return args
end

_valid_model(::AbstractModel, ::Any) = nothing

function _valid_model(m::M, name) where {M}
Expand Down Expand Up @@ -654,33 +620,6 @@ function _wrap_let(model, code)
return code
end

function _get_kwarg_value(
error_fn,
kwargs,
key::Symbol;
default = nothing,
escape::Bool = true,
)
index, count = 0, 0
for (i, kwarg) in enumerate(kwargs)
if kwarg.args[1] == key
count += 1
index = i
end
end
if count == 0
return default
elseif count == 1
if escape
return esc(kwargs[index].args[2])
else
return kwargs[index].args[2]
end
else
error_fn("`$key` keyword argument was given $count times.")
end
end

include("macros/@objective.jl")
include("macros/@expression.jl")
include("macros/@constraint.jl")
Expand Down
32 changes: 28 additions & 4 deletions src/macros/@NL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,30 @@ function _parse_generator_expression(code, x, operators)
return y
end

"""
_extract_kw_args(args)
Process the arguments to a macro, separating out the keyword arguments.
Return a tuple of (flat_arguments, keyword arguments, and requested_container),
where `requested_container` is a symbol to be passed to `container_code`.
"""
function _extract_kw_args(args)
flat_args, kw_args, requested_container = Any[], Any[], :Auto
for arg in args
if Meta.isexpr(arg, :(=))
if arg.args[1] == :container
requested_container = arg.args[2]
else
push!(kw_args, arg)
end
else
push!(flat_args, arg)
end
end
return flat_args, kw_args, requested_container
end

###
### @NLobjective(s)
###
Expand Down Expand Up @@ -252,7 +276,7 @@ macro NLobjective(model, sense, x)
function error_fn(str...)
return _macro_error(:NLobjective, (model, sense, x), __source__, str...)
end
sense_expr = _moi_sense(error_fn, sense)
sense_expr = _parse_moi_sense(error_fn, sense)
esc_model = esc(model)
parsing_code, expr = _parse_nonlinear_expression(esc_model, x)
code = quote
Expand Down Expand Up @@ -299,7 +323,7 @@ macro NLconstraint(m, x, args...)
# Two formats:
# - @NLconstraint(m, a*x <= 5)
# - @NLconstraint(m, myref[a=1:5], sin(x^a) <= 5)
extra, kw_args, requested_container = Containers._extract_kw_args(args)
extra, kw_args, requested_container = _extract_kw_args(args)
if length(extra) > 1 || length(kw_args) > 0
error_fn("too many arguments.")
end
Expand Down Expand Up @@ -413,7 +437,7 @@ subexpression[5]: log(1.0 + (exp(subexpression[2]) + exp(subexpression[3])))
"""
macro NLexpression(args...)
error_fn(str...) = _macro_error(:NLexpression, args, __source__, str...)
args, kw_args, requested_container = Containers._extract_kw_args(args)
args, kw_args, requested_container = _extract_kw_args(args)
if length(args) <= 1
error_fn(
"To few arguments ($(length(args))); must pass the model and nonlinear expression as arguments.",
Expand Down Expand Up @@ -577,7 +601,7 @@ macro NLparameter(model, args...)
function error_fn(str...)
return _macro_error(:NLparameter, (model, args...), __source__, str...)
end
pos_args, kw_args, requested_container = Containers._extract_kw_args(args)
pos_args, kw_args, requested_container = _extract_kw_args(args)
value = missing
for arg in kw_args
if arg.args[1] == :value
Expand Down
45 changes: 27 additions & 18 deletions src/macros/@constraint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ Other keyword arguments may be supported by JuMP extensions.
"""
macro constraint(input_args...)
error_fn(str...) = _macro_error(:constraint, input_args, __source__, str...)
args, kwargs, container = Containers._extract_kw_args(input_args)
args, kwargs = Containers.parse_macro_arguments(error_fn, input_args)
if length(args) < 2 && !isempty(kwargs)
error_fn(
"No constraint expression detected. If you are trying to " *
Expand All @@ -82,13 +82,12 @@ macro constraint(input_args...)
# [1:2] | Expr | :vect
# [i = 1:2, j = 1:2; i + j >= 3] | Expr | :vcat
# a constraint expression | Expr | :call or :comparison
c, x = if y isa Symbol || Meta.isexpr(y, (:vect, :vcat, :ref, :typed_vcat))
c, x = nothing, y
if y isa Symbol || Meta.isexpr(y, (:vect, :vcat, :ref, :typed_vcat))
if length(extra) == 0
error_fn("No constraint expression was given.")
end
y, popfirst!(extra)
else
nothing, y
c, x = y, popfirst!(extra)
end
if length(extra) > 1
error_fn("Cannot specify more than 1 additional positional argument.")
Expand All @@ -102,20 +101,30 @@ macro constraint(input_args...)
end
is_vectorized, parse_code, build_call = parse_constraint(error_fn, x)
_add_positional_args(build_call, extra)
_add_kw_args(build_call, kwargs; exclude = [:base_name, :set_string_name])
base_name = _get_kwarg_value(
error_fn,
kwargs,
:base_name;
default = string(something(Containers._get_name(c), "")),
)
set_name_flag = _get_kwarg_value(
error_fn,
kwargs,
:set_string_name;
default = :(set_string_names_on_creation($model)),
_add_keyword_args(
build_call,
kwargs;
exclude = [:base_name, :container, :set_string_name],
)
name_expr = :($set_name_flag ? $(_name_call(base_name, index_vars)) : "")
# ; base_name
default_base_name = string(something(Containers._get_name(c), ""))
base_name = get(kwargs, :base_name, default_base_name)
if base_name isa Expr
base_name = esc(base_name)
end
# ; container
# There is no need to escape this one.
container = get(kwargs, :container, :Auto)
# ; set_string_name
name_expr = _name_call(base_name, index_vars)
if name_expr != ""
set_string_name = if haskey(kwargs, :set_string_name)
esc(kwargs[:set_string_name])
else
:(set_string_names_on_creation($model))
end
name_expr = :($set_string_name ? $name_expr : "")
end
code = if is_vectorized
quote
$parse_code
Expand Down
17 changes: 11 additions & 6 deletions src/macros/@expression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,17 @@ julia> expr = @expression(model, [i in 1:3], i * sum(x[j] for j in 1:3))
"""
macro expression(input_args...)
error_fn(str...) = _macro_error(:expression, input_args, __source__, str...)
args, kw_args, container = Containers._extract_kw_args(input_args)
args, kwargs = Containers.parse_macro_arguments(error_fn, input_args)
if !(2 <= length(args) <= 3)
error_fn("needs at least two arguments.")
elseif !isempty(kw_args)
error_fn("unrecognized keyword argument")
error_fn("expected 2 or 3 positional arguments, got $(length(args)).")
elseif Meta.isexpr(args[2], :block)
error_fn("Invalid syntax. Did you mean to use `@expressions`?")
elseif !isempty(kwargs)
for key in keys(kwargs)
if key != :container
error_fn("unsupported keyword argument `$key`.")
end
end
end
name_expr = length(args) == 3 ? args[2] : nothing
index_vars, indices = Containers.build_ref_sets(error_fn, name_expr)
Expand All @@ -81,14 +85,15 @@ macro expression(input_args...)
"different name for the index.",
)
end
expr_var, build_code = _rewrite_expression(args[end])
model = esc(args[1])
expr, build_code = _rewrite_expression(args[end])
code = quote
$build_code
# Don't leak a `_MA.Zero` if the expression is an empty summation, or
# other structure that returns `_MA.Zero()`.
_replace_zero($model, $expr_var)
_replace_zero($model, $expr)
end
container = get(kwargs, :container, :Auto)
return _finalize_macro(
model,
Containers.container_code(index_vars, indices, code, container),
Expand Down
Loading

0 comments on commit 65115f3

Please sign in to comment.