From 9704d2bea50c4404e336d640c2fd6549887738e1 Mon Sep 17 00:00:00 2001 From: Oscar Dowson Date: Mon, 11 Dec 2023 20:36:13 +1300 Subject: [PATCH] More refactoring of src/macros.jl (#3619) --- src/macros.jl | 179 ++++++++------------------------------ src/macros/@constraint.jl | 78 ++++++++++++++++- src/macros/@variable.jl | 25 +++++- test/test_macros.jl | 21 +++-- 4 files changed, 146 insertions(+), 157 deletions(-) diff --git a/src/macros.jl b/src/macros.jl index e0e68c4503e..884de13748c 100644 --- a/src/macros.jl +++ b/src/macros.jl @@ -323,38 +323,26 @@ function model_convert( return model_convert.(model, x) end -function _add_keyword_args(call::Expr, kwargs::Dict; exclude = Symbol[]) - for (key, value) in kwargs - if key in exclude - continue - end - push!(call.args, esc(Expr(:kw, key, value))) - end - return -end - """ - _add_positional_args(call::Expr, args::Vector{Any})::Nothing + _add_additional_args( + call::Expr, + args::Vector, + kwargs::Dict{Symbol,Any}; + kwarg_exclude::Vector{Symbol} = Symbol[], + ) Add the positional arguments `args` to the function call expression `call`, escaping each argument expression. This function is able to incorporate additional positional arguments to `call`s that already have keyword arguments. - -## Example - -```jldoctest -julia> call = :(f(1, a=2)) -:(f(1, a = 2)) - -julia> JuMP._add_positional_args(call, [:(x)]) - -julia> call -:(f(1, $(Expr(:escape, :x)), a = 2)) -``` """ -function _add_positional_args(call::Expr, args::Vector) +function _add_additional_args( + call::Expr, + args::Vector, + kwargs::Dict{Symbol,Any}; + kwarg_exclude::Vector{Symbol} = Symbol[], +) call_args = call.args if Meta.isexpr(call, :.) # call is broadcasted @@ -368,6 +356,11 @@ function _add_positional_args(call::Expr, args::Vector) append!(call_args, esc.(args)) # Re-add the cached keyword arguments back to the end append!(call_args, kw_args) + for (key, value) in kwargs + if !(key in kwarg_exclude) + push!(call_args, esc(Expr(:kw, key, value))) + end + end return end @@ -406,8 +399,12 @@ function _finalize_macro( wrap_let::Bool = false, ) @assert Meta.isexpr(model, :escape) - if wrap_let - code = _wrap_let(model, code) + if wrap_let && model.args[1] isa Symbol + code = quote + let $model = $model + $code + end + end end if register_name !== nothing sym_name = Meta.quot(register_name) @@ -416,12 +413,8 @@ function _finalize_macro( $(esc(register_name)) = $model[$sym_name] = $code end end - return Expr( - :block, - source, - :(_valid_model($model, $(Meta.quot(model.args[1])))), - code, - ) + is_valid_code = :(_valid_model($model, $(Meta.quot(model.args[1])))) + return Expr(:block, source, is_valid_code, code) end function _error_if_cannot_register(model::AbstractModel, name::Symbol) @@ -442,89 +435,6 @@ function _error_if_cannot_register(model::AbstractModel, name::Symbol) return end -function _check_vectorized(sense::Symbol) - sense_str = string(sense) - if startswith(sense_str, '.') - return Symbol(sense_str[2:end]), true - end - return sense, false -end - -""" - _desparsify(x) - -If `x` is an `AbstractSparseArray`, return the dense equivalent, otherwise just -return `x`. - -This function is used in `_build_constraint`. - -## Why is this needed? - -When broadcasting `f.(x)` over an `AbstractSparseArray` `x`, Julia first calls -the equivalent of `f(zero(eltype(x))`. Here's an example: - -```jldoctest -julia> import SparseArrays - -julia> foo(x) = (println("Calling \$(x)"); x) -foo (generic function with 1 method) - -julia> foo.(SparseArrays.sparsevec([1, 2], [1, 2])) -Calling 1 -Calling 2 -2-element SparseArrays.SparseVector{Int64, Int64} with 2 stored entries: - [1] = 1 - [2] = 2 -``` - -However, if `f` is mutating, this can have serious consequences! In our case, -broadcasting `build_constraint` will add a new `0 = 0` constraint. - -Sparse arrays most-often arise when some input data to the constraint is sparse -(e.g., a constant vector or matrix). Due to promotion and arithmetic, this -results in a constraint function that is represented by an `AbstractSparseArray`, -but is actually dense. Thus, we can safely `collect` the matrix into a dense -array. - -If the function is sparse, it's not obvious what to do. What is the "zero" -element of the result? What does it mean to broadcast `build_constraint` over a -sparse array adding scalar constraints? This likely means that the user is using -the wrong data structure. For simplicity, let's also call `collect` into a dense -array, and wait for complaints. -""" -_desparsify(x::SparseArrays.AbstractSparseArray) = collect(x) - -_desparsify(x) = x - -function _functionize(v::V) where {V<:AbstractVariableRef} - return convert(GenericAffExpr{value_type(V),V}, v) -end - -_functionize(v::AbstractArray{<:AbstractVariableRef}) = _functionize.(v) - -function _functionize( - v::LinearAlgebra.Symmetric{V}, -) where {V<:AbstractVariableRef} - return LinearAlgebra.Symmetric(_functionize(v.data)) -end - -_functionize(x) = x - -_functionize(::_MA.Zero) = false - -""" - reverse_sense(::Val{T}) where {T} - -Given an (in)equality symbol `T`, return a new `Val` object with the opposite -(in)equality symbol. -""" -function reverse_sense end -reverse_sense(::Val{:<=}) = Val(:>=) -reverse_sense(::Val{:≤}) = Val(:≥) -reverse_sense(::Val{:>=}) = Val(:<=) -reverse_sense(::Val{:≥}) = Val(:≤) -reverse_sense(::Val{:(==)}) = Val(:(==)) - # This method is needed because Julia v1.10 prints LineNumberNode in the string # representation of an expression. function _strip_LineNumberNode(x::Expr) @@ -536,42 +446,36 @@ end _strip_LineNumberNode(x) = x -function _macro_error(macroname, args, source, str...) +function _macro_error(macro_name, args, source, str...) str_args = join(_strip_LineNumberNode.(args), ", ") return error( - "At $(source.file):$(source.line): `@$macroname($str_args)`: ", + "At $(source.file):$(source.line): `@$macro_name($str_args)`: ", str..., ) end -# Given a base_name and idxvars, returns an expression that constructs the name -# of the object. For use within macros only. -function _name_call(base_name, idxvars) - if isempty(idxvars) || base_name == "" +function _base_name_with_indices(base_name, index_vars::Vector) + if isempty(index_vars) || base_name == "" return base_name end - ex = Expr(:call, :string, base_name, "[") - for i in 1:length(idxvars) + expr = Expr(:call, :string, base_name, "[") + for index in index_vars # Converting the arguments to strings before concatenating is faster: # https://github.com/JuliaLang/julia/issues/29550. - esc_idxvar = esc(idxvars[i]) - push!(ex.args, :(string($esc_idxvar))) - i < length(idxvars) && push!(ex.args, ",") + push!(expr.args, :(string($(esc(index))))) + push!(expr.args, ",") end - push!(ex.args, "]") - return ex + expr.args[end] = "]" + return expr end -_esc_non_constant(x::Number) = x -_esc_non_constant(x::Expr) = Meta.isexpr(x, :quote) ? x : esc(x) -_esc_non_constant(x) = esc(x) - """ _replace_zero(model::M, x) where {M<:AbstractModel} Replaces `_MA.Zero` with a floating point `zero(value_type(M))`. """ _replace_zero(::M, ::_MA.Zero) where {M<:AbstractModel} = zero(value_type(M)) + _replace_zero(::AbstractModel, x::Any) = x function _plural_macro_code(model, block, macro_sym) @@ -609,17 +513,6 @@ function _plural_macro_code(model, block, macro_sym) return code end -function _wrap_let(model, code) - if Meta.isexpr(model, :escape) && model.args[1] isa Symbol - return quote - let $model = $model - $code - end - end - end - return code -end - include("macros/@objective.jl") include("macros/@expression.jl") include("macros/@constraint.jl") diff --git a/src/macros/@constraint.jl b/src/macros/@constraint.jl index 07559097ae5..2e128567080 100644 --- a/src/macros/@constraint.jl +++ b/src/macros/@constraint.jl @@ -100,11 +100,11 @@ macro constraint(input_args...) ) end is_vectorized, parse_code, build_call = parse_constraint(error_fn, x) - _add_positional_args(build_call, extra) - _add_keyword_args( + _add_additional_args( build_call, + extra, kwargs; - exclude = [:base_name, :container, :set_string_name], + kwarg_exclude = [:base_name, :container, :set_string_name], ) # ; base_name default_base_name = string(something(Containers._get_name(c), "")) @@ -116,7 +116,7 @@ macro constraint(input_args...) # There is no need to escape this one. container = get(kwargs, :container, :Auto) # ; set_string_name - name_expr = _name_call(base_name, index_vars) + name_expr = _base_name_with_indices(base_name, index_vars) if name_expr != "" set_string_name = if haskey(kwargs, :set_string_name) esc(kwargs[:set_string_name]) @@ -276,6 +276,60 @@ function parse_constraint(error_fn::Function, arg) ) end +function _check_vectorized(sense::Symbol) + sense_str = string(sense) + if startswith(sense_str, '.') + return Symbol(sense_str[2:end]), true + end + return sense, false +end + +""" + _desparsify(x) + +If `x` is an `AbstractSparseArray`, return the dense equivalent, otherwise just +return `x`. + +This function is used in `_build_constraint`. + +## Why is this needed? + +When broadcasting `f.(x)` over an `AbstractSparseArray` `x`, Julia first calls +the equivalent of `f(zero(eltype(x))`. Here's an example: + +```jldoctest +julia> import SparseArrays + +julia> foo(x) = (println("Calling \$(x)"); x) +foo (generic function with 1 method) + +julia> foo.(SparseArrays.sparsevec([1, 2], [1, 2])) +Calling 1 +Calling 2 +2-element SparseArrays.SparseVector{Int64, Int64} with 2 stored entries: + [1] = 1 + [2] = 2 +``` + +However, if `f` is mutating, this can have serious consequences! In our case, +broadcasting `build_constraint` will add a new `0 = 0` constraint. + +Sparse arrays most-often arise when some input data to the constraint is sparse +(e.g., a constant vector or matrix). Due to promotion and arithmetic, this +results in a constraint function that is represented by an `AbstractSparseArray`, +but is actually dense. Thus, we can safely `collect` the matrix into a dense +array. + +If the function is sparse, it's not obvious what to do. What is the "zero" +element of the result? What does it mean to broadcast `build_constraint` over a +sparse array adding scalar constraints? This likely means that the user is using +the wrong data structure. For simplicity, let's also call `collect` into a dense +array, and wait for complaints. +""" +_desparsify(x::SparseArrays.AbstractSparseArray) = collect(x) + +_desparsify(x) = x + """ parse_constraint_head(error_fn::Function, ::Val{head}, args...) @@ -621,6 +675,22 @@ function parse_constraint_call( return parse_code, build_call end +function _functionize(v::V) where {V<:AbstractVariableRef} + return convert(GenericAffExpr{value_type(V),V}, v) +end + +_functionize(v::AbstractArray{<:AbstractVariableRef}) = _functionize.(v) + +function _functionize( + v::LinearAlgebra.Symmetric{V}, +) where {V<:AbstractVariableRef} + return LinearAlgebra.Symmetric(_functionize(v.data)) +end + +_functionize(x) = x + +_functionize(::_MA.Zero) = false + """ parse_constraint_call( error_fn::Function, diff --git a/src/macros/@variable.jl b/src/macros/@variable.jl index 1e6dbc88197..09d148ac63c 100644 --- a/src/macros/@variable.jl +++ b/src/macros/@variable.jl @@ -220,7 +220,7 @@ macro variable(input_args...) # There is no need to escape this one. container = get(kwargs, :container, :Auto) # ; set_string_name - name_expr = _name_call(base_name, index_vars) + name_expr = _base_name_with_indices(base_name, index_vars) if name_expr != "" set_string_name = if haskey(kwargs, :set_string_name) esc(kwargs[:set_string_name]) @@ -268,11 +268,11 @@ macro variable(input_args...) end filter!(ex -> !(ex in (:Int, :Bin, :PSD, :Symmetric, :Hermitian)), args) build_code = :(build_variable($error_fn, $(_constructor_expr(info_expr)))) - _add_positional_args(build_code, args) - _add_keyword_args( + _add_additional_args( build_code, + args, kwargs; - exclude = vcat( + kwarg_exclude = vcat( _INFO_KWARGS, [:base_name, :container, :variable_type, :set, :set_string_name], ), @@ -366,6 +366,10 @@ macro variables(model, block) return _plural_macro_code(model, block, Symbol("@variable")) end +_esc_non_constant(x::Number) = x +_esc_non_constant(x::Expr) = Meta.isexpr(x, :quote) ? x : esc(x) +_esc_non_constant(x) = esc(x) + """ parse_variable(error_fn::Function, ::_VariableInfoExpr, args...) @@ -403,6 +407,19 @@ function parse_variable( return var, set end +""" + reverse_sense(::Val{T}) where {T} + +Given an (in)equality symbol `T`, return a new `Val` object with the opposite +(in)equality symbol. +""" +function reverse_sense end +reverse_sense(::Val{:<=}) = Val(:>=) +reverse_sense(::Val{:≤}) = Val(:≥) +reverse_sense(::Val{:>=}) = Val(:<=) +reverse_sense(::Val{:≥}) = Val(:≤) +reverse_sense(::Val{:(==)}) = Val(:(==)) + # If the lhs is a number and not the rhs, we can deduce that the rhs is # the variable. function parse_variable( diff --git a/test/test_macros.jl b/test/test_macros.jl index d40ba967165..fea1c412eaa 100644 --- a/test/test_macros.jl +++ b/test/test_macros.jl @@ -166,22 +166,31 @@ function test_Check_Julia_condition_expression_parsing() return end -function test_add_positional_args() +function test_add_additional_args() call = :(f(1; a = 2)) - @test JuMP._add_positional_args(call, [:(MyObject)]) isa Nothing + kwargs = Dict{Symbol,Any}() + @test JuMP._add_additional_args(call, [:(MyObject)], kwargs) isa Nothing @test call == :(f(1, $(Expr(:escape, :MyObject)); a = 2)) call = :(f(1)) - JuMP._add_positional_args(call, [2, 3]) + JuMP._add_additional_args(call, [2, 3], kwargs) @test call == :(f(1, $(esc(2)), $(esc(3)))) call = :(f.(1)) - JuMP._add_positional_args(call, [2, 3]) + JuMP._add_additional_args(call, [2, 3], kwargs) @test call == :(f.(1, $(esc(2)), $(esc(3)))) call = :(f(1; a = 4)) - JuMP._add_positional_args(call, [2, 3]) + JuMP._add_additional_args(call, [2, 3], kwargs) @test call == :(f(1, $(esc(2)), $(esc(3)); a = 4)) call = :(f.(1; a = 4)) - JuMP._add_positional_args(call, [2, 3]) + JuMP._add_additional_args(call, [2, 3], kwargs) @test call == :(f.(1, $(esc(2)), $(esc(3)); a = 4)) + call = :(f.(1, a = 4)) + kwargs = Dict{Symbol,Any}(:b => 4, :c => false) + JuMP._add_additional_args(call, Any[2], kwargs; kwarg_exclude = [:b]) + @test call == Expr( + :., + :f, + Expr(:tuple, 1, esc(2), Expr(:kw, :a, 4), esc(Expr(:kw, :c, false))), + ) return end