Skip to content

Commit

Permalink
Merge pull request #84 from ACEsuit/co/valsvd
Browse files Browse the repository at this point in the history
Add svd with validation set
  • Loading branch information
cortner authored Sep 11, 2024
2 parents 1b58c75 + 56dcdcc commit d146148
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 49 deletions.
53 changes: 9 additions & 44 deletions src/asp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,20 +66,23 @@ end
function solve(solver::ASP, A, y, Aval=A, yval=y)
# Apply preconditioning
AP = A / solver.P
AvalP = Aval / solver.P

tracer = asp_homotopy(AP, y; solver.params...)

q = length(tracer)
every = max(1, q ÷ solver.nstore)
istore = unique([1:every:q; q])
new_tracer = [ (solution = solver.P \ tracer[i][1], λ = tracer[i][2], σ = 0.0 )
every = max(1, q / solver.nstore)
istore = unique(round.(Int, [1:every:q; q]))
new_tracer = [ (solution = tracer[i][1], λ = tracer[i][2], σ = 0.0 )
for i in istore ]

if solver.tsvd # Post-processing if tsvd is true
post = post_asp_tsvd(new_tracer, A, y, Aval, yval)
new_post = [ (solution = p.θ, λ = p.λ, σ = p.σ) for p in post ]
post = post_asp_tsvd(new_tracer, AP, y, AvalP, yval)
new_post = [ (solution = solver.P \ p.θ, λ = p.λ, σ = p.σ)
for p in post ]
else
new_post = new_tracer
new_post = [ (solution = solver.P \ p.solution, λ = p.λ, σ = 0.0)
for p in new_tracer ]
end

xs, in = select_solution(new_post, solver, Aval, yval)
Expand Down Expand Up @@ -124,34 +127,6 @@ function select_solution(tracer, solver, A, y)
end



using SparseArrays

function solve_tsvd(At, yt, Av, yv)
Ut, Σt, Vt = svd(At); zt = Ut' * yt
Qv, Rv = qr(Av); zv = Matrix(Qv)' * yv
@assert issorted(Σt, rev=true)

Rv_Vt = Rv * Vt

θv = zeros(size(Av, 2))
θv[1] = zt[1] / Σt[1]
rv = Rv_Vt[:, 1] * θv[1] - zv

tsvd_errs = Float64[]
push!(tsvd_errs, norm(rv))

for k = 2:length(Σt)
θv[k] = zt[k] / Σt[k]
rv += Rv_Vt[:, k] * θv[k]
push!(tsvd_errs, norm(rv))
end

imin = argmin(tsvd_errs)
θv[imin+1:end] .= 0
return Vt * θv, Σt[imin]
end

function post_asp_tsvd(path, At, yt, Av, yv)
Qt, Rt = qr(At); zt = Matrix(Qt)' * yt
Qv, Rv = qr(Av); zv = Matrix(Qv)' * yv
Expand All @@ -166,14 +141,4 @@ function post_asp_tsvd(path, At, yt, Av, yv)
end

return _post.(path)

# post = []
# for (θ, λ) in path
# if isempty(θ.nzind); push!(post, (θ = θ, λ = λ, σ = Inf)); continue; end
# inz = θ.nzind
# θ1, σ = solve_tsvd(Rt[:, inz], zt, Rv[:, inz], zv)
# θ2 = copy(θ); θ2[inz] .= θ1
# push!(post, (θ = θ2, λ = λ, σ = σ))
# end
# return identity.(post)
end
34 changes: 34 additions & 0 deletions src/solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,37 @@ function solve(solver::TruncatedSVD, A, y)
return Dict{String, Any}("C" => solver.P \ θP)
end


# ------------ Truncated SVD with tol specified by validation set ------------

function solve_tsvd(At, yt, Av, yv)
Ut, Σt, Vt = svd(At); zt = Ut' * yt
Qv, Rv = qr(Av); zv = Matrix(Qv)' * yv
@assert issorted(Σt, rev=true)

Rv_Vt = Rv * Vt

θv = zeros(size(Av, 2))
θv[1] = zt[1] / Σt[1]
rv = Rv_Vt[:, 1] * θv[1] - zv

tsvd_errs = Float64[]
push!(tsvd_errs, norm(rv))

for k = 2:length(Σt)
θv[k] = zt[k] / Σt[k]
rv += Rv_Vt[:, k] * θv[k]
push!(tsvd_errs, norm(rv))
end

imin = argmin(tsvd_errs)
θv[imin+1:end] .= 0
return Vt * θv, Σt[imin]
end


function solve(solver::TruncatedSVD, At, yt, Av, yv)
# make a function barrier because solver.P is not inferred
θ, σ = solve_tsvd(At / solver.P, yt, Av / solver.P, yv)
return Dict{String, Any}("C" => solver.P \ θ, "σ" => σ)
end
7 changes: 6 additions & 1 deletion test/test_asp.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using ACEfit
using LinearAlgebra, Random, Test
using Random

##

Expand Down Expand Up @@ -29,6 +28,12 @@ Av = A[val_indices,:]
yt = y[train_indices]
yv = y[val_indices]

for (nstore, n1) in [ (20, 21), (100, 101), (200, 165)]
solver = ACEfit.ASP(P=I, select = :final, nstore = nstore, loglevel=0, traceFlag=true)
results = ACEfit.solve(solver, A, y)
@test length(results["path"]) == n1
end

for (select, tolr, tolc) in [ (:final, 10*epsn, 1),
( (:byerror,1.3), 10*epsn, 1),
( (:bysize,73), 1, 10) ]
Expand Down
21 changes: 17 additions & 4 deletions test/test_linearsolvers.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@

using ACEfit
using LinearAlgebra, Random, Test
using Random
using PythonCall
using ACEfit, LinearAlgebra, Random, Test, PythonCall

##

Expand Down Expand Up @@ -168,3 +165,19 @@ C = results["C"]
@test norm(A * C - y) < 10 * epsn
@test norm(C - c_ref) < 1


##

@info("Truncated SVD with validation")
solver = ACEfit.TruncatedSVD(; rtol = 0.0)
At = A[1:8000, :]
yt = y[1:8000]
Av = A[8001:end, :]
yv = y[8001:end]
results_v = ACEfit.solve(solver, At, yt, Av, yv)
@show err_v = norm(Av * results_v["C"] - yv)
@show err = norm(Av * results["C"] - yv)
@test err_v <= err
@show norm(results_v["C"] - c_ref)
@show norm(results["C"] - c_ref)
@test norm(results_v["C"] - c_ref) < 1e-2

0 comments on commit d146148

Please sign in to comment.