From 4c5d5a4457acae15629c4fbbf79aaf31fd15f404 Mon Sep 17 00:00:00 2001 From: nathanemac <91251698+nathanemac@users.noreply.github.com> Date: Fri, 6 Dec 2024 09:05:58 -0500 Subject: [PATCH 1/4] exporting useful objects in main --- src/ProxTV.jl | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/ProxTV.jl b/src/ProxTV.jl index 6a529eb..07e81f9 100644 --- a/src/ProxTV.jl +++ b/src/ProxTV.jl @@ -16,5 +16,13 @@ end include("libproxtv.jl") include("proxtv_utils.jl") - +export InexactShiftedProximableFunction, + NormLp, ShiftedNormLp, + NormTVp, ShiftedNormTVp, + prox!, shifted, shift!, + TVp_norm, + fun_name, fun_expr, fun_params, + AlgorithmContextCallback, + julia_callback, + callback_pointer end From ec6ba0be210c530649de71932c2360f6fca1f282 Mon Sep 17 00:00:00 2001 From: nathanemac <91251698+nathanemac@users.noreply.github.com> Date: Fri, 6 Dec 2024 09:46:05 -0500 Subject: [PATCH 2/4] exporting useful objects in main --- src/ProxTV.jl | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/ProxTV.jl b/src/ProxTV.jl index 07e81f9..3133d8f 100644 --- a/src/ProxTV.jl +++ b/src/ProxTV.jl @@ -12,17 +12,14 @@ function __init__() end end +export InexactShiftedProximableFunction +export NormLp, ShiftedNormLp, NormTVp, ShiftedNormTVp +export prox!, shifted, shift!, TVp_norm +export fun_name, fun_expr, fun_params +export AlgorithmContextCallback, julia_callback, callback_pointer + # main library functions include("libproxtv.jl") include("proxtv_utils.jl") -export InexactShiftedProximableFunction, - NormLp, ShiftedNormLp, - NormTVp, ShiftedNormTVp, - prox!, shifted, shift!, - TVp_norm, - fun_name, fun_expr, fun_params, - AlgorithmContextCallback, - julia_callback, - callback_pointer end From 5a9c5967501486749a8c18be3a8ee6d648662c92 Mon Sep 17 00:00:00 2001 From: nathanemac <91251698+nathanemac@users.noreply.github.com> Date: Fri, 6 Dec 2024 17:00:37 -0500 Subject: [PATCH 3/4] improve compilation, still work to do --- src/ProxTV.jl | 2 +- src/libproxtv.jl | 12 +++---- src/proxtv_utils.jl | 79 +++++++++++++++++++++------------------------ 3 files changed, 43 insertions(+), 50 deletions(-) diff --git a/src/ProxTV.jl b/src/ProxTV.jl index 3133d8f..17b28b1 100644 --- a/src/ProxTV.jl +++ b/src/ProxTV.jl @@ -12,11 +12,11 @@ function __init__() end end +export AlgorithmContextCallback export InexactShiftedProximableFunction export NormLp, ShiftedNormLp, NormTVp, ShiftedNormTVp export prox!, shifted, shift!, TVp_norm export fun_name, fun_expr, fun_params -export AlgorithmContextCallback, julia_callback, callback_pointer # main library functions include("libproxtv.jl") diff --git a/src/libproxtv.jl b/src/libproxtv.jl index ea6028d..0fdedfa 100644 --- a/src/libproxtv.jl +++ b/src/libproxtv.jl @@ -55,8 +55,8 @@ end # original PN_LPp function function PN_LPp(y, lambda, x, info, n, p, ws, positive, ctx, callback) - context = unsafe_pointer_to_objref(ctx)::AlgorithmContextCallback - objGap = context.dualGap + objGap = ctx.dualGap + ctx_ptr = Ptr{Cvoid}(pointer_from_objref(ctx)) @ccall libproxtv.PN_LPp( y::Ptr{Float64}, lambda::Float64, @@ -67,7 +67,7 @@ function PN_LPp(y, lambda, x, info, n, p, ws, positive, ctx, callback) ws::Ptr{Workspace}, positive::Int32, objGap::Float64, - Ref(ctx)::Ptr{Cvoid}, + ctx_ptr::Ptr{Cvoid}, callback::Ptr{Cvoid}, )::Int32 end @@ -139,8 +139,8 @@ end # original TV function function TV(y, lambda, x, info, n, p, ws, ctx, callback) - context = unsafe_pointer_to_objref(ctx)::AlgorithmContextCallback - objGap = context.dualGap + objGap = ctx.dualGap + ctx_ptr = Ptr{Cvoid}(pointer_from_objref(ctx)) @ccall libproxtv.TV( y::Ptr{Float64}, lambda::Float64, @@ -150,7 +150,7 @@ function TV(y, lambda, x, info, n, p, ws, ctx, callback) p::Float64, ws::Ptr{Workspace}, objGap::Float64, - Ref(ctx)::Ptr{Cvoid}, + ctx_ptr::Ptr{Cvoid}, callback::Ptr{Cvoid}, )::Int32 end diff --git a/src/proxtv_utils.jl b/src/proxtv_utils.jl index f23aed4..572c433 100644 --- a/src/proxtv_utils.jl +++ b/src/proxtv_utils.jl @@ -1,5 +1,20 @@ abstract type InexactShiftedProximableFunction end + +## Structure for callback function in iR2N +mutable struct AlgorithmContextCallback + hk::Float64 + mk::Function + κξ::Float64 + shift::AbstractVector{Float64} + s_k_unshifted::Vector{Float64} + dualGap::Float64 +end +function AlgorithmContextCallback(;hk=0.0, mk = x -> x, κξ = 0.0, shift = zeros(0), s_k_unshifted = zeros(0), dualGap = 0.0) + AlgorithmContextCallback(hk, mk, κξ, shift, s_k_unshifted, dualGap) +end + + ### NormLp and ShiftedNormLp Implementation """ @@ -45,7 +60,7 @@ function prox!( h::NormLp, q::AbstractArray, ν::Real, - ctx_ptr::Ptr{Cvoid}, + context::AlgorithmContextCallback, callback::Ptr{Cvoid}; ) @@ -60,7 +75,9 @@ function prox!( positive = Int32(all(v -> v >= 0, y) ? 1 : 0) - PN_LPp(q, lambda_scaled, y, info, n, h.p, ws, positive, ctx_ptr, callback) + PN_LPp(q, lambda_scaled, y, info, n, h.p, ws, positive, context, callback) + + freeWorkspace(ws) return y end @@ -169,11 +186,11 @@ function prox!( ψ::ShiftedNormLp, q::AbstractArray, ν::Real, - context, + context::AlgorithmContextCallback, callback::Ptr{Cvoid}; ) n = length(y) - ws = C_NULL # to avoid unexplained memory leaks + ws = newWorkspace(n) # to avoid unexplained memory leaks # Allocate info array (based on C++ code) info = zeros(Float64, 3) @@ -188,7 +205,6 @@ function prox!( x = zeros(n) positive = Int32(all(v -> v >= 0, y_shifted) ? 1 : 0) - if ψ.h.p == 1 PN_LP1(y_shifted, lambda_scaled, x, info, n) elseif ψ.h.p == 2 @@ -204,6 +220,8 @@ function prox!( # Store the result in y y .= s + freeWorkspace(ws) + return y end @@ -268,11 +286,11 @@ function prox!( h::NormTVp, q::AbstractArray, ν::Real, - ctx_ptr::Ptr{Cvoid}, + context::AlgorithmContextCallback, callback::Ptr{Cvoid}) n = length(y) - ws = C_NULL + ws = newWorkspace(n) # Allocate info array (based on C++ code) info = zeros(Float64, 3) @@ -280,7 +298,9 @@ function prox!( # Adjust λ by ν lambda_scaled = h.λ * ν - TV(q, lambda_scaled, y, info, n, h.p, ws, ctx_ptr, callback) + TV(q, lambda_scaled, y, info, n, h.p, ws, context, callback) + + freeWorkspace(ws) return y end @@ -384,9 +404,9 @@ Inputs: - `ctx_ptr`: Pointer to the context object. - `callback`: Pointer to the callback function. """ -function prox!(y::AbstractArray, ψ::ShiftedNormTVp, q::AbstractArray, ν::Real, context, callback::Ptr{Cvoid}) +function prox!(y::AbstractArray, ψ::ShiftedNormTVp, q::AbstractArray, ν::Real, context::AlgorithmContextCallback, callback::Ptr{Cvoid}) n = length(y) - ws = C_NULL + ws = newWorkspace(n) # Allocate info array (based on C++ code) info = zeros(Float64, 3) @@ -409,6 +429,8 @@ function prox!(y::AbstractArray, ψ::ShiftedNormTVp, q::AbstractArray, ν::Real, # Store the result in y y .= s + freeWorkspace(ws) + return y end @@ -457,44 +479,15 @@ Errors: function prox!(y, ψ::Union{InexactShiftedProximableFunction, ShiftedProximableFunction }, q, ν; ctx_ptr, callback) if ψ isa ShiftedProximableFunction - # Call to exact prox!() if dualGap is not defined + # Call to exact prox!() return prox!(y, ψ, q, ν) elseif ψ isa InexactShiftedProximableFunction - # Call to inexact prox!() if dualGap is defined + # Call to inexact prox!() + println("Inexact prox!() called") return prox!(y, ψ, q, ν, ctx_ptr, callback) + else error("Combination of ψ::$(typeof(ψ)) presence/lack of pointers is not a valid call to prox!. Please provide pointers for InexactShiftedProximableFunction or omit them for ShiftedProximableFunction.") end end - -############################################################################################################ -########################## CALLBACK FUNCTION ########################## -# Structure for callback function in iR2N -mutable struct AlgorithmContextCallback - hk::Float64 - mk::Function - κξ::Float64 - shift::AbstractVector{Float64} - s_k_unshifted::Vector{Float64} - dualGap::Float64 -end -function AlgorithmContextCallback(;hk=0.0, mk = x -> x, κξ = 0.0, shift = zeros(0), s_k_unshifted = zeros(0), dualGap = 0.0) - AlgorithmContextCallback(hk, mk, κξ, shift, s_k_unshifted, dualGap) -end - -function julia_callback(s_ptr::Ptr{Cdouble}, s_length::Csize_t, delta_k::Cdouble, ctx_ptr::Ptr{Cvoid})::Cint - s_k = unsafe_wrap(Vector{Float64}, s_ptr, s_length; own = false) - context = unsafe_pointer_to_objref(ctx_ptr)::AlgorithmContextCallback - - # In-place operation to avoid memory allocations - @. context.s_k_unshifted = s_k - context.shift - - # Computations without allocations - ξk = context.hk - context.mk(context.s_k_unshifted) + max(1, abs(context.hk)) * 10 * eps() - condition = delta_k ≤ (1 - context.κξ) / context.κξ * ξk - - return condition ? Int32(1) : Int32(0) -end - -callback_pointer = @cfunction(julia_callback, Cint, (Ptr{Cdouble}, Csize_t, Cdouble, Ptr{Cvoid})) From c15b02d961981f20d226fb8c8a13633b38513e2b Mon Sep 17 00:00:00 2001 From: nathanemac <91251698+nathanemac@users.noreply.github.com> Date: Mon, 9 Dec 2024 09:59:14 -0500 Subject: [PATCH 4/4] release v1.0.0 --- Project.toml | 2 +- src/ProxTV.jl | 6 ++++++ test/runtests.jl | 5 ++--- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index 6106afd..a68c060 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ProxTV" uuid = "925ea013-038b-5ab6-a1ab-e0849925e528" authors = ["Nathan Allaire and contributors"] -version = "0.3.0" +version = "1.0.0" [deps] JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" diff --git a/src/ProxTV.jl b/src/ProxTV.jl index 17b28b1..7b2da58 100644 --- a/src/ProxTV.jl +++ b/src/ProxTV.jl @@ -12,6 +12,12 @@ function __init__() end end +# import functions that we extend from ShiftedProximalOperators +import ShiftedProximalOperators.shift! +import ShiftedProximalOperators.shifted +import ShiftedProximalOperators.prox! + +# export our new functions export AlgorithmContextCallback export InexactShiftedProximableFunction export NormLp, ShiftedNormLp, NormTVp, ShiftedNormTVp diff --git a/test/runtests.jl b/test/runtests.jl index 6711623..cb6e315 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -31,10 +31,9 @@ end ctx = AlgorithmContextCallback(dualGap=dualGap) - ctx_ptr = Ptr{Cvoid}(pointer_from_objref(ctx)) callback_pointer = @cfunction(simple_callback, Cint, (Ptr{Cdouble}, Csize_t, Cdouble, Ptr{Cvoid})) - @test PN_LPp(y, lambda, x, info, n, p, ws, positive, ctx_ptr, callback_pointer) == 1 # 1 is the expected return value of the function. This means that the function has been executed successfully. + @test PN_LPp(y, lambda, x, info, n, p, ws, positive, ctx, callback_pointer) == 1 # 1 is the expected return value of the function. This means that the function has been executed successfully. - @test TV(y, lambda, x, info, n, p, ws, ctx_ptr, callback_pointer)== 1 + @test TV(y, lambda, x, info, n, p, ws, ctx, callback_pointer) == 1 end