From 0f7c083a0daf7b0e572e64860df642788b4d9cbf Mon Sep 17 00:00:00 2001 From: nathanemac <91251698+nathanemac@users.noreply.github.com> Date: Thu, 19 Dec 2024 20:39:03 +0100 Subject: [PATCH] modifs to be compatible with RegularizedOptimization.jl --- src/libproxtv.jl | 1 - src/proxtv_utils.jl | 18 +++++++++++++++--- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/src/libproxtv.jl b/src/libproxtv.jl index 0fdedfa..6e0f53e 100644 --- a/src/libproxtv.jl +++ b/src/libproxtv.jl @@ -72,7 +72,6 @@ function PN_LPp(y, lambda, x, info, n, p, ws, positive, ctx, callback) )::Int32 end - # overloaded PN_LPp function with less inputs function PN_LPp(y, lambda, x, p, objGap) n = length(y) # works for nD signals diff --git a/src/proxtv_utils.jl b/src/proxtv_utils.jl index 572c433..4099527 100644 --- a/src/proxtv_utils.jl +++ b/src/proxtv_utils.jl @@ -9,9 +9,10 @@ mutable struct AlgorithmContextCallback shift::AbstractVector{Float64} s_k_unshifted::Vector{Float64} dualGap::Float64 + prox_stats # for total number of iterations in ir2n, ir2 and prox. end -function AlgorithmContextCallback(;hk=0.0, mk = x -> x, κξ = 0.0, shift = zeros(0), s_k_unshifted = zeros(0), dualGap = 0.0) - AlgorithmContextCallback(hk, mk, κξ, shift, s_k_unshifted, dualGap) +function AlgorithmContextCallback(;hk=0.0, mk = x -> x, κξ = 0.0, shift = zeros(0), s_k_unshifted = zeros(0), dualGap = 0.0, prox_stats = [0.0, [], []]) + AlgorithmContextCallback(hk, mk, κξ, shift, s_k_unshifted, dualGap, prox_stats) end @@ -79,6 +80,9 @@ function prox!( freeWorkspace(ws) + # add the number of iterations in prox to the context object + push!(context.prox_stats[3], info[1]) + return y end @@ -222,6 +226,9 @@ function prox!( freeWorkspace(ws) + # add the number of iterations in prox to the context object + push!(context.prox_stats[3], info[1]) + return y end @@ -302,6 +309,9 @@ function prox!( freeWorkspace(ws) + # add the number of iterations in prox to the context object + push!(context.prox_stats[3], info[1]) + return y end @@ -431,6 +441,9 @@ function prox!(y::AbstractArray, ψ::ShiftedNormTVp, q::AbstractArray, ν::Real, freeWorkspace(ws) + # add the number of iterations in prox to the context object + push!(context.prox_stats[3], info[1]) + return y end @@ -483,7 +496,6 @@ function prox!(y, ψ::Union{InexactShiftedProximableFunction, ShiftedProximableF return prox!(y, ψ, q, ν) elseif ψ isa InexactShiftedProximableFunction # Call to inexact prox!() - println("Inexact prox!() called") return prox!(y, ψ, q, ν, ctx_ptr, callback) else