Skip to content

Commit

Permalink
improve compilation, still work to do
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanemac committed Dec 6, 2024
1 parent ec6ba0b commit 5a9c596
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 50 deletions.
2 changes: 1 addition & 1 deletion src/ProxTV.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ function __init__()
end
end

export AlgorithmContextCallback
export InexactShiftedProximableFunction
export NormLp, ShiftedNormLp, NormTVp, ShiftedNormTVp
export prox!, shifted, shift!, TVp_norm
export fun_name, fun_expr, fun_params
export AlgorithmContextCallback, julia_callback, callback_pointer

# main library functions
include("libproxtv.jl")
Expand Down
12 changes: 6 additions & 6 deletions src/libproxtv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ end

# original PN_LPp function
function PN_LPp(y, lambda, x, info, n, p, ws, positive, ctx, callback)
context = unsafe_pointer_to_objref(ctx)::AlgorithmContextCallback
objGap = context.dualGap
objGap = ctx.dualGap
ctx_ptr = Ptr{Cvoid}(pointer_from_objref(ctx))
@ccall libproxtv.PN_LPp(
y::Ptr{Float64},
lambda::Float64,
Expand All @@ -67,7 +67,7 @@ function PN_LPp(y, lambda, x, info, n, p, ws, positive, ctx, callback)
ws::Ptr{Workspace},
positive::Int32,
objGap::Float64,
Ref(ctx)::Ptr{Cvoid},
ctx_ptr::Ptr{Cvoid},
callback::Ptr{Cvoid},
)::Int32
end
Expand Down Expand Up @@ -139,8 +139,8 @@ end

# original TV function
function TV(y, lambda, x, info, n, p, ws, ctx, callback)
context = unsafe_pointer_to_objref(ctx)::AlgorithmContextCallback
objGap = context.dualGap
objGap = ctx.dualGap
ctx_ptr = Ptr{Cvoid}(pointer_from_objref(ctx))
@ccall libproxtv.TV(
y::Ptr{Float64},
lambda::Float64,
Expand All @@ -150,7 +150,7 @@ function TV(y, lambda, x, info, n, p, ws, ctx, callback)
p::Float64,
ws::Ptr{Workspace},
objGap::Float64,
Ref(ctx)::Ptr{Cvoid},
ctx_ptr::Ptr{Cvoid},
callback::Ptr{Cvoid},
)::Int32
end
Expand Down
79 changes: 36 additions & 43 deletions src/proxtv_utils.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,20 @@
abstract type InexactShiftedProximableFunction end


## Structure for callback function in iR2N
mutable struct AlgorithmContextCallback
hk::Float64
mk::Function
κξ::Float64
shift::AbstractVector{Float64}
s_k_unshifted::Vector{Float64}
dualGap::Float64
end
function AlgorithmContextCallback(;hk=0.0, mk = x -> x, κξ = 0.0, shift = zeros(0), s_k_unshifted = zeros(0), dualGap = 0.0)
AlgorithmContextCallback(hk, mk, κξ, shift, s_k_unshifted, dualGap)
end


### NormLp and ShiftedNormLp Implementation

"""
Expand Down Expand Up @@ -45,7 +60,7 @@ function prox!(
h::NormLp,
q::AbstractArray,
ν::Real,
ctx_ptr::Ptr{Cvoid},
context::AlgorithmContextCallback,
callback::Ptr{Cvoid};
)

Expand All @@ -60,7 +75,9 @@ function prox!(

positive = Int32(all(v -> v >= 0, y) ? 1 : 0)

PN_LPp(q, lambda_scaled, y, info, n, h.p, ws, positive, ctx_ptr, callback)
PN_LPp(q, lambda_scaled, y, info, n, h.p, ws, positive, context, callback)

freeWorkspace(ws)

return y
end
Expand Down Expand Up @@ -169,11 +186,11 @@ function prox!(
ψ::ShiftedNormLp,
q::AbstractArray,
ν::Real,
context,
context::AlgorithmContextCallback,
callback::Ptr{Cvoid};
)
n = length(y)
ws = C_NULL # to avoid unexplained memory leaks
ws = newWorkspace(n) # to avoid unexplained memory leaks

# Allocate info array (based on C++ code)
info = zeros(Float64, 3)
Expand All @@ -188,7 +205,6 @@ function prox!(
x = zeros(n)

positive = Int32(all(v -> v >= 0, y_shifted) ? 1 : 0)

if ψ.h.p == 1
PN_LP1(y_shifted, lambda_scaled, x, info, n)
elseif ψ.h.p == 2
Expand All @@ -204,6 +220,8 @@ function prox!(
# Store the result in y
y .= s

freeWorkspace(ws)

return y
end

Expand Down Expand Up @@ -268,19 +286,21 @@ function prox!(
h::NormTVp,
q::AbstractArray,
ν::Real,
ctx_ptr::Ptr{Cvoid},
context::AlgorithmContextCallback,
callback::Ptr{Cvoid})

n = length(y)
ws = C_NULL
ws = newWorkspace(n)

# Allocate info array (based on C++ code)
info = zeros(Float64, 3)

# Adjust λ by ν
lambda_scaled = h.λ * ν

TV(q, lambda_scaled, y, info, n, h.p, ws, ctx_ptr, callback)
TV(q, lambda_scaled, y, info, n, h.p, ws, context, callback)

freeWorkspace(ws)

return y
end
Expand Down Expand Up @@ -384,9 +404,9 @@ Inputs:
- `ctx_ptr`: Pointer to the context object.
- `callback`: Pointer to the callback function.
"""
function prox!(y::AbstractArray, ψ::ShiftedNormTVp, q::AbstractArray, ν::Real, context, callback::Ptr{Cvoid})
function prox!(y::AbstractArray, ψ::ShiftedNormTVp, q::AbstractArray, ν::Real, context::AlgorithmContextCallback, callback::Ptr{Cvoid})
n = length(y)
ws = C_NULL
ws = newWorkspace(n)

# Allocate info array (based on C++ code)
info = zeros(Float64, 3)
Expand All @@ -409,6 +429,8 @@ function prox!(y::AbstractArray, ψ::ShiftedNormTVp, q::AbstractArray, ν::Real,
# Store the result in y
y .= s

freeWorkspace(ws)

return y
end

Expand Down Expand Up @@ -457,44 +479,15 @@ Errors:
function prox!(y, ψ::Union{InexactShiftedProximableFunction, ShiftedProximableFunction
}, q, ν; ctx_ptr, callback)
if ψ isa ShiftedProximableFunction
# Call to exact prox!() if dualGap is not defined
# Call to exact prox!()
return prox!(y, ψ, q, ν)
elseif ψ isa InexactShiftedProximableFunction
# Call to inexact prox!() if dualGap is defined
# Call to inexact prox!()
println("Inexact prox!() called")
return prox!(y, ψ, q, ν, ctx_ptr, callback)

else
error("Combination of ψ::$(typeof(ψ)) presence/lack of pointers is not a valid call to prox!.
Please provide pointers for InexactShiftedProximableFunction or omit them for ShiftedProximableFunction.")
end
end

############################################################################################################
########################## CALLBACK FUNCTION ##########################
# Structure for callback function in iR2N
mutable struct AlgorithmContextCallback
hk::Float64
mk::Function
κξ::Float64
shift::AbstractVector{Float64}
s_k_unshifted::Vector{Float64}
dualGap::Float64
end
function AlgorithmContextCallback(;hk=0.0, mk = x -> x, κξ = 0.0, shift = zeros(0), s_k_unshifted = zeros(0), dualGap = 0.0)
AlgorithmContextCallback(hk, mk, κξ, shift, s_k_unshifted, dualGap)
end

function julia_callback(s_ptr::Ptr{Cdouble}, s_length::Csize_t, delta_k::Cdouble, ctx_ptr::Ptr{Cvoid})::Cint
s_k = unsafe_wrap(Vector{Float64}, s_ptr, s_length; own = false)
context = unsafe_pointer_to_objref(ctx_ptr)::AlgorithmContextCallback

# In-place operation to avoid memory allocations
@. context.s_k_unshifted = s_k - context.shift

# Computations without allocations
ξk = context.hk - context.mk(context.s_k_unshifted) + max(1, abs(context.hk)) * 10 * eps()
condition = delta_k (1 - context.κξ) / context.κξ * ξk

return condition ? Int32(1) : Int32(0)
end

callback_pointer = @cfunction(julia_callback, Cint, (Ptr{Cdouble}, Csize_t, Cdouble, Ptr{Cvoid}))

0 comments on commit 5a9c596

Please sign in to comment.