Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pedromxavier committed Nov 23, 2023
1 parent 7929d58 commit af533a0
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 46 deletions.
6 changes: 3 additions & 3 deletions src/compiler/build.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,23 @@ function objective_function(model::Virtual.Model{T}, ::AbstractArchitecture) whe
end

for (ci, g) in model.g
ρ = model.ρ[ci]
ρ = MOI.get(model, Attributes.ConstraintEncodingPenalty(), ci)

for (ω, c) in g
model.H[ω] += ρ * c
end
end

for (vi, h) in model.h
θ = model.θ[vi]
θ = MOI.get(model, Attributes.VariableEncodingPenalty(), vi)

for (ω, c) in h
model.H[ω] += θ * c
end
end

for (ci, s) in model.s
η = model.η[ci]
η = MOI.get(model, Attributes.SlackVariableEncodingPenalty(), ci)

for (ω, c) in s
model.H[ω] += η * c
Expand Down
3 changes: 0 additions & 3 deletions src/compiler/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,6 @@ function compile!(model::Virtual.Model{T}, arch::AbstractArchitecture) where {T}
# Add Regular Constraints
constraints!(model, arch)

# Add Encoding Constraints
encoding_constraints!(model, arch)

# Compute penalties
penalties!(model, arch)

Expand Down
20 changes: 3 additions & 17 deletions src/compiler/constraints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,9 @@ function constraint(
end

Check warning on line 468 in src/compiler/constraints.jl

View check run for this annotation

Codecov / codecov/patch

src/compiler/constraints.jl#L466-L468

Added lines #L466 - L468 were not covered by tests

g[w] = one(T)

Check warning on line 470 in src/compiler/constraints.jl

View check run for this annotation

Codecov / codecov/patch

src/compiler/constraints.jl#L470

Added line #L470 was not covered by tests

# Tell the compiler that quadratization is necessary
MOI.set(model, Attributes.Quadratize(), true)

Check warning on line 473 in src/compiler/constraints.jl

View check run for this annotation

Codecov / codecov/patch

src/compiler/constraints.jl#L473

Added line #L473 was not covered by tests
end
end

Expand All @@ -483,20 +486,3 @@ function constraint(

return g^2 + h
end

function encoding_constraints!(model::Virtual.Model{T}, ::AbstractArchitecture) where {T}
for v in model.variables
i = Virtual.source(v)
χ = Virtual.penaltyfn(v)

if !isnothing(χ)
if i isa VI
model.h[i] = χ
elseif i isa CI
model.s[i] = χ
end
end
end

return nothing
end
6 changes: 3 additions & 3 deletions src/compiler/penalties.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ function penalties!(model::Virtual.Model{T}, ::AbstractArchitecture) where {T}
ρ = σ */ ϵ + β)
end

model.ρ[ci] = ρ
MOI.set(model, Attributes.ConstraintEncodingPenalty(), ci, ρ)
end

for (vi, h) in model.h
Expand All @@ -24,7 +24,7 @@ function penalties!(model::Virtual.Model{T}, ::AbstractArchitecture) where {T}
θ = σ */ ϵ + β)

Check warning on line 24 in src/compiler/penalties.jl

View check run for this annotation

Codecov / codecov/patch

src/compiler/penalties.jl#L24

Added line #L24 was not covered by tests
end

model.θ[vi] = θ
MOI.set(model, Attributes.VariableEncodingPenalty(), vi, θ)
end

for (ci, s) in model.s
Expand All @@ -35,7 +35,7 @@ function penalties!(model::Virtual.Model{T}, ::AbstractArchitecture) where {T}
η = σ */ ϵ + β)

Check warning on line 35 in src/compiler/penalties.jl

View check run for this annotation

Codecov / codecov/patch

src/compiler/penalties.jl#L34-L35

Added lines #L34 - L35 were not covered by tests
end

model.η[ci] = η
MOI.set(model, Attributes.SlackVariableEncodingPenalty(), ci, η)
end

return nothing
Expand Down
6 changes: 5 additions & 1 deletion src/encoding/variables/set/one_hot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,11 @@ function encode(

a, b = S

Γ = collect(range(a, b; length = p))
Γ = if p == 1
T[(a + b) / 2]

Check warning on line 103 in src/encoding/variables/set/one_hot.jl

View check run for this annotation

Codecov / codecov/patch

src/encoding/variables/set/one_hot.jl#L103

Added line #L103 was not covered by tests
else
collect(T, range(a, b; length = p))
end

return encode(var, e, Γ)
end
9 changes: 9 additions & 0 deletions src/virtual/encoding.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
function Encoding.encode!(model::Model{T}, v::Variable{T}) where {T}
x = source(v)
χ = penaltyfn(v)

if x isa VI
model.source[x] = v

if !isnothing(χ)
model.h[x] = χ
end
elseif x isa CI
model.slack[x] = v

if !isnothing(χ)
model.s[x] = χ
end
end

for y in target(v)
Expand Down
2 changes: 1 addition & 1 deletion src/wrapper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@ function MOI.optimize!(model::Optimizer)

# De facto JuMP to QUBO Compilation
let t = @elapsed ToQUBO.Compiler.compile!(model)
MOI.set(model, Attributes.CompilationStatus(), MOI.LOCALLY_SOLVED)
MOI.set(model, Attributes.CompilationTime(), t)
end

if !isnothing(model.optimizer)
MOI.optimize!(model.optimizer, model.target_model)
MOI.set(model, MOI.RawStatusString(), MOI.get(model.optimizer, MOI.RawStatusString()))
else
MOI.set(model, Attributes.CompilationStatus(), MOI.LOCALLY_SOLVED)
MOI.set(model, MOI.RawStatusString(), "Compilation complete without an internal solver")

Check warning on line 40 in src/wrapper.jl

View check run for this annotation

Codecov / codecov/patch

src/wrapper.jl#L40

Added line #L40 was not covered by tests
end

Expand Down
60 changes: 42 additions & 18 deletions test/integration/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,16 +89,16 @@ function test_interface_moi()
# max x1 + x2 + x3
# st x1 + x2 <= 1 (c1)
# x2 + x3 <= 1 (c2)
# x1 ∈ {0, 1}
# x2 ∈ {0, 1}
# x3 ∈ {0, 1}
# 0 <= x1 <= 1
# 0 <= x2 <= 1
# 0 <= x3 <= 1

model = MOI.instantiate(
() -> ToQUBO.Optimizer(RandomSampler.Optimizer);
with_bridge_type = Float64,
)

x, _ = MOI.add_constrained_variables(model, fill(MOI.ZeroOne(), 3))
x, _ = MOI.add_constrained_variables(model, fill(MOI.Interval{Float64}(0.0, 1.0), 3))

MOI.set(model, MOI.ObjectiveSense(), MOI.MAX_SENSE)

Expand Down Expand Up @@ -205,10 +205,10 @@ function test_interface_moi()
@test MOI.get(model, Attributes.VariableEncodingMethod(), x[1]) === nothing
@test MOI.get(model, Attributes.VariableEncodingMethod(), x[2]) === nothing

MOI.set(model, Attributes.VariableEncodingMethod(), x[1], Encoding.Arithmetic())
MOI.set(model, Attributes.VariableEncodingMethod(), x[1], Encoding.OneHot())
MOI.set(model, Attributes.VariableEncodingMethod(), x[2], Encoding.Arithmetic())

@test MOI.get(model, Attributes.VariableEncodingMethod(), x[1]) isa Encoding.Arithmetic
@test MOI.get(model, Attributes.VariableEncodingMethod(), x[1]) isa Encoding.OneHot
@test MOI.get(model, Attributes.VariableEncodingMethod(), x[2]) isa Encoding.Arithmetic

# Variable Encoding ATol
Expand All @@ -233,11 +233,11 @@ function test_interface_moi()
@test MOI.get(model, Attributes.VariableEncodingBits(), x[1]) === nothing
@test MOI.get(model, Attributes.VariableEncodingBits(), x[2]) === nothing

MOI.set(model, Attributes.VariableEncodingBits(), x[1], 1)
MOI.set(model, Attributes.VariableEncodingBits(), x[2], 2)
MOI.set(model, Attributes.VariableEncodingBits(), x[1], 10)
MOI.set(model, Attributes.VariableEncodingBits(), x[2], 20)

@test MOI.get(model, Attributes.VariableEncodingBits(), x[1]) == 1
@test MOI.get(model, Attributes.VariableEncodingBits(), x[2]) == 2
@test MOI.get(model, Attributes.VariableEncodingBits(), x[1]) == 10
@test MOI.get(model, Attributes.VariableEncodingBits(), x[2]) == 20

# Variable Encoding Penalty
@test MOI.get(model, Attributes.VariableEncodingPenaltyHint(), x[1]) === nothing
Expand Down Expand Up @@ -297,20 +297,28 @@ function test_interface_moi()
MOI.optimize!(model)

let virtual_model = model.model.optimizer
@test MOI.get(virtual_model, Attributes.Architecture()) isa SuperArchitecture
@test MOI.get(virtual_model, Attributes.Architecture()).super === true

@test MOI.get(virtual_model, Attributes.Optimization()) === 3
@test MOI.get(virtual_model, Attributes.Optimization()) == 3
@test Attributes.optimization(virtual_model) == 3

@test MOI.get(virtual_model, Attributes.Discretize()) === true
@test Attributes.discretize(virtual_model) === true

@test MOI.get(virtual_model, Attributes.Quadratize()) === true
@test Attributes.quadratize(virtual_model) === true

@test MOI.get(virtual_model, Attributes.Warnings()) === false
@test Attributes.warnings(virtual_model) === false

@test MOI.get(virtual_model, Attributes.Architecture()) isa SuperArchitecture
@test MOI.get(virtual_model, Attributes.Architecture()).super === true
@test Attributes.architecture(virtual_model) isa SuperArchitecture
@test Attributes.architecture(virtual_model).super === true

@test MOI.get(virtual_model, Attributes.QuadratizationMethod()) isa PBO.PTR_BG
@test MOI.get(virtual_model, Attributes.StableQuadratization()) === true

@test MOI.get(virtual_model, Attributes.DefaultVariableEncodingMethod()) isa Encoding.Unary
@test MOI.get(virtual_model, Attributes.VariableEncodingMethod(), x[1]) isa Encoding.Arithmetic
@test MOI.get(virtual_model, Attributes.VariableEncodingMethod(), x[1]) isa Encoding.OneHot
@test MOI.get(virtual_model, Attributes.VariableEncodingMethod(), x[2]) isa Encoding.Arithmetic
@test MOI.get(virtual_model, Attributes.VariableEncodingMethod(), x[3]) === nothing

Expand All @@ -320,21 +328,37 @@ function test_interface_moi()
@test MOI.get(virtual_model, Attributes.VariableEncodingATol(), x[3]) === nothing

@test MOI.get(virtual_model, Attributes.DefaultVariableEncodingBits()) == 3
@test MOI.get(virtual_model, Attributes.VariableEncodingBits(), x[1]) == 1
@test MOI.get(virtual_model, Attributes.VariableEncodingBits(), x[2]) == 2
@test MOI.get(virtual_model, Attributes.VariableEncodingBits(), x[1]) == 10
@test MOI.get(virtual_model, Attributes.VariableEncodingBits(), x[2]) == 20
@test MOI.get(virtual_model, Attributes.VariableEncodingBits(), x[3]) === nothing

@test MOI.get(virtual_model, Attributes.VariableEncodingPenaltyHint(), x[1]) == -1.0
@test Attributes.variable_encoding_penalty_hint(virtual_model, x[1]) == -1.0
@test MOI.get(virtual_model, Attributes.VariableEncodingPenaltyHint(), x[2]) === nothing
@test Attributes.variable_encoding_penalty_hint(virtual_model, x[2]) === nothing
@test MOI.get(virtual_model, Attributes.VariableEncodingPenaltyHint(), x[3]) === nothing
@test Attributes.variable_encoding_penalty_hint(virtual_model, x[3]) === nothing

@test MOI.get(virtual_model, Attributes.VariableEncodingPenalty(), x[1]) == -1.0
@test Attributes.variable_encoding_penalty(virtual_model, x[1]) == -1.0
@test MOI.get(virtual_model, Attributes.VariableEncodingPenalty(), x[2]) === nothing
@test Attributes.variable_encoding_penalty(virtual_model, x[2]) === nothing
@test MOI.get(virtual_model, Attributes.VariableEncodingPenalty(), x[3]) === nothing
@test Attributes.variable_encoding_penalty(virtual_model, x[3]) === nothing

@test MOI.get(virtual_model, Attributes.ConstraintEncodingPenaltyHint(), c[1]) == -10.0
@test MOI.get(virtual_model, Attributes.ConstraintEncodingPenaltyHint(), c[2]) === nothing

@test MOI.get(virtual_model, Attributes.ConstraintEncodingPenalty(), c[1]) == -10.0
@test MOI.get(virtual_model, Attributes.ConstraintEncodingPenalty(), c[2]) == -4.0
@test MOI.get(virtual_model, Attributes.ConstraintEncodingPenalty(), c[2]) <= 0.0

@test MOI.get(model, Attributes.SlackVariableEncodingPenalty(), c[1]) == -100.0

@test MOI.get(virtual_model, Attributes.CompilationStatus()) === MOI.LOCALLY_SOLVED
@test Attributes.compilation_status(virtual_model) === MOI.LOCALLY_SOLVED

@test MOI.get(virtual_model, Attributes.CompilationTime()) > 0.0
@test Attributes.compilation_time(virtual_model) > 0.0
end
end
end
Expand Down

0 comments on commit af533a0

Please sign in to comment.