diff --git a/src/macros/@expression.jl b/src/macros/@expression.jl index d3f06239d4f..9342030b6b7 100644 --- a/src/macros/@expression.jl +++ b/src/macros/@expression.jl @@ -52,59 +52,47 @@ julia> expr = @expression(model, [i in 1:3], i * sum(x[j] for j in 1:3)) 3 x[1] + 3 x[2] + 3 x[3] ``` """ -macro expression(args...) - error_fn(str...) = _macro_error(:expression, args, __source__, str...) - args, kw_args, requested_container = Containers._extract_kw_args(args) - if length(args) == 3 - m = esc(args[1]) - c = args[2] - x = args[3] - elseif length(args) == 2 - m = esc(args[1]) - c = gensym() - x = args[2] - else +macro expression(input_args...) + error_fn(str...) = _macro_error(:expression, input_args, __source__, str...) + args, kw_args, container = Containers._extract_kw_args(input_args) + if !(2 <= length(args) <= 3) error_fn("needs at least two arguments.") - end - length(kw_args) == 0 || error_fn("unrecognized keyword argument") - if Meta.isexpr(args[2], :block) + elseif !isempty(kw_args) + error_fn("unrecognized keyword argument") + elseif Meta.isexpr(args[2], :block) error_fn("Invalid syntax. Did you mean to use `@expressions`?") end - anonvar = - Meta.isexpr(c, :vect) || Meta.isexpr(c, :vcat) || length(args) == 2 - variable = gensym() - - idxvars, indices = Containers.build_ref_sets(error_fn, c) - if args[1] in idxvars + name_expr = length(args) == 3 ? args[2] : gensym() + index_vars, indices = Containers.build_ref_sets(error_fn, name_expr) + if args[1] in index_vars error_fn( "Index $(args[1]) is the same symbol as the model. Use a " * "different name for the index.", ) end - expr_var, build_code = _rewrite_expression(x) + expr_var, build_code = _rewrite_expression(args[end]) + model = esc(args[1]) code = quote $build_code # Don't leak a `_MA.Zero` if the expression is an empty summation, or # other structure that returns `_MA.Zero()`. - _replace_zero($m, $expr_var) + _replace_zero($model, $expr_var) end - code = - Containers.container_code(idxvars, indices, code, requested_container) + code = Containers.container_code(index_vars, indices, code, container) # Wrap the entire code block in a let statement to make the model act as # a type stable local variable. - code = _wrap_let(m, code) - # don't do anything with the model, but check that it's valid anyway - if anonvar - macro_code = code + code = _wrap_let(model, code) + macro_code = if Meta.isexpr(name_expr, (:vect, :vcat)) || length(args) == 2 + code else - macro_code = _macro_assign_and_return( + _macro_assign_and_return( code, - variable, - Containers._get_name(c); - model_for_registering = m, + gensym(), + Containers._get_name(name_expr); + model_for_registering = model, ) end - return _finalize_macro(m, macro_code, __source__) + return _finalize_macro(model, macro_code, __source__) end """ diff --git a/test/test_macros.jl b/test/test_macros.jl index 90eeb52c88b..57d963721fa 100644 --- a/test/test_macros.jl +++ b/test/test_macros.jl @@ -1712,6 +1712,18 @@ function test_expression_not_enough_arguments() return end +function test_expression_keyword_arguments() + model = Model() + @variable(model, x) + @test_macro_throws( + ErrorException( + "In `@expression(model, x, foo = 1)`: unrecognized keyword argument", + ), + @expression(model, x, foo = 1), + ) + return +end + function test_build_constraint_invalid() model = Model() @variable(model, x)