From 5a9c5967501486749a8c18be3a8ee6d648662c92 Mon Sep 17 00:00:00 2001 From: nathanemac <91251698+nathanemac@users.noreply.github.com> Date: Fri, 6 Dec 2024 17:00:37 -0500 Subject: [PATCH] improve compilation, still work to do --- src/ProxTV.jl | 2 +- src/libproxtv.jl | 12 +++---- src/proxtv_utils.jl | 79 +++++++++++++++++++++------------------------ 3 files changed, 43 insertions(+), 50 deletions(-) diff --git a/src/ProxTV.jl b/src/ProxTV.jl index 3133d8f..17b28b1 100644 --- a/src/ProxTV.jl +++ b/src/ProxTV.jl @@ -12,11 +12,11 @@ function __init__() end end +export AlgorithmContextCallback export InexactShiftedProximableFunction export NormLp, ShiftedNormLp, NormTVp, ShiftedNormTVp export prox!, shifted, shift!, TVp_norm export fun_name, fun_expr, fun_params -export AlgorithmContextCallback, julia_callback, callback_pointer # main library functions include("libproxtv.jl") diff --git a/src/libproxtv.jl b/src/libproxtv.jl index ea6028d..0fdedfa 100644 --- a/src/libproxtv.jl +++ b/src/libproxtv.jl @@ -55,8 +55,8 @@ end # original PN_LPp function function PN_LPp(y, lambda, x, info, n, p, ws, positive, ctx, callback) - context = unsafe_pointer_to_objref(ctx)::AlgorithmContextCallback - objGap = context.dualGap + objGap = ctx.dualGap + ctx_ptr = Ptr{Cvoid}(pointer_from_objref(ctx)) @ccall libproxtv.PN_LPp( y::Ptr{Float64}, lambda::Float64, @@ -67,7 +67,7 @@ function PN_LPp(y, lambda, x, info, n, p, ws, positive, ctx, callback) ws::Ptr{Workspace}, positive::Int32, objGap::Float64, - Ref(ctx)::Ptr{Cvoid}, + ctx_ptr::Ptr{Cvoid}, callback::Ptr{Cvoid}, )::Int32 end @@ -139,8 +139,8 @@ end # original TV function function TV(y, lambda, x, info, n, p, ws, ctx, callback) - context = unsafe_pointer_to_objref(ctx)::AlgorithmContextCallback - objGap = context.dualGap + objGap = ctx.dualGap + ctx_ptr = Ptr{Cvoid}(pointer_from_objref(ctx)) @ccall libproxtv.TV( y::Ptr{Float64}, lambda::Float64, @@ -150,7 +150,7 @@ function TV(y, lambda, x, info, n, p, ws, ctx, callback) p::Float64, ws::Ptr{Workspace}, objGap::Float64, - Ref(ctx)::Ptr{Cvoid}, + ctx_ptr::Ptr{Cvoid}, callback::Ptr{Cvoid}, )::Int32 end diff --git a/src/proxtv_utils.jl b/src/proxtv_utils.jl index f23aed4..572c433 100644 --- a/src/proxtv_utils.jl +++ b/src/proxtv_utils.jl @@ -1,5 +1,20 @@ abstract type InexactShiftedProximableFunction end + +## Structure for callback function in iR2N +mutable struct AlgorithmContextCallback + hk::Float64 + mk::Function + κξ::Float64 + shift::AbstractVector{Float64} + s_k_unshifted::Vector{Float64} + dualGap::Float64 +end +function AlgorithmContextCallback(;hk=0.0, mk = x -> x, κξ = 0.0, shift = zeros(0), s_k_unshifted = zeros(0), dualGap = 0.0) + AlgorithmContextCallback(hk, mk, κξ, shift, s_k_unshifted, dualGap) +end + + ### NormLp and ShiftedNormLp Implementation """ @@ -45,7 +60,7 @@ function prox!( h::NormLp, q::AbstractArray, ν::Real, - ctx_ptr::Ptr{Cvoid}, + context::AlgorithmContextCallback, callback::Ptr{Cvoid}; ) @@ -60,7 +75,9 @@ function prox!( positive = Int32(all(v -> v >= 0, y) ? 1 : 0) - PN_LPp(q, lambda_scaled, y, info, n, h.p, ws, positive, ctx_ptr, callback) + PN_LPp(q, lambda_scaled, y, info, n, h.p, ws, positive, context, callback) + + freeWorkspace(ws) return y end @@ -169,11 +186,11 @@ function prox!( ψ::ShiftedNormLp, q::AbstractArray, ν::Real, - context, + context::AlgorithmContextCallback, callback::Ptr{Cvoid}; ) n = length(y) - ws = C_NULL # to avoid unexplained memory leaks + ws = newWorkspace(n) # to avoid unexplained memory leaks # Allocate info array (based on C++ code) info = zeros(Float64, 3) @@ -188,7 +205,6 @@ function prox!( x = zeros(n) positive = Int32(all(v -> v >= 0, y_shifted) ? 1 : 0) - if ψ.h.p == 1 PN_LP1(y_shifted, lambda_scaled, x, info, n) elseif ψ.h.p == 2 @@ -204,6 +220,8 @@ function prox!( # Store the result in y y .= s + freeWorkspace(ws) + return y end @@ -268,11 +286,11 @@ function prox!( h::NormTVp, q::AbstractArray, ν::Real, - ctx_ptr::Ptr{Cvoid}, + context::AlgorithmContextCallback, callback::Ptr{Cvoid}) n = length(y) - ws = C_NULL + ws = newWorkspace(n) # Allocate info array (based on C++ code) info = zeros(Float64, 3) @@ -280,7 +298,9 @@ function prox!( # Adjust λ by ν lambda_scaled = h.λ * ν - TV(q, lambda_scaled, y, info, n, h.p, ws, ctx_ptr, callback) + TV(q, lambda_scaled, y, info, n, h.p, ws, context, callback) + + freeWorkspace(ws) return y end @@ -384,9 +404,9 @@ Inputs: - `ctx_ptr`: Pointer to the context object. - `callback`: Pointer to the callback function. """ -function prox!(y::AbstractArray, ψ::ShiftedNormTVp, q::AbstractArray, ν::Real, context, callback::Ptr{Cvoid}) +function prox!(y::AbstractArray, ψ::ShiftedNormTVp, q::AbstractArray, ν::Real, context::AlgorithmContextCallback, callback::Ptr{Cvoid}) n = length(y) - ws = C_NULL + ws = newWorkspace(n) # Allocate info array (based on C++ code) info = zeros(Float64, 3) @@ -409,6 +429,8 @@ function prox!(y::AbstractArray, ψ::ShiftedNormTVp, q::AbstractArray, ν::Real, # Store the result in y y .= s + freeWorkspace(ws) + return y end @@ -457,44 +479,15 @@ Errors: function prox!(y, ψ::Union{InexactShiftedProximableFunction, ShiftedProximableFunction }, q, ν; ctx_ptr, callback) if ψ isa ShiftedProximableFunction - # Call to exact prox!() if dualGap is not defined + # Call to exact prox!() return prox!(y, ψ, q, ν) elseif ψ isa InexactShiftedProximableFunction - # Call to inexact prox!() if dualGap is defined + # Call to inexact prox!() + println("Inexact prox!() called") return prox!(y, ψ, q, ν, ctx_ptr, callback) + else error("Combination of ψ::$(typeof(ψ)) presence/lack of pointers is not a valid call to prox!. Please provide pointers for InexactShiftedProximableFunction or omit them for ShiftedProximableFunction.") end end - -############################################################################################################ -########################## CALLBACK FUNCTION ########################## -# Structure for callback function in iR2N -mutable struct AlgorithmContextCallback - hk::Float64 - mk::Function - κξ::Float64 - shift::AbstractVector{Float64} - s_k_unshifted::Vector{Float64} - dualGap::Float64 -end -function AlgorithmContextCallback(;hk=0.0, mk = x -> x, κξ = 0.0, shift = zeros(0), s_k_unshifted = zeros(0), dualGap = 0.0) - AlgorithmContextCallback(hk, mk, κξ, shift, s_k_unshifted, dualGap) -end - -function julia_callback(s_ptr::Ptr{Cdouble}, s_length::Csize_t, delta_k::Cdouble, ctx_ptr::Ptr{Cvoid})::Cint - s_k = unsafe_wrap(Vector{Float64}, s_ptr, s_length; own = false) - context = unsafe_pointer_to_objref(ctx_ptr)::AlgorithmContextCallback - - # In-place operation to avoid memory allocations - @. context.s_k_unshifted = s_k - context.shift - - # Computations without allocations - ξk = context.hk - context.mk(context.s_k_unshifted) + max(1, abs(context.hk)) * 10 * eps() - condition = delta_k ≤ (1 - context.κξ) / context.κξ * ξk - - return condition ? Int32(1) : Int32(0) -end - -callback_pointer = @cfunction(julia_callback, Cint, (Ptr{Cdouble}, Csize_t, Cdouble, Ptr{Cvoid}))