Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TSVD Postprocessing of ASP #83

Merged
merged 3 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Optim = "429524aa-4258-5aef-a3af-852621145aeb"
ParallelDataTransfer = "2dcacdae-9679-587a-88bb-8b444fb7085b"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
SharedArrays = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[weakdeps]
Expand Down
116 changes: 69 additions & 47 deletions src/asp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,35 +48,45 @@ solve(solver::ASP, A, y, Aval=A, yval=y)
If independent `Aval` and `yval` are provided (instead of detaults `A, y`),
then the solver will use this separate validation set instead of the training
set to select the best solution along the model path.
"""
# """

struct ASP
P
select
mode::Symbol
tsvd::Bool
nstore::Integer
params
end

function ASP(; P = I, select, mode=:train, params...)
return ASP(P, select, params)
function ASP(; P = I, select, mode=:train, tsvd=false, nstore=100, params...)
return ASP(P, select, mode, tsvd, nstore, params)
end

function solve(solver::ASP, A, y, Aval=A, yval=y)
# Apply preconditioning
AP = A / solver.P

tracer = asp_homotopy(AP, y; solver.params...)
q = length(tracer)
new_tracer = Vector{NamedTuple{(:solution, :λ), Tuple{Any, Any}}}(undef, q)

for i in 1:q
new_tracer[i] = (solution = solver.P \ tracer[i][1], λ = tracer[i][2])
q = length(tracer)
every = max(1, q ÷ solver.nstore)
istore = unique([1:every:q; q])
new_tracer = [ (solution = solver.P \ tracer[i][1], λ = tracer[i][2], σ = 0.0 )
for i in istore ]

if solver.tsvd # Post-processing if tsvd is true
post = post_asp_tsvd(new_tracer, A, y, Aval, yval)
new_post = [ (solution = p.θ, λ = p.λ, σ = p.σ) for p in post ]
else
new_post = new_tracer
end

xs, in = select_solution(new_tracer, solver, Aval, yval)
xs, in = select_solution(new_post, solver, Aval, yval)

# println("done.")
return Dict( "C" => xs,
"path" => new_tracer,
"nnzs" => length((new_tracer[in][:solution]).nzind) )
return Dict( "C" => xs,
"path" => new_post,
"nnzs" => length( (new_tracer[in][:solution]).nzind) )
end


Expand Down Expand Up @@ -114,44 +124,56 @@ function select_solution(tracer, solver, A, y)
end


#=
function select_smart(tracer, solver, Aval, yval)

best_metric = Inf
best_iteration = 0
validation_metric = 0
q = length(tracer)
errors = [norm(Aval * t[:solution] - yval) for t in tracer]
nnzss = [(t[:solution]).nzind for t in tracer]
best_iteration = argmin(errors)
validation_metric = errors[best_iteration]
validation_end = norm(Aval * tracer[end][:solution] - yval)
using SparseArrays

if validation_end < validation_metric #make sure to check the last one too in case q<<100
best_iteration = q
end
function solve_tsvd(At, yt, Av, yv)
Ut, Σt, Vt = svd(At); zt = Ut' * yt
Qv, Rv = qr(Av); zv = Matrix(Qv)' * yv
@assert issorted(Σt, rev=true)

criterion, threshold = solver.select

if criterion == :val
return tracer[best_iteration][:solution], best_iteration

elseif criterion == :byerror
for (i, error) in enumerate(errors)
if error <= threshold * validation_metric
return tracer[i][:solution], i
end
end
Rv_Vt = Rv * Vt

elseif criterion == :bysize
first_index = findfirst(sublist -> threshold in sublist, nnzss)
relevant_errors = errors[1:first_index - 1]
min_error = minimum(relevant_errors)
min_error_index = findfirst(==(min_error), relevant_errors)
return tracer[min_error_index][:solution], min_error_index
θv = zeros(size(Av, 2))
θv[1] = zt[1] / Σt[1]
rv = Rv_Vt[:, 1] * θv[1] - zv

else
@error("Unknown selection criterion: $criterion")
end
tsvd_errs = Float64[]
push!(tsvd_errs, norm(rv))

for k = 2:length(Σt)
θv[k] = zt[k] / Σt[k]
rv += Rv_Vt[:, k] * θv[k]
push!(tsvd_errs, norm(rv))
end

imin = argmin(tsvd_errs)
θv[imin+1:end] .= 0
return Vt * θv, Σt[imin]
end

function post_asp_tsvd(path, At, yt, Av, yv)
Qt, Rt = qr(At); zt = Matrix(Qt)' * yt
Qv, Rv = qr(Av); zv = Matrix(Qv)' * yv

function _post(θλ)
(θ, λ) = θλ
if isempty(θ.nzind); return (θ = θ, λ = λ, σ = Inf); end
inz = θ.nzind
θ1, σ = solve_tsvd(Rt[:, inz], zt, Rv[:, inz], zv)
θ2 = copy(θ); θ2[inz] .= θ1
return (θ = θ2, λ = λ, σ = σ)
end

return _post.(path)

# post = []
# for (θ, λ) in path
# if isempty(θ.nzind); push!(post, (θ = θ, λ = λ, σ = Inf)); continue; end
# inz = θ.nzind
# θ1, σ = solve_tsvd(Rt[:, inz], zt, Rv[:, inz], zv)
# θ2 = copy(θ); θ2[inz] .= θ1
# push!(post, (θ = θ2, λ = λ, σ = σ))
# end
# return identity.(post)
end
=#
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,7 @@ using Test

@testset "Linear Solvers" begin include("test_linearsolvers.jl") end

@testset "ASP" begin include("test_asp.jl") end

@testset "MLJ Solvers" begin include("test_mlj.jl") end
end
111 changes: 111 additions & 0 deletions test/test_asp.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
using ACEfit
using LinearAlgebra, Random, Test
using Random

##

@info("Test Solver on overdetermined system")

Random.seed!(1234)
Nobs = 10_000
Nfeat = 100
A1 = randn(Nobs, Nfeat) / sqrt(Nobs)
U, S1, V = svd(A1)
S = 1e-4 .+ ((S1 .- S1[end]) / (S1[1] - S1[end])).^2
A = U * Diagonal(S) * V'
c_ref = randn(Nfeat)
epsn = 1e-5
y = A * c_ref + epsn * randn(Nobs) / sqrt(Nobs)
P = Diagonal(1.0 .+ rand(Nfeat))

##

@info(" ... ASP")
shuffled_indices = shuffle(1:length(y))
train_indices = shuffled_indices[1:round(Int, 0.85 * length(y))]
val_indices = shuffled_indices[round(Int, 0.85 * length(y)) + 1:end]
At = A[train_indices,:]
Av = A[val_indices,:]
yt = y[train_indices]
yv = y[val_indices]

for (select, tolr, tolc) in [ (:final, 10*epsn, 1),
( (:byerror,1.3), 10*epsn, 1),
( (:bysize,73), 1, 10) ]
@show select
local solver, results, C
solver = ACEfit.ASP(P=I, select = select, loglevel=0, traceFlag=true)
# without validation
results = ACEfit.solve(solver, A, y)
C = results["C"]
full_path = results["path"]
@show results["nnzs"]
@show norm(A * C - y)
@show norm(C)
@show norm(C - c_ref)

@test norm(A * C - y) < tolr
@test norm(C - c_ref) < tolc


# with validation
results = ACEfit.solve(solver, At, yt, Av, yv)
C = results["C"]
full_path = results["path"]
@show results["nnzs"]
@show norm(Av * C - yv)
@show norm(C)
@show norm(C - c_ref)

@test norm(Av * C - yv) < tolr
@test norm(C - c_ref) < tolc
end

##


# I didn't wanna add more tsvd tests to yours so I just wrote this one
# I only wanted to naïvely demonstrate that tsvd actually does make a difference! :)

for (select, tolr, tolc) in [ (:final, 20*epsn, 1.5),
( (:byerror,1.3), 20*epsn, 1.5),
( (:bysize,73), 1, 10) ]
@show select
local solver, results, C
solver_tsvd = ACEfit.ASP(P=I, select=select, mode=:train, tsvd=true,
nstore=100, loglevel=0, traceFlag=true)

solver = ACEfit.ASP(P=I, select=select, mode=:train, tsvd=false,
nstore=100, loglevel=0, traceFlag=true)
# without validation
results_tsvd = ACEfit.solve(solver_tsvd, A, y)
results = ACEfit.solve(solver, A, y)
C_tsvd = results_tsvd["C"]
C = results["C"]

@show results["nnzs"]
@show norm(A * C - y)
@show norm(A * C_tsvd - y)
if norm(A * C_tsvd - y)< norm(A * C - y)
@info "tsvd made improvements!"
else
@warn "tsvd did NOT make any improvements!"
end


# with validation
results_tsvd = ACEfit.solve(solver_tsvd, At, yt, Av, yv)
results = ACEfit.solve(solver, At, yt, Av, yv)
C_tsvd = results_tsvd["C"]
C = results["C"]
@show results["nnzs"]
@show norm(A * C - y)
@show norm(A * C_tsvd - y)

if norm(A * C_tsvd - y)< norm(A * C - y)
@info "tsvd made improvements!"
else
@warn "tsvd did NOT make any improvements!"
end
end

42 changes: 0 additions & 42 deletions test/test_linearsolvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,45 +168,3 @@ C = results["C"]
@test norm(A * C - y) < 10 * epsn
@test norm(C - c_ref) < 1

##

@info(" ... ASP")
shuffled_indices = shuffle(1:length(y))
train_indices = shuffled_indices[1:round(Int, 0.85 * length(y))]
val_indices = shuffled_indices[round(Int, 0.85 * length(y)) + 1:end]
At = A[train_indices,:]
Av = A[val_indices,:]
yt = y[train_indices]
yv = y[val_indices]

for (select, tolr, tolc) in [ (:final, 10*epsn, 1),
( (:byerror,1.3), 10*epsn, 1),
( (:bysize,73), 1, 10) ]
@show select
local solver, results, C
solver = ACEfit.ASP(P=I, select = select, loglevel=0, traceFlag=true)
# without validation
results = ACEfit.solve(solver, A, y)
C = results["C"]
full_path = results["path"]
@show results["nnzs"]
@show norm(A * C - y)
@show norm(C)
@show norm(C - c_ref)

@test norm(A * C - y) < tolr
@test norm(C - c_ref) < tolc


# with validation
results = ACEfit.solve(solver, At, yt, Av, yv)
C = results["C"]
full_path = results["path"]
@show results["nnzs"]
@show norm(Av * C - yv)
@show norm(C)
@show norm(C - c_ref)

@test norm(Av * C - yv) < tolr
@test norm(C - c_ref) < tolc
end
Loading