diff --git a/docs/src/manual/nonlinear.md b/docs/src/manual/nonlinear.md index afd11e58876..f71b5f03c07 100644 --- a/docs/src/manual/nonlinear.md +++ b/docs/src/manual/nonlinear.md @@ -153,6 +153,54 @@ julia> sin(sin(1.0)) 0.7456241416655579 ``` +## Common subexpressions + +JuMP does not perform [common subexpression elimination](https://en.wikipedia.org/wiki/Common_subexpression_elimination). +Instead, if you re-use +an expression in multiple places, JuMP will insert a copy of the expression. + +JuMP's lack of common subexpression elimination is a common cause of performance +problems, particularly in nonlinear models with a pattern like +`sum(t / common_term for t in terms)`. One example is the logistic loss: + +```jldoctest +julia> model = Model(); + +julia> @variable(model, x[1:2]); + +julia> @expression(model, expr, sum(exp.(x))) +0.0 + exp(x[2]) + exp(x[1]) + +julia> @objective(model, Min, sum(exp(x[i]) / expr for i in 1:2)) +(exp(x[1]) / (0.0 + exp(x[2]) + exp(x[1]))) + (exp(x[2]) / (0.0 + exp(x[2]) + exp(x[1]))) +``` +In this model, JuMP will compute the value (and derivatives) of the denominator +twice, without realizing that it is the same expression. + +As a work-around, create a new [`@variable`](@ref) and use an `==` +[`@constraint`](@ref) to constrain the value of the variable to the +subexpression. + +```jldoctest +julia> model = Model(); + +julia> @variable(model, x[1:2]); + +julia> @variable(model, expr); + +julia> @constraint(model, expr == sum(exp.(x))) +expr - (0.0 + exp(x[2]) + exp(x[1])) = 0 + +julia> @objective(model, Min, sum(exp(x[i]) / expr for i in 1:2)) +(exp(x[1]) / expr) + (exp(x[2]) / expr) +``` + +The reason JuMP does not perform common subexpression elimination automatically +is for simplicity, and because there is a trade-off: for simple expressions, the +extra complexity of detecting and merging common subexpressions may outweigh +the cost of computing them independently. Instead, we leave it to the user to +decide which expressions to extract as common subexpressions. + ## Automatic differentiation JuMP computes first- and second-order derivatives using sparse reverse-mode