diff --git a/src/Plasmo.jl b/src/Plasmo.jl index 1e192e2..ee624f3 100644 --- a/src/Plasmo.jl +++ b/src/Plasmo.jl @@ -110,7 +110,9 @@ export OptiGraph, # other functions set_jump_model, - extract_variables + extract_variables, + is_separable, + extract_separable_terms include("core_types.jl") @@ -140,7 +142,7 @@ include("graph_functions/topology.jl") include("graph_functions/partition.jl") -include("utilities.jl") +include("utils.jl") # extensions function __init__() diff --git a/src/optigraph.jl b/src/optigraph.jl index 4ee979b..8a05171 100644 --- a/src/optigraph.jl +++ b/src/optigraph.jl @@ -1091,27 +1091,100 @@ function has_node_objective(graph::OptiGraph) return false end +""" + node_objective_type(graph::OptiGraph) + +Return the most complex objective type among nodes in the given `graph`. The order of +complexity is: Nonlinear, Quadratic, Linear. +""" +function node_objective_type(graph::OptiGraph) + if !(has_node_objective(graph)) + return nothing + end + + obj_types = JuMP.objective_function_type.(all_nodes(graph)) + if JuMP.GenericNonlinearExpr{NodeVariableRef} in obj_types + return JuMP.GenericNonlinearExpr{NodeVariableRef} + elseif JuMP.GenericQuadExpr{Float64,NodeVariableRef} in obj_types + return JuMP.GenericQuadExpr{Float64,NodeVariableRef} + elseif JuMP.GenericAffExpr{Float64,NodeVariableRef} in obj_types + return JuMP.GenericAffExpr{Float64,NodeVariableRef} + elseif NodeVariableRef in obj_types + return JuMP.GenericAffExpr{Float64,NodeVariableRef} + else + error("Could not determine node objective type") + end +end + """ set_to_node_objectives(graph::OptiGraph) Set the `graph` objective to the summation of all of its optinode objectives. Assumes the -objective sense is an MOI.MIN_SENSE and adjusts the signs of node objective functions -accordingly. +objective sense is an MOI.MIN_SENSE and accounts for the sense of node objectives +accordingly. + +Note that building nonlinear objective functions is much slower than +linear or quadratic because nonlienar expressions cannot be updated in place. """ function set_to_node_objectives(graph::OptiGraph) - obj = 0 + if has_node_objective(graph) + node_obj_type = node_objective_type(graph) + _set_to_node_objectives(graph, node_obj_type) + end + return nothing +end + +function _set_to_node_objectives( + graph::OptiGraph, + obj_type::Type{T} where T <: Union{ + JuMP.GenericAffExpr{Float64, NodeVariableRef}, + JuMP.GenericQuadExpr{Float64, NodeVariableRef} + } +) + objective = zero(obj_type) for node in all_nodes(graph) if has_objective(node) sense = JuMP.objective_sense(node) == MOI.MAX_SENSE ? -1 : 1 - obj += sense * JuMP.objective_function(node) + JuMP.add_to_expression!(objective, JuMP.objective_function(node), sense) end end - if obj != 0 - @objective(graph, Min, obj) + @objective(graph, Min, objective) + return +end + +function _set_to_node_objectives( + graph::OptiGraph, + obj_type::Type{T} where T <: JuMP.GenericNonlinearExpr{NodeVariableRef} +) + objective = zero(obj_type) + for node in all_nodes(graph) + if has_objective(node) + sense = JuMP.objective_sense(node) == MOI.MAX_SENSE ? -1 : 1 + objective += *(sense, objective_function(node)) + end end - return nothing + @objective(graph, Min, objective) + return end +# TODO +""" + set_node_objectives_from_graph(graph::OptiGraph) + +Set the objective of each node within `graph` by parsing and separating the graph objective +function. Note this only works if the objective function is separable over the nodes in +`graph`. +""" +# function set_node_objectives_from_graph(graph::OptiGraph) +# obj = objective_function(graph) +# if !(is_separable(obj)) +# error("Cannot set node objectives from graph. It is not separable across nodes.") +# end +# sense = objective_sense(graph) +# _set_node_objectives_from_graph(obj, sense) +# return nothing +# end + """ JuMP.objective_function(graph::OptiGraph) diff --git a/src/optinode.jl b/src/optinode.jl index 11a0e11..b413799 100644 --- a/src/optinode.jl +++ b/src/optinode.jl @@ -241,7 +241,15 @@ function JuMP.set_objective_sense(node::OptiNode, sense::MOI.OptimizationSense) end function JuMP.objective_function(node::OptiNode) - return JuMP.object_dictionary(node)[(node, :objective_function)] + if haskey(JuMP.object_dictionary(node), (node,:objective_function)) + return JuMP.object_dictionary(node)[(node, :objective_function)] + else + return nothing + end +end + +function JuMP.objective_function_type(node::OptiNode) + return typeof(objective_function(node)) end function JuMP.objective_sense(node::OptiNode) diff --git a/src/utilities.jl b/src/utilities.jl deleted file mode 100644 index 5b0cb8a..0000000 --- a/src/utilities.jl +++ /dev/null @@ -1,39 +0,0 @@ -### Utilities for querying variables used in constraints - -function extract_variables(func) - return _extract_variables(func) -end - -function _extract_variables(func::NodeVariableRef) - return [func] -end - -function _extract_variables(ref::EdgeConstraintRef) - func = JuMP.jump_function(JuMP.constraint_object(ref)) - return _extract_variables(func) -end - -function _extract_variables(func::JuMP.GenericAffExpr) - return collect(keys(func.terms)) -end - -function _extract_variables(func::JuMP.GenericQuadExpr) - quad_vars = vcat([[term[2]; term[3]] for term in JuMP.quad_terms(func)]...) - aff_vars = _extract_variables(func.aff) - return union(quad_vars, aff_vars) -end - -function _extract_variables(func::JuMP.GenericNonlinearExpr) - vars = NodeVariableRef[] - for i in 1:length(func.args) - func_arg = func.args[i] - if func_arg isa Number - continue - elseif typeof(func_arg) == NodeVariableRef - push!(vars, func_arg) - else - append!(vars, _extract_variables(func_arg)) - end - end - return vars -end diff --git a/src/utils.jl b/src/utils.jl new file mode 100644 index 0000000..0f961ea --- /dev/null +++ b/src/utils.jl @@ -0,0 +1,225 @@ +""" + extract_variables(func) + +Return the variables contained within the given expression or reference. +""" +function extract_variables(func) + return _extract_variables(func) +end + +function _extract_variables(func::NodeVariableRef) + return [func] +end + +function _extract_variables(ref::EdgeConstraintRef) + func = JuMP.jump_function(JuMP.constraint_object(ref)) + return _extract_variables(func) +end + +function _extract_variables(func::JuMP.GenericAffExpr) + return collect(keys(func.terms)) +end + +function _extract_variables(func::JuMP.GenericQuadExpr) + quad_vars = vcat([[term[2]; term[3]] for term in JuMP.quad_terms(func)]...) + aff_vars = _extract_variables(func.aff) + return union(quad_vars, aff_vars) +end + +function _extract_variables(func::JuMP.GenericNonlinearExpr) + vars = NodeVariableRef[] + for i in 1:length(func.args) + func_arg = func.args[i] + if func_arg isa Number + continue + elseif typeof(func_arg) == NodeVariableRef + push!(vars, func_arg) + else + append!(vars, _extract_variables(func_arg)) + end + end + return vars +end + +function _first_variable(func::JuMP.GenericNonlinearExpr) + for i in 1:length(func.args) + func_arg = func.args[i] + if func_arg isa Number + continue + elseif typeof(func_arg) == NodeVariableRef + return func_arg + else + return _first_variable(func_arg) + end + end +end + +""" + is_separable(func) + +Return whether the given function is separable across optinodes. +""" +function is_separable(func::Union{Number,JuMP.AbstractJuMPScalar}) + return _is_separable(func) +end + +function _is_separable(::Number) + return true +end + +function _is_separable(::NodeVariableRef) + return true +end + +function _is_separable(::JuMP.GenericAffExpr{<:Number,NodeVariableRef}) + return true +end + +function _is_separable(func::JuMP.GenericQuadExpr{<:Number,NodeVariableRef}) + # check each term; make sure they are all on the same subproblem + for term in Plasmo.quad_terms(func) + # term = (coefficient, variable_1, variable_2) + node1 = get_node(term[2]) + node2 = get_node(term[3]) + + # if any term is split across nodes, the objective is not separable + if node1 != node2 + return false + end + end + return true +end + +function _is_separable(func::JuMP.GenericNonlinearExpr{NodeVariableRef}) + # check for a constant multiplier + if func.head == :* + if !(func.args[1] isa Number) + return false + end + end + + # if not additive, check if term is separable + if func.head != :+ && func.head != :- + vars = extract_variables(func) + nodes = get_node.(vars) + if length(unique(nodes)) > 1 + return false + end + end + + # check each argument + for arg in func.args + if !(is_separable(arg)) + return false + end + end + return true +end + +""" + extract_separable_terms(func::JuMP.AbstractJuMPScalar,graph::OptiGraph) + +Extract the separable terms contained within `graph`. +NOTE: Nonlinear objectives are not completely tested and may return incorrect results. +""" +function extract_separable_terms(func::JuMP.AbstractJuMPScalar, graph::OptiGraph) + !is_separable(func) && error("Cannont extract terms. Function is not separable.") + return _extract_separable_terms(func, graph) +end + +function _extract_separable_terms( + func::Union{Number,Plasmo.NodeVariableRef}, + graph::OptiGraph +) + return func +end + +function _extract_separable_terms( + func::JuMP.GenericAffExpr{<:Number,NodeVariableRef}, + graph::OptiGraph +) + node_terms = OrderedDict{OptiNode,Vector{JuMP.GenericAffExpr{<:Number,NodeVariableRef}}}() + nodes = Plasmo.collect_nodes(func) + nodes = intersect(nodes, all_nodes(graph)) + for node in nodes + node_terms[node] = Vector{JuMP.GenericAffExpr{<:Number,NodeVariableRef}}() + end + + for term in Plasmo.linear_terms(func) + node = get_node(term[2]) + push!(node_terms[node], term[1]*term[2]) + end + + return node_terms +end + +function _extract_separable_terms( + func::JuMP.GenericQuadExpr{<:Number,NodeVariableRef}, + graph::OptiGraph +) + node_terms = OrderedDict{OptiNode,Vector{JuMP.GenericQuadExpr{<:Number,NodeVariableRef}}}() + nodes = collect_nodes(func) + nodes = intersect(nodes, all_nodes(graph)) + for node in nodes + node_terms[node] = Vector{JuMP.GenericQuadExpr{<:Number,NodeVariableRef}}() + end + + for term in JuMP.quad_terms(func) + node = get_node(term[2]) + push!(node_terms[node], term[1]*term[2]*term[3]) + end + + for term in JuMP.linear_terms(func) + node = get_node(term[2]) + push!(node_terms[node], term[1]*term[2]) + end + + return node_terms +end + +# NOTE: method needs improvement. does not cover all separable cases. +function _extract_separable_terms( + func::JuMP.GenericNonlinearExpr{NodeVariableRef}, + graph::OptiGraph +) + node_terms = OrderedDict{OptiNode,Vector{JuMP.GenericNonlinearExpr{NodeVariableRef}}}() + nodes = collect_nodes(func) + nodes = intersect(nodes, all_nodes(graph)) + for node in nodes + node_terms[node] = Vector{JuMP.GenericNonlinearExpr{NodeVariableRef}}() + end + + _extract_separable_terms(func, node_terms) + + return node_terms +end + +function _extract_separable_terms( + func::JuMP.GenericNonlinearExpr{NodeVariableRef}, + node_terms::OrderedDict{OptiNode,Vector{JuMP.GenericNonlinearExpr{NodeVariableRef}}} +) + # check for a constant multiplier + multiplier = 1.0 + if func.head == :* + if func.args[1] isa Number + multiplier = func.args[1] + end + end + + # if not additive, get node for this term + if func.head != :+ && func.head != :- + var = _first_variable(func) + node = get_node(var) + push!(node_terms[node], multiplier*func) + else + # check each argument + for arg in func.args + if arg isa Number + continue + end + _extract_separable_terms(arg, node_terms) + end + end + + return nothing +end \ No newline at end of file