Skip to content

Commit

Permalink
Fix Base.copy for ScalarNonlinearFunction (#2612)
Browse files Browse the repository at this point in the history
  • Loading branch information
odow authored Jan 16, 2025
1 parent 48ac449 commit 7a081f4
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 1 deletion.
24 changes: 23 additions & 1 deletion src/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,30 @@ struct ScalarNonlinearFunction <: AbstractScalarFunction
end
end

# copy() doesn't recursively copy the children, and deepcopy seems to have a
# performance problem for deeply nested structs.
function Base.copy(f::ScalarNonlinearFunction)
return ScalarNonlinearFunction(f.head, copy(f.args))
stack, result_stack = Any[f], Any[]
while !isempty(stack)
arg = pop!(stack)
if arg isa ScalarNonlinearFunction
# We need some sort of hint so that the next time we see this on the
# stack we evaluate it using the args in `result_stack`. One option
# would be a custom type. Or we can just wrap in (,) and then check
# for a Tuple, which isn't (curretly) a valid argument.
push!(stack, (arg,))
for child in arg.args
push!(stack, child)
end
elseif arg isa Tuple{<:ScalarNonlinearFunction}
result = only(arg)
args = Any[pop!(result_stack) for i in 1:length(result.args)]
push!(result_stack, ScalarNonlinearFunction(result.head, args))
else
push!(result_stack, copy(arg))
end
end
return only(result_stack)
end

constant(f::ScalarNonlinearFunction, ::Type{T} = Float64) where {T} = zero(T)
Expand Down
42 changes: 42 additions & 0 deletions test/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,48 @@ function test_convert_VectorAffineFunction_VectorQuadraticFunction()
return
end

function test_copy_ScalarNonlinearFunction()
N = 10_000
x = MOI.VariableIndex.(1:N)
f1 = MOI.ScalarNonlinearFunction(:^, Any[x[1], 1])
for i in 2:N
g = MOI.ScalarNonlinearFunction(:^, Any[x[i], 1])
f1 = MOI.ScalarNonlinearFunction(:+, Any[f1, g])
end
f2 = MOI.ScalarNonlinearFunction(:^, Any[x[1], 1])
for i in 2:N
g = MOI.ScalarNonlinearFunction(:^, Any[x[i], 1])
f2 = MOI.ScalarNonlinearFunction(:+, Any[f2, g])
end
f_copy = copy(f1)
@test (f_copy, f2)
f1.args[2].args[2] = 2.0 # x[1]^1 --> x[1]^2
@test !isapprox(f_copy, f1)
@test isapprox(f_copy, f2)
return
end

function test_copy_ScalarNonlinearFunction_with_arg()
N = 10_000
x = MOI.VariableIndex.(1:N)
f1 = 1.0 * x[1] + 1.0
for i in 2:N
g = f1 = Float64(i) * x[i] + Float64(i)
f1 = MOI.ScalarNonlinearFunction(:+, Any[f1, g])
end
f2 = 1.0 * x[1] + 1.0
for i in 2:N
g = f2 = Float64(i) * x[i] + Float64(i)
f2 = MOI.ScalarNonlinearFunction(:+, Any[f2, g])
end
f_copy = copy(f1)
@test (f_copy, f2)
f1.args[2].constant += 1
@test !isapprox(f_copy, f1)
@test isapprox(f_copy, f2)
return
end

function runtests()
for name in names(@__MODULE__; all = true)
if startswith("$name", "test_")
Expand Down

0 comments on commit 7a081f4

Please sign in to comment.