Skip to content

Commit

Permalink
More refactoring of src/macros.jl (#3619)
Browse files Browse the repository at this point in the history
  • Loading branch information
odow authored Dec 11, 2023
1 parent 65115f3 commit 9704d2b
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 157 deletions.
179 changes: 36 additions & 143 deletions src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -323,38 +323,26 @@ function model_convert(
return model_convert.(model, x)
end

function _add_keyword_args(call::Expr, kwargs::Dict; exclude = Symbol[])
for (key, value) in kwargs
if key in exclude
continue
end
push!(call.args, esc(Expr(:kw, key, value)))
end
return
end

"""
_add_positional_args(call::Expr, args::Vector{Any})::Nothing
_add_additional_args(
call::Expr,
args::Vector,
kwargs::Dict{Symbol,Any};
kwarg_exclude::Vector{Symbol} = Symbol[],
)
Add the positional arguments `args` to the function call expression `call`,
escaping each argument expression.
This function is able to incorporate additional positional arguments to `call`s
that already have keyword arguments.
## Example
```jldoctest
julia> call = :(f(1, a=2))
:(f(1, a = 2))
julia> JuMP._add_positional_args(call, [:(x)])
julia> call
:(f(1, $(Expr(:escape, :x)), a = 2))
```
"""
function _add_positional_args(call::Expr, args::Vector)
function _add_additional_args(
call::Expr,
args::Vector,
kwargs::Dict{Symbol,Any};
kwarg_exclude::Vector{Symbol} = Symbol[],
)
call_args = call.args
if Meta.isexpr(call, :.)
# call is broadcasted
Expand All @@ -368,6 +356,11 @@ function _add_positional_args(call::Expr, args::Vector)
append!(call_args, esc.(args))
# Re-add the cached keyword arguments back to the end
append!(call_args, kw_args)
for (key, value) in kwargs
if !(key in kwarg_exclude)
push!(call_args, esc(Expr(:kw, key, value)))
end
end
return
end

Expand Down Expand Up @@ -406,8 +399,12 @@ function _finalize_macro(
wrap_let::Bool = false,
)
@assert Meta.isexpr(model, :escape)
if wrap_let
code = _wrap_let(model, code)
if wrap_let && model.args[1] isa Symbol
code = quote
let $model = $model
$code
end
end
end
if register_name !== nothing
sym_name = Meta.quot(register_name)
Expand All @@ -416,12 +413,8 @@ function _finalize_macro(
$(esc(register_name)) = $model[$sym_name] = $code
end
end
return Expr(
:block,
source,
:(_valid_model($model, $(Meta.quot(model.args[1])))),
code,
)
is_valid_code = :(_valid_model($model, $(Meta.quot(model.args[1]))))
return Expr(:block, source, is_valid_code, code)
end

function _error_if_cannot_register(model::AbstractModel, name::Symbol)
Expand All @@ -442,89 +435,6 @@ function _error_if_cannot_register(model::AbstractModel, name::Symbol)
return
end

function _check_vectorized(sense::Symbol)
sense_str = string(sense)
if startswith(sense_str, '.')
return Symbol(sense_str[2:end]), true
end
return sense, false
end

"""
_desparsify(x)
If `x` is an `AbstractSparseArray`, return the dense equivalent, otherwise just
return `x`.
This function is used in `_build_constraint`.
## Why is this needed?
When broadcasting `f.(x)` over an `AbstractSparseArray` `x`, Julia first calls
the equivalent of `f(zero(eltype(x))`. Here's an example:
```jldoctest
julia> import SparseArrays
julia> foo(x) = (println("Calling \$(x)"); x)
foo (generic function with 1 method)
julia> foo.(SparseArrays.sparsevec([1, 2], [1, 2]))
Calling 1
Calling 2
2-element SparseArrays.SparseVector{Int64, Int64} with 2 stored entries:
[1] = 1
[2] = 2
```
However, if `f` is mutating, this can have serious consequences! In our case,
broadcasting `build_constraint` will add a new `0 = 0` constraint.
Sparse arrays most-often arise when some input data to the constraint is sparse
(e.g., a constant vector or matrix). Due to promotion and arithmetic, this
results in a constraint function that is represented by an `AbstractSparseArray`,
but is actually dense. Thus, we can safely `collect` the matrix into a dense
array.
If the function is sparse, it's not obvious what to do. What is the "zero"
element of the result? What does it mean to broadcast `build_constraint` over a
sparse array adding scalar constraints? This likely means that the user is using
the wrong data structure. For simplicity, let's also call `collect` into a dense
array, and wait for complaints.
"""
_desparsify(x::SparseArrays.AbstractSparseArray) = collect(x)

_desparsify(x) = x

function _functionize(v::V) where {V<:AbstractVariableRef}
return convert(GenericAffExpr{value_type(V),V}, v)
end

_functionize(v::AbstractArray{<:AbstractVariableRef}) = _functionize.(v)

function _functionize(
v::LinearAlgebra.Symmetric{V},
) where {V<:AbstractVariableRef}
return LinearAlgebra.Symmetric(_functionize(v.data))
end

_functionize(x) = x

_functionize(::_MA.Zero) = false

"""
reverse_sense(::Val{T}) where {T}
Given an (in)equality symbol `T`, return a new `Val` object with the opposite
(in)equality symbol.
"""
function reverse_sense end
reverse_sense(::Val{:<=}) = Val(:>=)
reverse_sense(::Val{:≤}) = Val(:)
reverse_sense(::Val{:>=}) = Val(:<=)
reverse_sense(::Val{:≥}) = Val(:)
reverse_sense(::Val{:(==)}) = Val(:(==))

# This method is needed because Julia v1.10 prints LineNumberNode in the string
# representation of an expression.
function _strip_LineNumberNode(x::Expr)
Expand All @@ -536,42 +446,36 @@ end

_strip_LineNumberNode(x) = x

function _macro_error(macroname, args, source, str...)
function _macro_error(macro_name, args, source, str...)
str_args = join(_strip_LineNumberNode.(args), ", ")
return error(
"At $(source.file):$(source.line): `@$macroname($str_args)`: ",
"At $(source.file):$(source.line): `@$macro_name($str_args)`: ",
str...,
)
end

# Given a base_name and idxvars, returns an expression that constructs the name
# of the object. For use within macros only.
function _name_call(base_name, idxvars)
if isempty(idxvars) || base_name == ""
function _base_name_with_indices(base_name, index_vars::Vector)
if isempty(index_vars) || base_name == ""
return base_name
end
ex = Expr(:call, :string, base_name, "[")
for i in 1:length(idxvars)
expr = Expr(:call, :string, base_name, "[")
for index in index_vars
# Converting the arguments to strings before concatenating is faster:
# https://github.com/JuliaLang/julia/issues/29550.
esc_idxvar = esc(idxvars[i])
push!(ex.args, :(string($esc_idxvar)))
i < length(idxvars) && push!(ex.args, ",")
push!(expr.args, :(string($(esc(index)))))
push!(expr.args, ",")
end
push!(ex.args, "]")
return ex
expr.args[end] = "]"
return expr
end

_esc_non_constant(x::Number) = x
_esc_non_constant(x::Expr) = Meta.isexpr(x, :quote) ? x : esc(x)
_esc_non_constant(x) = esc(x)

"""
_replace_zero(model::M, x) where {M<:AbstractModel}
Replaces `_MA.Zero` with a floating point `zero(value_type(M))`.
"""
_replace_zero(::M, ::_MA.Zero) where {M<:AbstractModel} = zero(value_type(M))

_replace_zero(::AbstractModel, x::Any) = x

function _plural_macro_code(model, block, macro_sym)
Expand Down Expand Up @@ -609,17 +513,6 @@ function _plural_macro_code(model, block, macro_sym)
return code
end

function _wrap_let(model, code)
if Meta.isexpr(model, :escape) && model.args[1] isa Symbol
return quote
let $model = $model
$code
end
end
end
return code
end

include("macros/@objective.jl")
include("macros/@expression.jl")
include("macros/@constraint.jl")
Expand Down
78 changes: 74 additions & 4 deletions src/macros/@constraint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,11 @@ macro constraint(input_args...)
)
end
is_vectorized, parse_code, build_call = parse_constraint(error_fn, x)
_add_positional_args(build_call, extra)
_add_keyword_args(
_add_additional_args(
build_call,
extra,
kwargs;
exclude = [:base_name, :container, :set_string_name],
kwarg_exclude = [:base_name, :container, :set_string_name],
)
# ; base_name
default_base_name = string(something(Containers._get_name(c), ""))
Expand All @@ -116,7 +116,7 @@ macro constraint(input_args...)
# There is no need to escape this one.
container = get(kwargs, :container, :Auto)
# ; set_string_name
name_expr = _name_call(base_name, index_vars)
name_expr = _base_name_with_indices(base_name, index_vars)
if name_expr != ""
set_string_name = if haskey(kwargs, :set_string_name)
esc(kwargs[:set_string_name])
Expand Down Expand Up @@ -276,6 +276,60 @@ function parse_constraint(error_fn::Function, arg)
)
end

function _check_vectorized(sense::Symbol)
sense_str = string(sense)
if startswith(sense_str, '.')
return Symbol(sense_str[2:end]), true
end
return sense, false
end

"""
_desparsify(x)
If `x` is an `AbstractSparseArray`, return the dense equivalent, otherwise just
return `x`.
This function is used in `_build_constraint`.
## Why is this needed?
When broadcasting `f.(x)` over an `AbstractSparseArray` `x`, Julia first calls
the equivalent of `f(zero(eltype(x))`. Here's an example:
```jldoctest
julia> import SparseArrays
julia> foo(x) = (println("Calling \$(x)"); x)
foo (generic function with 1 method)
julia> foo.(SparseArrays.sparsevec([1, 2], [1, 2]))
Calling 1
Calling 2
2-element SparseArrays.SparseVector{Int64, Int64} with 2 stored entries:
[1] = 1
[2] = 2
```
However, if `f` is mutating, this can have serious consequences! In our case,
broadcasting `build_constraint` will add a new `0 = 0` constraint.
Sparse arrays most-often arise when some input data to the constraint is sparse
(e.g., a constant vector or matrix). Due to promotion and arithmetic, this
results in a constraint function that is represented by an `AbstractSparseArray`,
but is actually dense. Thus, we can safely `collect` the matrix into a dense
array.
If the function is sparse, it's not obvious what to do. What is the "zero"
element of the result? What does it mean to broadcast `build_constraint` over a
sparse array adding scalar constraints? This likely means that the user is using
the wrong data structure. For simplicity, let's also call `collect` into a dense
array, and wait for complaints.
"""
_desparsify(x::SparseArrays.AbstractSparseArray) = collect(x)

_desparsify(x) = x

"""
parse_constraint_head(error_fn::Function, ::Val{head}, args...)
Expand Down Expand Up @@ -621,6 +675,22 @@ function parse_constraint_call(
return parse_code, build_call
end

function _functionize(v::V) where {V<:AbstractVariableRef}
return convert(GenericAffExpr{value_type(V),V}, v)
end

_functionize(v::AbstractArray{<:AbstractVariableRef}) = _functionize.(v)

function _functionize(
v::LinearAlgebra.Symmetric{V},
) where {V<:AbstractVariableRef}
return LinearAlgebra.Symmetric(_functionize(v.data))
end

_functionize(x) = x

_functionize(::_MA.Zero) = false

"""
parse_constraint_call(
error_fn::Function,
Expand Down
Loading

0 comments on commit 9704d2b

Please sign in to comment.