diff --git a/src/macros.jl b/src/macros.jl index 209c3ed3f41..88d7646997a 100644 --- a/src/macros.jl +++ b/src/macros.jl @@ -474,3 +474,30 @@ for file in readdir(joinpath(@__DIR__, "macros")) include(joinpath(@__DIR__, "macros", file)) end end + +# These methods must come after the macros are included, because they use both +# `@variable` and `@constraint`. + +function _build_subexpression( + ::Function, + model::AbstractModel, + expr::AbstractJuMPScalar, + name::String, +) + y = @variable(model) + set_name(y, name) + @constraint(model, y == expr) + return y +end + +function _build_subexpression( + ::Function, + model::AbstractModel, + expr::Array{<:AbstractJuMPScalar}, + name::String, +) + y = [@variable(model) for _ in expr] + set_name.(y, name) + @constraint(model, y .== expr) + return y +end diff --git a/src/macros/@expression.jl b/src/macros/@expression.jl index 7250a01f9cf..cda19ca40b1 100644 --- a/src/macros/@expression.jl +++ b/src/macros/@expression.jl @@ -69,24 +69,30 @@ macro expression(input_args...) error_fn, input_args; num_positional_args = 2:3, - valid_kwargs = [:container], + valid_kwargs = [:container, :subexpression], ) if Meta.isexpr(args[2], :block) error_fn("Invalid syntax. Did you mean to use `@expressions`?") end + is_subexpression = get(kwargs, :subexpression, false) name_expr = length(args) == 3 ? args[2] : nothing name, index_vars, indices = Containers.parse_ref_sets( error_fn, name_expr; invalid_index_variables = [args[1]], ) + name_expr = Containers.build_name_expr(name, index_vars, kwargs) model = esc(args[1]) expr, build_code = _rewrite_expression(args[end]) code = quote $build_code # Don't leak a `_MA.Zero` if the expression is an empty summation, or # other structure that returns `_MA.Zero()`. - _replace_zero($model, $expr) + if $is_subexpression + _build_subexpression($error_fn, $model, $expr, $name_expr) + else + _replace_zero($model, $expr) + end end return _finalize_macro( model, @@ -97,6 +103,17 @@ macro expression(input_args...) ) end +function _build_subexpression( + error_fn::Function, + ::AbstractModel, + expr::Any, + ::String, +) + return error_fn( + "Unable to build a subexpression for the type $(typeof(expr))", + ) +end + """ @expressions(model, args...) diff --git a/test/test_macros.jl b/test/test_macros.jl index 3ec99b5f498..ec39bc9edb9 100644 --- a/test/test_macros.jl +++ b/test/test_macros.jl @@ -2487,4 +2487,92 @@ function test_array_scalar_sets() return end +function test_subexpression_kwarg() + model = Model() + @variable(model, x) + @expression(model, ex, sin(x), subexpression = true) + @test ex isa VariableRef + @test model[:ex] isa VariableRef + @test model[:ex] === ex + @test occursin(r"ex - sin\(x\) ==? 0", sprint(print, model)) + @test num_variables(model) == 2 + return +end + +function test_subexpression_kwarg_array() + model = Model() + @variable(model, x[1:2]) + @expression(model, ex[i in 1:2], sin(x[i]), subexpression = true) + @test ex isa Vector{VariableRef} + @test model[:ex] === ex + @test occursin(r"ex\[1\] - sin\(x\[1\]\) ==? 0", sprint(print, model)) + @test occursin(r"ex\[2\] - sin\(x\[2\]\) ==? 0", sprint(print, model)) + @test num_variables(model) == 4 + return +end + +function test_subexpression_kwarg_dense_axis_array() + model = Model() + @variable(model, x[2:3]) + @expression(model, ex[i in 2:3], sin(x[i]), subexpression = true) + @test ex isa Containers.DenseAxisArray{VariableRef} + @test model[:ex] === ex + @test occursin(r"ex\[2\] - sin\(x\[2\]\) ==? 0", sprint(print, model)) + @test occursin(r"ex\[3\] - sin\(x\[3\]\) ==? 0", sprint(print, model)) + @test num_variables(model) == 4 + return +end + +function test_subexpression_kwarg_dense_axis_array() + model = Model() + @variable(model, x[i in 1:3; isodd(i)]) + @expression(model, ex[i in 1:3; isodd(i)], sin(x[i]), subexpression = true) + @test ex isa Containers.SparseAxisArray{VariableRef} + @test model[:ex] === ex + @test occursin(r"ex\[1\] - sin\(x\[1\]\) ==? 0", sprint(print, model)) + @test occursin(r"ex\[3\] - sin\(x\[3\]\) ==? 0", sprint(print, model)) + @test num_variables(model) == 4 + return +end + +function test_subexpression_kwarg_vector_element() + model = Model() + @variable(model, x[i in 1:2]) + @expression(model, ex, sin.(x), subexpression = true) + @test ex isa Vector{VariableRef} + @test model[:ex] === ex + @test occursin(r"ex - sin\(x\[1\]\) ==? 0", sprint(print, model)) + @test occursin(r"ex - sin\(x\[2\]\) ==? 0", sprint(print, model)) + @test num_variables(model) == 4 + return +end + +function test_subexpression_kwarg_no_name() + model = Model() + @variable(model, x) + ex = @expression(model, sin(x), subexpression = true) + @test ex isa VariableRef + @test !haskey(model, :ex) + @test occursin(r"\_\[2\] - sin\(x\) ==? 0", sprint(print, model)) + @test num_variables(model) == 2 + return +end + +function test_subexpression_kwarg_dict_element() + model = Model() + @variable(model, x[i in 1:2]) + @test_throws_runtime( + ErrorException( + "In `@expression(model, ex, Dict((i => x[i] for i = 1:2)), subexpression = true)`: Unable to build a subexpression for the type $(Dict{Int,VariableRef})", + ), + @expression( + model, + ex, + Dict(i => x[i] for i in 1:2), + subexpression = true, + ), + ) + return +end + end # module