Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
hdavid16 committed Oct 16, 2023
1 parent 0daa695 commit c788b10
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 9 deletions.
19 changes: 12 additions & 7 deletions src/hull.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ function _disaggregate_variable(model::Model, lvref::LogicalVariableRef, vref::V
#get binary indicator variable
bvref = _indicator_to_binary(model)[lvref]
#temp storage
if !haskey(method.disjunction_variables, vref)
if !haskey(method.disjunction_variables, vref) #NOTE: not needed because _Hull disjunction_variables is initialized with all the variables in the disjunction
method.disjunction_variables[vref] = Vector{VariableRef}()
end
push!(method.disjunction_variables[vref], dvref)
Expand Down Expand Up @@ -66,7 +66,7 @@ end
# variable
function _disaggregate_expression(model::Model, vref::VariableRef, bvref::VariableRef, method::_Hull)
if is_binary(vref) || !haskey(method.disjunct_variables, (vref, bvref)) #keep any binary variables or nested disaggregated variables unchanged
return vref
return vref #NOTE: not needed because nested constraint of the form `vref in MOI.AbstractScalarSet` gets reformulated to an affine expression.
else #replace with disaggregated form
return method.disjunct_variables[vref, bvref]
end
Expand Down Expand Up @@ -105,17 +105,20 @@ function _disaggregate_nl_expression(model::Model, c::Number, ::VariableRef, met
end
# variable in NonlinearExpr
function _disaggregate_nl_expression(model::Model, vref::VariableRef, bvref::VariableRef, method::_Hull)
ϵ = method.value
dvref = method.disjunct_variables[vref, bvref]
new_var = dvref / ((1-ϵ)*bvref+ϵ)
return new_var
if is_binary(vref) || !haskey(method.disjunct_variables, (vref, bvref)) #keep any binary variables or nested disaggregated variables unchanged
return vref
else #replace with disaggregated form
ϵ = method.value
dvref = method.disjunct_variables[vref, bvref]
return dvref / ((1-ϵ)*bvref+ϵ)
end
end
# affine expression in NonlinearExpr
function _disaggregate_nl_expression(model::Model, aff::AffExpr, bvref::VariableRef, method::_Hull)
new_expr = aff.constant
ϵ = method.value
for (vref, coeff) in aff.terms
if is_binary(vref) #keep any binary variables undisaggregated
if is_binary(vref) || !haskey(method.disjunct_variables, (vref, bvref)) #keep any binary variables or nested disaggregated variables unchanged
dvref = vref
else #replace other vars with disaggregated form
dvref = method.disjunct_variables[vref, bvref]
Expand All @@ -125,6 +128,8 @@ function _disaggregate_nl_expression(model::Model, aff::AffExpr, bvref::Variable
return new_expr
end
# quadratic expression in NonlinearExpr
# TODO review what happens when there are bilinear terms with binary variables involved since these are not being disaggregated
# (e.g., complementarity constraints; though likely irrelevant)...
function _disaggregate_nl_expression(model::Model, quad::QuadExpr, bvref::VariableRef, method::_Hull)
#get affine part
new_expr = _disaggregate_nl_expression(model, quad.aff, bvref, method)
Expand Down
2 changes: 1 addition & 1 deletion src/logic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ function _reformulate_proposition(model::Model, lexpr::_LogicalExpr)
end
elseif expr.head in (:||, :!) && all(_isa_literal.(expr.args))
_add_reformulated_proposition(model, expr)
else
else #NOTE: should never enter the `else` section
error("Expression $expr was not converted to proper Conjunctive Normal Form.")
end
end
Expand Down
53 changes: 52 additions & 1 deletion test/constraints/hull.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,10 @@ function test_disaggregate_variables()
@variable(model, 10 <= x <= 100)
@variable(model, y, Bin)
@variable(model, z, Logical)
vrefs = Set([x,y])
vrefs = Set{VariableRef}() #initialize empty set to check if method.disjunct_variables has variables added to it in _disaggregate_variable call
DP._reformulate_logical_variables(model)
method = DP._Hull(Hull(1e-3, Dict(x => (0., 100.))), vrefs)
vrefs = Set([x,y])
DP._disaggregate_variables(model, z, vrefs, method)

refvars = DP._reformulation_variables(model)
Expand Down Expand Up @@ -76,6 +77,38 @@ function test_aggregate_variable()
@test refcons[1].set == MOI.EqualTo(0.)
end

function test_disaggregate_expression_var_binary()
model = GDPModel()
@variable(model, x, Bin)
@variable(model, z, Logical)
DP._reformulate_logical_variables(model)
bvrefs = DP._indicator_to_binary(model)

vrefs = Set([x])
method = DP._Hull(Hull(1e-3, Dict(x => (0., 1.))), vrefs)
DP._disaggregate_variables(model, z, vrefs, method)
@test isnothing(variable_by_name(model, "x_z"))

refexpr = DP._disaggregate_expression(model, x, bvrefs[z], method)
@test refexpr == x
end

function test_disaggregate_expression_var()
model = GDPModel()
@variable(model, 10 <= x <= 100)
@variable(model, z, Logical)
DP._reformulate_logical_variables(model)
bvrefs = DP._indicator_to_binary(model)

vrefs = Set([x])
method = DP._Hull(Hull(1e-3, Dict(x => (0., 100.))), vrefs)
DP._disaggregate_variables(model, z, vrefs, method)

refexpr = DP._disaggregate_expression(model, x, bvrefs[z], method)
x_z = variable_by_name(model, "x_z")
@test refexpr == x_z
end

function test_disaggregate_expression_affine()
model = GDPModel()
@variable(model, 10 <= x <= 100)
Expand Down Expand Up @@ -131,6 +164,21 @@ function test_disaggregate_nl_expression_c()
@test refexpr == 1
end

function test_disaggregate_nl_expression_var_binary()
model = GDPModel()
@variable(model, x, Bin)
@variable(model, z, Logical)
DP._reformulate_logical_variables(model)
bvrefs = DP._indicator_to_binary(model)

vrefs = Set([x])
method = DP._Hull(Hull(1e-3, Dict(x => (0., 1.))), vrefs)
DP._disaggregate_variables(model, z, vrefs, method)

refexpr = DP._disaggregate_nl_expression(model, x, bvrefs[z], method)
@test refexpr == x
end

function test_disaggregate_nl_expression_var()
model = GDPModel()
@variable(model, 10 <= x <= 100)
Expand Down Expand Up @@ -548,9 +596,12 @@ end
test_query_variable_bounds_error2()
test_disaggregate_variables()
test_aggregate_variable()
test_disaggregate_expression_var_binary()
test_disaggregate_expression_var()
test_disaggregate_expression_affine()
test_disaggregate_expression_quadratic()
test_disaggregate_nl_expression_c()
test_disaggregate_nl_expression_var_binary()
test_disaggregate_nl_expression_var()
test_disaggregate_nl_expression_aff()
test_disaggregate_nl_expression_quad()
Expand Down

0 comments on commit c788b10

Please sign in to comment.