From 7a081f433b8ccd3fc41fcdb133a2da846ab81867 Mon Sep 17 00:00:00 2001 From: Oscar Dowson Date: Fri, 17 Jan 2025 08:49:38 +1300 Subject: [PATCH] Fix Base.copy for ScalarNonlinearFunction (#2612) --- src/functions.jl | 24 +++++++++++++++++++++++- test/functions.jl | 42 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 1 deletion(-) diff --git a/src/functions.jl b/src/functions.jl index bc658737c4..b6193e0192 100644 --- a/src/functions.jl +++ b/src/functions.jl @@ -360,8 +360,30 @@ struct ScalarNonlinearFunction <: AbstractScalarFunction end end +# copy() doesn't recursively copy the children, and deepcopy seems to have a +# performance problem for deeply nested structs. function Base.copy(f::ScalarNonlinearFunction) - return ScalarNonlinearFunction(f.head, copy(f.args)) + stack, result_stack = Any[f], Any[] + while !isempty(stack) + arg = pop!(stack) + if arg isa ScalarNonlinearFunction + # We need some sort of hint so that the next time we see this on the + # stack we evaluate it using the args in `result_stack`. One option + # would be a custom type. Or we can just wrap in (,) and then check + # for a Tuple, which isn't (curretly) a valid argument. + push!(stack, (arg,)) + for child in arg.args + push!(stack, child) + end + elseif arg isa Tuple{<:ScalarNonlinearFunction} + result = only(arg) + args = Any[pop!(result_stack) for i in 1:length(result.args)] + push!(result_stack, ScalarNonlinearFunction(result.head, args)) + else + push!(result_stack, copy(arg)) + end + end + return only(result_stack) end constant(f::ScalarNonlinearFunction, ::Type{T} = Float64) where {T} = zero(T) diff --git a/test/functions.jl b/test/functions.jl index 503dc50da5..d6143d161b 100644 --- a/test/functions.jl +++ b/test/functions.jl @@ -469,6 +469,48 @@ function test_convert_VectorAffineFunction_VectorQuadraticFunction() return end +function test_copy_ScalarNonlinearFunction() + N = 10_000 + x = MOI.VariableIndex.(1:N) + f1 = MOI.ScalarNonlinearFunction(:^, Any[x[1], 1]) + for i in 2:N + g = MOI.ScalarNonlinearFunction(:^, Any[x[i], 1]) + f1 = MOI.ScalarNonlinearFunction(:+, Any[f1, g]) + end + f2 = MOI.ScalarNonlinearFunction(:^, Any[x[1], 1]) + for i in 2:N + g = MOI.ScalarNonlinearFunction(:^, Any[x[i], 1]) + f2 = MOI.ScalarNonlinearFunction(:+, Any[f2, g]) + end + f_copy = copy(f1) + @test ≈(f_copy, f2) + f1.args[2].args[2] = 2.0 # x[1]^1 --> x[1]^2 + @test !isapprox(f_copy, f1) + @test isapprox(f_copy, f2) + return +end + +function test_copy_ScalarNonlinearFunction_with_arg() + N = 10_000 + x = MOI.VariableIndex.(1:N) + f1 = 1.0 * x[1] + 1.0 + for i in 2:N + g = f1 = Float64(i) * x[i] + Float64(i) + f1 = MOI.ScalarNonlinearFunction(:+, Any[f1, g]) + end + f2 = 1.0 * x[1] + 1.0 + for i in 2:N + g = f2 = Float64(i) * x[i] + Float64(i) + f2 = MOI.ScalarNonlinearFunction(:+, Any[f2, g]) + end + f_copy = copy(f1) + @test ≈(f_copy, f2) + f1.args[2].constant += 1 + @test !isapprox(f_copy, f1) + @test isapprox(f_copy, f2) + return +end + function runtests() for name in names(@__MODULE__; all = true) if startswith("$name", "test_")