Skip to content

Commit

Permalink
Merge pull request #83 from ACEsuit/asp_svd
Browse files Browse the repository at this point in the history
TSVD Postprocessing of ASP
  • Loading branch information
cortner authored Sep 9, 2024
2 parents e0fea2e + 83dbee5 commit b021f6a
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 89 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Optim = "429524aa-4258-5aef-a3af-852621145aeb"
ParallelDataTransfer = "2dcacdae-9679-587a-88bb-8b444fb7085b"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
SharedArrays = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[weakdeps]
Expand Down
116 changes: 69 additions & 47 deletions src/asp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,35 +48,45 @@ solve(solver::ASP, A, y, Aval=A, yval=y)
If independent `Aval` and `yval` are provided (instead of detaults `A, y`),
then the solver will use this separate validation set instead of the training
set to select the best solution along the model path.
"""
# """

struct ASP
P
select
mode::Symbol
tsvd::Bool
nstore::Integer
params
end

function ASP(; P = I, select, mode=:train, params...)
return ASP(P, select, params)
function ASP(; P = I, select, mode=:train, tsvd=false, nstore=100, params...)
return ASP(P, select, mode, tsvd, nstore, params)
end

function solve(solver::ASP, A, y, Aval=A, yval=y)
# Apply preconditioning
AP = A / solver.P

tracer = asp_homotopy(AP, y; solver.params...)
q = length(tracer)
new_tracer = Vector{NamedTuple{(:solution, :λ), Tuple{Any, Any}}}(undef, q)

for i in 1:q
new_tracer[i] = (solution = solver.P \ tracer[i][1], λ = tracer[i][2])
q = length(tracer)
every = max(1, q ÷ solver.nstore)
istore = unique([1:every:q; q])
new_tracer = [ (solution = solver.P \ tracer[i][1], λ = tracer[i][2], σ = 0.0 )
for i in istore ]

if solver.tsvd # Post-processing if tsvd is true
post = post_asp_tsvd(new_tracer, A, y, Aval, yval)
new_post = [ (solution = p.θ, λ = p.λ, σ = p.σ) for p in post ]
else
new_post = new_tracer
end

xs, in = select_solution(new_tracer, solver, Aval, yval)
xs, in = select_solution(new_post, solver, Aval, yval)

# println("done.")
return Dict( "C" => xs,
"path" => new_tracer,
"nnzs" => length((new_tracer[in][:solution]).nzind) )
return Dict( "C" => xs,
"path" => new_post,
"nnzs" => length( (new_tracer[in][:solution]).nzind) )
end


Expand Down Expand Up @@ -114,44 +124,56 @@ function select_solution(tracer, solver, A, y)
end


#=
function select_smart(tracer, solver, Aval, yval)

best_metric = Inf
best_iteration = 0
validation_metric = 0
q = length(tracer)
errors = [norm(Aval * t[:solution] - yval) for t in tracer]
nnzss = [(t[:solution]).nzind for t in tracer]
best_iteration = argmin(errors)
validation_metric = errors[best_iteration]
validation_end = norm(Aval * tracer[end][:solution] - yval)
using SparseArrays

if validation_end < validation_metric #make sure to check the last one too in case q<<100
best_iteration = q
end
function solve_tsvd(At, yt, Av, yv)
Ut, Σt, Vt = svd(At); zt = Ut' * yt
Qv, Rv = qr(Av); zv = Matrix(Qv)' * yv
@assert issorted(Σt, rev=true)

criterion, threshold = solver.select
if criterion == :val
return tracer[best_iteration][:solution], best_iteration
elseif criterion == :byerror
for (i, error) in enumerate(errors)
if error <= threshold * validation_metric
return tracer[i][:solution], i
end
end
Rv_Vt = Rv * Vt

elseif criterion == :bysize
first_index = findfirst(sublist -> threshold in sublist, nnzss)
relevant_errors = errors[1:first_index - 1]
min_error = minimum(relevant_errors)
min_error_index = findfirst(==(min_error), relevant_errors)
return tracer[min_error_index][:solution], min_error_index
θv = zeros(size(Av, 2))
θv[1] = zt[1] / Σt[1]
rv = Rv_Vt[:, 1] * θv[1] - zv

else
@error("Unknown selection criterion: $criterion")
end
tsvd_errs = Float64[]
push!(tsvd_errs, norm(rv))

for k = 2:length(Σt)
θv[k] = zt[k] / Σt[k]
rv += Rv_Vt[:, k] * θv[k]
push!(tsvd_errs, norm(rv))
end

imin = argmin(tsvd_errs)
θv[imin+1:end] .= 0
return Vt * θv, Σt[imin]
end

function post_asp_tsvd(path, At, yt, Av, yv)
Qt, Rt = qr(At); zt = Matrix(Qt)' * yt
Qv, Rv = qr(Av); zv = Matrix(Qv)' * yv

function _post(θλ)
(θ, λ) = θλ
if isempty.nzind); return= θ, λ = λ, σ = Inf); end
inz = θ.nzind
θ1, σ = solve_tsvd(Rt[:, inz], zt, Rv[:, inz], zv)
θ2 = copy(θ); θ2[inz] .= θ1
return= θ2, λ = λ, σ = σ)
end

return _post.(path)

# post = []
# for (θ, λ) in path
# if isempty(θ.nzind); push!(post, (θ = θ, λ = λ, σ = Inf)); continue; end
# inz = θ.nzind
# θ1, σ = solve_tsvd(Rt[:, inz], zt, Rv[:, inz], zv)
# θ2 = copy(θ); θ2[inz] .= θ1
# push!(post, (θ = θ2, λ = λ, σ = σ))
# end
# return identity.(post)
end
=#
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,7 @@ using Test

@testset "Linear Solvers" begin include("test_linearsolvers.jl") end

@testset "ASP" begin include("test_asp.jl") end

@testset "MLJ Solvers" begin include("test_mlj.jl") end
end
111 changes: 111 additions & 0 deletions test/test_asp.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
using ACEfit
using LinearAlgebra, Random, Test
using Random

##

@info("Test Solver on overdetermined system")

Random.seed!(1234)
Nobs = 10_000
Nfeat = 100
A1 = randn(Nobs, Nfeat) / sqrt(Nobs)
U, S1, V = svd(A1)
S = 1e-4 .+ ((S1 .- S1[end]) / (S1[1] - S1[end])).^2
A = U * Diagonal(S) * V'
c_ref = randn(Nfeat)
epsn = 1e-5
y = A * c_ref + epsn * randn(Nobs) / sqrt(Nobs)
P = Diagonal(1.0 .+ rand(Nfeat))

##

@info(" ... ASP")
shuffled_indices = shuffle(1:length(y))
train_indices = shuffled_indices[1:round(Int, 0.85 * length(y))]
val_indices = shuffled_indices[round(Int, 0.85 * length(y)) + 1:end]
At = A[train_indices,:]
Av = A[val_indices,:]
yt = y[train_indices]
yv = y[val_indices]

for (select, tolr, tolc) in [ (:final, 10*epsn, 1),
( (:byerror,1.3), 10*epsn, 1),
( (:bysize,73), 1, 10) ]
@show select
local solver, results, C
solver = ACEfit.ASP(P=I, select = select, loglevel=0, traceFlag=true)
# without validation
results = ACEfit.solve(solver, A, y)
C = results["C"]
full_path = results["path"]
@show results["nnzs"]
@show norm(A * C - y)
@show norm(C)
@show norm(C - c_ref)

@test norm(A * C - y) < tolr
@test norm(C - c_ref) < tolc


# with validation
results = ACEfit.solve(solver, At, yt, Av, yv)
C = results["C"]
full_path = results["path"]
@show results["nnzs"]
@show norm(Av * C - yv)
@show norm(C)
@show norm(C - c_ref)

@test norm(Av * C - yv) < tolr
@test norm(C - c_ref) < tolc
end

##


# I didn't wanna add more tsvd tests to yours so I just wrote this one
# I only wanted to naïvely demonstrate that tsvd actually does make a difference! :)

for (select, tolr, tolc) in [ (:final, 20*epsn, 1.5),
( (:byerror,1.3), 20*epsn, 1.5),
( (:bysize,73), 1, 10) ]
@show select
local solver, results, C
solver_tsvd = ACEfit.ASP(P=I, select=select, mode=:train, tsvd=true,
nstore=100, loglevel=0, traceFlag=true)

solver = ACEfit.ASP(P=I, select=select, mode=:train, tsvd=false,
nstore=100, loglevel=0, traceFlag=true)
# without validation
results_tsvd = ACEfit.solve(solver_tsvd, A, y)
results = ACEfit.solve(solver, A, y)
C_tsvd = results_tsvd["C"]
C = results["C"]

@show results["nnzs"]
@show norm(A * C - y)
@show norm(A * C_tsvd - y)
if norm(A * C_tsvd - y)< norm(A * C - y)
@info "tsvd made improvements!"
else
@warn "tsvd did NOT make any improvements!"
end


# with validation
results_tsvd = ACEfit.solve(solver_tsvd, At, yt, Av, yv)
results = ACEfit.solve(solver, At, yt, Av, yv)
C_tsvd = results_tsvd["C"]
C = results["C"]
@show results["nnzs"]
@show norm(A * C - y)
@show norm(A * C_tsvd - y)

if norm(A * C_tsvd - y)< norm(A * C - y)
@info "tsvd made improvements!"
else
@warn "tsvd did NOT make any improvements!"
end
end

42 changes: 0 additions & 42 deletions test/test_linearsolvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,45 +168,3 @@ C = results["C"]
@test norm(A * C - y) < 10 * epsn
@test norm(C - c_ref) < 1

##

@info(" ... ASP")
shuffled_indices = shuffle(1:length(y))
train_indices = shuffled_indices[1:round(Int, 0.85 * length(y))]
val_indices = shuffled_indices[round(Int, 0.85 * length(y)) + 1:end]
At = A[train_indices,:]
Av = A[val_indices,:]
yt = y[train_indices]
yv = y[val_indices]

for (select, tolr, tolc) in [ (:final, 10*epsn, 1),
( (:byerror,1.3), 10*epsn, 1),
( (:bysize,73), 1, 10) ]
@show select
local solver, results, C
solver = ACEfit.ASP(P=I, select = select, loglevel=0, traceFlag=true)
# without validation
results = ACEfit.solve(solver, A, y)
C = results["C"]
full_path = results["path"]
@show results["nnzs"]
@show norm(A * C - y)
@show norm(C)
@show norm(C - c_ref)

@test norm(A * C - y) < tolr
@test norm(C - c_ref) < tolc


# with validation
results = ACEfit.solve(solver, At, yt, Av, yv)
C = results["C"]
full_path = results["path"]
@show results["nnzs"]
@show norm(Av * C - yv)
@show norm(C)
@show norm(C - c_ref)

@test norm(Av * C - yv) < tolr
@test norm(C - c_ref) < tolc
end

0 comments on commit b021f6a

Please sign in to comment.