diff --git a/Project.toml b/Project.toml index 6106afd..a68c060 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ProxTV" uuid = "925ea013-038b-5ab6-a1ab-e0849925e528" authors = ["Nathan Allaire and contributors"] -version = "0.3.0" +version = "1.0.0" [deps] JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" diff --git a/src/ProxTV.jl b/src/ProxTV.jl index 6a529eb..7b2da58 100644 --- a/src/ProxTV.jl +++ b/src/ProxTV.jl @@ -12,9 +12,20 @@ function __init__() end end +# import functions that we extend from ShiftedProximalOperators +import ShiftedProximalOperators.shift! +import ShiftedProximalOperators.shifted +import ShiftedProximalOperators.prox! + +# export our new functions +export AlgorithmContextCallback +export InexactShiftedProximableFunction +export NormLp, ShiftedNormLp, NormTVp, ShiftedNormTVp +export prox!, shifted, shift!, TVp_norm +export fun_name, fun_expr, fun_params + # main library functions include("libproxtv.jl") include("proxtv_utils.jl") - end diff --git a/src/libproxtv.jl b/src/libproxtv.jl index ea6028d..0fdedfa 100644 --- a/src/libproxtv.jl +++ b/src/libproxtv.jl @@ -55,8 +55,8 @@ end # original PN_LPp function function PN_LPp(y, lambda, x, info, n, p, ws, positive, ctx, callback) - context = unsafe_pointer_to_objref(ctx)::AlgorithmContextCallback - objGap = context.dualGap + objGap = ctx.dualGap + ctx_ptr = Ptr{Cvoid}(pointer_from_objref(ctx)) @ccall libproxtv.PN_LPp( y::Ptr{Float64}, lambda::Float64, @@ -67,7 +67,7 @@ function PN_LPp(y, lambda, x, info, n, p, ws, positive, ctx, callback) ws::Ptr{Workspace}, positive::Int32, objGap::Float64, - Ref(ctx)::Ptr{Cvoid}, + ctx_ptr::Ptr{Cvoid}, callback::Ptr{Cvoid}, )::Int32 end @@ -139,8 +139,8 @@ end # original TV function function TV(y, lambda, x, info, n, p, ws, ctx, callback) - context = unsafe_pointer_to_objref(ctx)::AlgorithmContextCallback - objGap = context.dualGap + objGap = ctx.dualGap + ctx_ptr = Ptr{Cvoid}(pointer_from_objref(ctx)) @ccall libproxtv.TV( y::Ptr{Float64}, lambda::Float64, @@ -150,7 +150,7 @@ function TV(y, lambda, x, info, n, p, ws, ctx, callback) p::Float64, ws::Ptr{Workspace}, objGap::Float64, - Ref(ctx)::Ptr{Cvoid}, + ctx_ptr::Ptr{Cvoid}, callback::Ptr{Cvoid}, )::Int32 end diff --git a/src/proxtv_utils.jl b/src/proxtv_utils.jl index f23aed4..572c433 100644 --- a/src/proxtv_utils.jl +++ b/src/proxtv_utils.jl @@ -1,5 +1,20 @@ abstract type InexactShiftedProximableFunction end + +## Structure for callback function in iR2N +mutable struct AlgorithmContextCallback + hk::Float64 + mk::Function + κξ::Float64 + shift::AbstractVector{Float64} + s_k_unshifted::Vector{Float64} + dualGap::Float64 +end +function AlgorithmContextCallback(;hk=0.0, mk = x -> x, κξ = 0.0, shift = zeros(0), s_k_unshifted = zeros(0), dualGap = 0.0) + AlgorithmContextCallback(hk, mk, κξ, shift, s_k_unshifted, dualGap) +end + + ### NormLp and ShiftedNormLp Implementation """ @@ -45,7 +60,7 @@ function prox!( h::NormLp, q::AbstractArray, ν::Real, - ctx_ptr::Ptr{Cvoid}, + context::AlgorithmContextCallback, callback::Ptr{Cvoid}; ) @@ -60,7 +75,9 @@ function prox!( positive = Int32(all(v -> v >= 0, y) ? 1 : 0) - PN_LPp(q, lambda_scaled, y, info, n, h.p, ws, positive, ctx_ptr, callback) + PN_LPp(q, lambda_scaled, y, info, n, h.p, ws, positive, context, callback) + + freeWorkspace(ws) return y end @@ -169,11 +186,11 @@ function prox!( ψ::ShiftedNormLp, q::AbstractArray, ν::Real, - context, + context::AlgorithmContextCallback, callback::Ptr{Cvoid}; ) n = length(y) - ws = C_NULL # to avoid unexplained memory leaks + ws = newWorkspace(n) # to avoid unexplained memory leaks # Allocate info array (based on C++ code) info = zeros(Float64, 3) @@ -188,7 +205,6 @@ function prox!( x = zeros(n) positive = Int32(all(v -> v >= 0, y_shifted) ? 1 : 0) - if ψ.h.p == 1 PN_LP1(y_shifted, lambda_scaled, x, info, n) elseif ψ.h.p == 2 @@ -204,6 +220,8 @@ function prox!( # Store the result in y y .= s + freeWorkspace(ws) + return y end @@ -268,11 +286,11 @@ function prox!( h::NormTVp, q::AbstractArray, ν::Real, - ctx_ptr::Ptr{Cvoid}, + context::AlgorithmContextCallback, callback::Ptr{Cvoid}) n = length(y) - ws = C_NULL + ws = newWorkspace(n) # Allocate info array (based on C++ code) info = zeros(Float64, 3) @@ -280,7 +298,9 @@ function prox!( # Adjust λ by ν lambda_scaled = h.λ * ν - TV(q, lambda_scaled, y, info, n, h.p, ws, ctx_ptr, callback) + TV(q, lambda_scaled, y, info, n, h.p, ws, context, callback) + + freeWorkspace(ws) return y end @@ -384,9 +404,9 @@ Inputs: - `ctx_ptr`: Pointer to the context object. - `callback`: Pointer to the callback function. """ -function prox!(y::AbstractArray, ψ::ShiftedNormTVp, q::AbstractArray, ν::Real, context, callback::Ptr{Cvoid}) +function prox!(y::AbstractArray, ψ::ShiftedNormTVp, q::AbstractArray, ν::Real, context::AlgorithmContextCallback, callback::Ptr{Cvoid}) n = length(y) - ws = C_NULL + ws = newWorkspace(n) # Allocate info array (based on C++ code) info = zeros(Float64, 3) @@ -409,6 +429,8 @@ function prox!(y::AbstractArray, ψ::ShiftedNormTVp, q::AbstractArray, ν::Real, # Store the result in y y .= s + freeWorkspace(ws) + return y end @@ -457,44 +479,15 @@ Errors: function prox!(y, ψ::Union{InexactShiftedProximableFunction, ShiftedProximableFunction }, q, ν; ctx_ptr, callback) if ψ isa ShiftedProximableFunction - # Call to exact prox!() if dualGap is not defined + # Call to exact prox!() return prox!(y, ψ, q, ν) elseif ψ isa InexactShiftedProximableFunction - # Call to inexact prox!() if dualGap is defined + # Call to inexact prox!() + println("Inexact prox!() called") return prox!(y, ψ, q, ν, ctx_ptr, callback) + else error("Combination of ψ::$(typeof(ψ)) presence/lack of pointers is not a valid call to prox!. Please provide pointers for InexactShiftedProximableFunction or omit them for ShiftedProximableFunction.") end end - -############################################################################################################ -########################## CALLBACK FUNCTION ########################## -# Structure for callback function in iR2N -mutable struct AlgorithmContextCallback - hk::Float64 - mk::Function - κξ::Float64 - shift::AbstractVector{Float64} - s_k_unshifted::Vector{Float64} - dualGap::Float64 -end -function AlgorithmContextCallback(;hk=0.0, mk = x -> x, κξ = 0.0, shift = zeros(0), s_k_unshifted = zeros(0), dualGap = 0.0) - AlgorithmContextCallback(hk, mk, κξ, shift, s_k_unshifted, dualGap) -end - -function julia_callback(s_ptr::Ptr{Cdouble}, s_length::Csize_t, delta_k::Cdouble, ctx_ptr::Ptr{Cvoid})::Cint - s_k = unsafe_wrap(Vector{Float64}, s_ptr, s_length; own = false) - context = unsafe_pointer_to_objref(ctx_ptr)::AlgorithmContextCallback - - # In-place operation to avoid memory allocations - @. context.s_k_unshifted = s_k - context.shift - - # Computations without allocations - ξk = context.hk - context.mk(context.s_k_unshifted) + max(1, abs(context.hk)) * 10 * eps() - condition = delta_k ≤ (1 - context.κξ) / context.κξ * ξk - - return condition ? Int32(1) : Int32(0) -end - -callback_pointer = @cfunction(julia_callback, Cint, (Ptr{Cdouble}, Csize_t, Cdouble, Ptr{Cvoid})) diff --git a/test/runtests.jl b/test/runtests.jl index 6711623..cb6e315 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -31,10 +31,9 @@ end ctx = AlgorithmContextCallback(dualGap=dualGap) - ctx_ptr = Ptr{Cvoid}(pointer_from_objref(ctx)) callback_pointer = @cfunction(simple_callback, Cint, (Ptr{Cdouble}, Csize_t, Cdouble, Ptr{Cvoid})) - @test PN_LPp(y, lambda, x, info, n, p, ws, positive, ctx_ptr, callback_pointer) == 1 # 1 is the expected return value of the function. This means that the function has been executed successfully. + @test PN_LPp(y, lambda, x, info, n, p, ws, positive, ctx, callback_pointer) == 1 # 1 is the expected return value of the function. This means that the function has been executed successfully. - @test TV(y, lambda, x, info, n, p, ws, ctx_ptr, callback_pointer)== 1 + @test TV(y, lambda, x, info, n, p, ws, ctx, callback_pointer) == 1 end