Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More refactoring of src/macros.jl #3619

Merged
merged 3 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 36 additions & 143 deletions src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -323,38 +323,26 @@ function model_convert(
return model_convert.(model, x)
end

function _add_keyword_args(call::Expr, kwargs::Dict; exclude = Symbol[])
for (key, value) in kwargs
if key in exclude
continue
end
push!(call.args, esc(Expr(:kw, key, value)))
end
return
end

"""
_add_positional_args(call::Expr, args::Vector{Any})::Nothing
_add_additional_args(
call::Expr,
args::Vector,
kwargs::Dict{Symbol,Any};
kwarg_exclude::Vector{Symbol} = Symbol[],
)

Add the positional arguments `args` to the function call expression `call`,
escaping each argument expression.

This function is able to incorporate additional positional arguments to `call`s
that already have keyword arguments.

## Example

```jldoctest
julia> call = :(f(1, a=2))
:(f(1, a = 2))

julia> JuMP._add_positional_args(call, [:(x)])

julia> call
:(f(1, $(Expr(:escape, :x)), a = 2))
```
"""
function _add_positional_args(call::Expr, args::Vector)
function _add_additional_args(
call::Expr,
args::Vector,
kwargs::Dict{Symbol,Any};
kwarg_exclude::Vector{Symbol} = Symbol[],
)
call_args = call.args
if Meta.isexpr(call, :.)
# call is broadcasted
Expand All @@ -368,6 +356,11 @@ function _add_positional_args(call::Expr, args::Vector)
append!(call_args, esc.(args))
# Re-add the cached keyword arguments back to the end
append!(call_args, kw_args)
for (key, value) in kwargs
if !(key in kwarg_exclude)
push!(call_args, esc(Expr(:kw, key, value)))
end
end
return
end

Expand Down Expand Up @@ -406,8 +399,12 @@ function _finalize_macro(
wrap_let::Bool = false,
)
@assert Meta.isexpr(model, :escape)
if wrap_let
code = _wrap_let(model, code)
if wrap_let && model.args[1] isa Symbol
code = quote
let $model = $model
$code
end
end
end
if register_name !== nothing
sym_name = Meta.quot(register_name)
Expand All @@ -416,12 +413,8 @@ function _finalize_macro(
$(esc(register_name)) = $model[$sym_name] = $code
end
end
return Expr(
:block,
source,
:(_valid_model($model, $(Meta.quot(model.args[1])))),
code,
)
is_valid_code = :(_valid_model($model, $(Meta.quot(model.args[1]))))
return Expr(:block, source, is_valid_code, code)
end

function _error_if_cannot_register(model::AbstractModel, name::Symbol)
Expand All @@ -442,89 +435,6 @@ function _error_if_cannot_register(model::AbstractModel, name::Symbol)
return
end

function _check_vectorized(sense::Symbol)
sense_str = string(sense)
if startswith(sense_str, '.')
return Symbol(sense_str[2:end]), true
end
return sense, false
end

"""
_desparsify(x)

If `x` is an `AbstractSparseArray`, return the dense equivalent, otherwise just
return `x`.

This function is used in `_build_constraint`.

## Why is this needed?

When broadcasting `f.(x)` over an `AbstractSparseArray` `x`, Julia first calls
the equivalent of `f(zero(eltype(x))`. Here's an example:

```jldoctest
julia> import SparseArrays

julia> foo(x) = (println("Calling \$(x)"); x)
foo (generic function with 1 method)

julia> foo.(SparseArrays.sparsevec([1, 2], [1, 2]))
Calling 1
Calling 2
2-element SparseArrays.SparseVector{Int64, Int64} with 2 stored entries:
[1] = 1
[2] = 2
```

However, if `f` is mutating, this can have serious consequences! In our case,
broadcasting `build_constraint` will add a new `0 = 0` constraint.

Sparse arrays most-often arise when some input data to the constraint is sparse
(e.g., a constant vector or matrix). Due to promotion and arithmetic, this
results in a constraint function that is represented by an `AbstractSparseArray`,
but is actually dense. Thus, we can safely `collect` the matrix into a dense
array.

If the function is sparse, it's not obvious what to do. What is the "zero"
element of the result? What does it mean to broadcast `build_constraint` over a
sparse array adding scalar constraints? This likely means that the user is using
the wrong data structure. For simplicity, let's also call `collect` into a dense
array, and wait for complaints.
"""
_desparsify(x::SparseArrays.AbstractSparseArray) = collect(x)

_desparsify(x) = x

function _functionize(v::V) where {V<:AbstractVariableRef}
return convert(GenericAffExpr{value_type(V),V}, v)
end

_functionize(v::AbstractArray{<:AbstractVariableRef}) = _functionize.(v)

function _functionize(
v::LinearAlgebra.Symmetric{V},
) where {V<:AbstractVariableRef}
return LinearAlgebra.Symmetric(_functionize(v.data))
end

_functionize(x) = x

_functionize(::_MA.Zero) = false

"""
reverse_sense(::Val{T}) where {T}

Given an (in)equality symbol `T`, return a new `Val` object with the opposite
(in)equality symbol.
"""
function reverse_sense end
reverse_sense(::Val{:<=}) = Val(:>=)
reverse_sense(::Val{:≤}) = Val(:≥)
reverse_sense(::Val{:>=}) = Val(:<=)
reverse_sense(::Val{:≥}) = Val(:≤)
reverse_sense(::Val{:(==)}) = Val(:(==))

# This method is needed because Julia v1.10 prints LineNumberNode in the string
# representation of an expression.
function _strip_LineNumberNode(x::Expr)
Expand All @@ -536,42 +446,36 @@ end

_strip_LineNumberNode(x) = x

function _macro_error(macroname, args, source, str...)
function _macro_error(macro_name, args, source, str...)
str_args = join(_strip_LineNumberNode.(args), ", ")
return error(
"At $(source.file):$(source.line): `@$macroname($str_args)`: ",
"At $(source.file):$(source.line): `@$macro_name($str_args)`: ",
str...,
)
end

# Given a base_name and idxvars, returns an expression that constructs the name
# of the object. For use within macros only.
function _name_call(base_name, idxvars)
if isempty(idxvars) || base_name == ""
function _base_name_with_indices(base_name, index_vars::Vector)
if isempty(index_vars) || base_name == ""
return base_name
end
ex = Expr(:call, :string, base_name, "[")
for i in 1:length(idxvars)
expr = Expr(:call, :string, base_name, "[")
for index in index_vars
# Converting the arguments to strings before concatenating is faster:
# https://github.com/JuliaLang/julia/issues/29550.
esc_idxvar = esc(idxvars[i])
push!(ex.args, :(string($esc_idxvar)))
i < length(idxvars) && push!(ex.args, ",")
push!(expr.args, :(string($(esc(index)))))
push!(expr.args, ",")
end
push!(ex.args, "]")
return ex
expr.args[end] = "]"
return expr
end

_esc_non_constant(x::Number) = x
_esc_non_constant(x::Expr) = Meta.isexpr(x, :quote) ? x : esc(x)
_esc_non_constant(x) = esc(x)

"""
_replace_zero(model::M, x) where {M<:AbstractModel}

Replaces `_MA.Zero` with a floating point `zero(value_type(M))`.
"""
_replace_zero(::M, ::_MA.Zero) where {M<:AbstractModel} = zero(value_type(M))

_replace_zero(::AbstractModel, x::Any) = x

function _plural_macro_code(model, block, macro_sym)
Expand Down Expand Up @@ -609,17 +513,6 @@ function _plural_macro_code(model, block, macro_sym)
return code
end

function _wrap_let(model, code)
if Meta.isexpr(model, :escape) && model.args[1] isa Symbol
return quote
let $model = $model
$code
end
end
end
return code
end

include("macros/@objective.jl")
include("macros/@expression.jl")
include("macros/@constraint.jl")
Expand Down
78 changes: 74 additions & 4 deletions src/macros/@constraint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,11 @@ macro constraint(input_args...)
)
end
is_vectorized, parse_code, build_call = parse_constraint(error_fn, x)
_add_positional_args(build_call, extra)
_add_keyword_args(
_add_additional_args(
build_call,
extra,
kwargs;
exclude = [:base_name, :container, :set_string_name],
kwarg_exclude = [:base_name, :container, :set_string_name],
)
# ; base_name
default_base_name = string(something(Containers._get_name(c), ""))
Expand All @@ -116,7 +116,7 @@ macro constraint(input_args...)
# There is no need to escape this one.
container = get(kwargs, :container, :Auto)
# ; set_string_name
name_expr = _name_call(base_name, index_vars)
name_expr = _base_name_with_indices(base_name, index_vars)
if name_expr != ""
set_string_name = if haskey(kwargs, :set_string_name)
esc(kwargs[:set_string_name])
Expand Down Expand Up @@ -276,6 +276,60 @@ function parse_constraint(error_fn::Function, arg)
)
end

function _check_vectorized(sense::Symbol)
sense_str = string(sense)
if startswith(sense_str, '.')
return Symbol(sense_str[2:end]), true
end
return sense, false
end

"""
_desparsify(x)

If `x` is an `AbstractSparseArray`, return the dense equivalent, otherwise just
return `x`.

This function is used in `_build_constraint`.

## Why is this needed?

When broadcasting `f.(x)` over an `AbstractSparseArray` `x`, Julia first calls
the equivalent of `f(zero(eltype(x))`. Here's an example:

```jldoctest
julia> import SparseArrays

julia> foo(x) = (println("Calling \$(x)"); x)
foo (generic function with 1 method)

julia> foo.(SparseArrays.sparsevec([1, 2], [1, 2]))
Calling 1
Calling 2
2-element SparseArrays.SparseVector{Int64, Int64} with 2 stored entries:
[1] = 1
[2] = 2
```

However, if `f` is mutating, this can have serious consequences! In our case,
broadcasting `build_constraint` will add a new `0 = 0` constraint.

Sparse arrays most-often arise when some input data to the constraint is sparse
(e.g., a constant vector or matrix). Due to promotion and arithmetic, this
results in a constraint function that is represented by an `AbstractSparseArray`,
but is actually dense. Thus, we can safely `collect` the matrix into a dense
array.

If the function is sparse, it's not obvious what to do. What is the "zero"
element of the result? What does it mean to broadcast `build_constraint` over a
sparse array adding scalar constraints? This likely means that the user is using
the wrong data structure. For simplicity, let's also call `collect` into a dense
array, and wait for complaints.
"""
_desparsify(x::SparseArrays.AbstractSparseArray) = collect(x)

_desparsify(x) = x

"""
parse_constraint_head(error_fn::Function, ::Val{head}, args...)

Expand Down Expand Up @@ -621,6 +675,22 @@ function parse_constraint_call(
return parse_code, build_call
end

function _functionize(v::V) where {V<:AbstractVariableRef}
return convert(GenericAffExpr{value_type(V),V}, v)
end

_functionize(v::AbstractArray{<:AbstractVariableRef}) = _functionize.(v)

function _functionize(
v::LinearAlgebra.Symmetric{V},
) where {V<:AbstractVariableRef}
return LinearAlgebra.Symmetric(_functionize(v.data))
end

_functionize(x) = x

_functionize(::_MA.Zero) = false

"""
parse_constraint_call(
error_fn::Function,
Expand Down
Loading
Loading