Skip to content

Commit

Permalink
Add a few missing methods for AbstractJuMPScalar to support e.g. `D…
Browse files Browse the repository at this point in the history
…istances.jl` (#3585)
  • Loading branch information
LebedevRI authored Nov 27, 2023
1 parent 417c58d commit f7fb42b
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/JuMP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1049,6 +1049,9 @@ function owner_model end
Base.ndims(::Type{<:AbstractJuMPScalar}) = 0
Base.ndims(::AbstractJuMPScalar) = 0

Base.IteratorEltype(::Type{<:AbstractJuMPScalar}) = Base.HasEltype()
Base.eltype(::Type{T}) where {T<:AbstractJuMPScalar} = T

# These are required to create symmetric containers of AbstractJuMPScalars.
LinearAlgebra.symmetric_type(::Type{T}) where {T<:AbstractJuMPScalar} = T
LinearAlgebra.hermitian_type(::Type{T}) where {T<:AbstractJuMPScalar} = T
Expand All @@ -1059,6 +1062,7 @@ LinearAlgebra.adjoint(scalar::AbstractJuMPScalar) = conj(scalar)
Base.iterate(x::AbstractJuMPScalar) = (x, true)
Base.iterate(::AbstractJuMPScalar, state) = nothing
Base.isempty(::AbstractJuMPScalar) = false
Base.length(::AbstractJuMPScalar) = 1

# Check if two arrays of AbstractJuMPScalars are equal. Useful for testing.
function isequal_canonical(
Expand Down
6 changes: 6 additions & 0 deletions src/aff_expr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,14 @@ function Base.one(::Type{GenericAffExpr{C,V}}) where {C,V}
return GenericAffExpr{C,V}(one(C), OrderedDict{V,C}())
end

function Base.oneunit(::Type{GenericAffExpr{C,V}}) where {C,V}
return GenericAffExpr{C,V}(oneunit(C), OrderedDict{V,C}())
end

Base.one(a::GenericAffExpr) = one(typeof(a))

Base.oneunit(a::GenericAffExpr) = oneunit(typeof(a))

Base.copy(a::GenericAffExpr) = GenericAffExpr(copy(a.constant), copy(a.terms))

Base.broadcastable(a::GenericAffExpr) = Ref(a)
Expand Down
6 changes: 6 additions & 0 deletions src/variables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -361,10 +361,16 @@ end

Base.one(v::AbstractVariableRef) = one(typeof(v))

Base.oneunit(v::AbstractVariableRef) = oneunit(typeof(v))

function Base.one(::Type{V}) where {V<:AbstractVariableRef}
return one(GenericAffExpr{value_type(V),V})
end

function Base.oneunit(::Type{V}) where {V<:AbstractVariableRef}
return oneunit(GenericAffExpr{value_type(V),V})
end

"""
coefficient(v1::GenericVariableRef{T}, v2::GenericVariableRef{T}) where {T}
Expand Down
25 changes: 25 additions & 0 deletions test/test_variable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1601,4 +1601,29 @@ function test_bad_bound_types()
return
end

function test_variable_length()
model = Model()
@variable(model, x)
@test length(x) == 1
return
end

function test_variable_eltype()
model = Model()
@variable(model, x)
@test Base.IteratorEltype(x) == Base.HasEltype()
@test Base.eltype(typeof(x)) == typeof(x)
return
end

function test_variable_one()
model = Model()
@variable(model, x)
@test one(x) == AffExpr(1.0)
@test one(2 * x) == AffExpr(1.0)
@test oneunit(x) == AffExpr(1.0)
@test oneunit(2 * x) == AffExpr(1.0)
return
end

end # module TestVariable

0 comments on commit f7fb42b

Please sign in to comment.