diff --git a/src/asp.jl b/src/asp.jl index 6006128..7e0460b 100644 --- a/src/asp.jl +++ b/src/asp.jl @@ -83,19 +83,28 @@ function solve(solver::ASP, A, y, Aval=A, yval=y) for p in new_tracer ] end - xs, in = select_solution(new_post, solver, Aval, yval) + tracer_final = _add_errors(new_post, Aval, yval) + xs, in = asp_select(tracer_final, solver.select) - return Dict( "C" => xs, - "path" => new_post, - "nnzs" => length( (new_post[in][:solution]).nzind) ) + return Dict( "C" => xs, + "path" => tracer_final, ) end -function select_solution(tracer, solver, A, y) - if solver.select == :final +function _add_errors(tracer, A, y) + rtN = sqrt(length(y)) + return [ ( solution = t.solution, λ = t.λ, σ = t.σ, + rmse = norm(A * t.solution - y) / rtN ) + for t in tracer ] +end + +asp_select(D::Dict, select) = asp_select(D["path"], select) + +function asp_select(tracer, select) + if select == :final criterion = :final else - criterion, p = solver.select + criterion, p = select end if criterion == :final @@ -108,12 +117,12 @@ function select_solution(tracer, solver, A, y) elseif criterion == :bysize maxind = findfirst(t -> length((t[:solution]).nzind) > p, tracer) - 1 - threshold = 1.0 + threshold = 1.0 else error("Unknown selection criterion: $criterion") end - errors = [ norm(A * t[:solution] - y) for t in tracer[1:maxind] ] + errors = [ t.rmse for t in tracer[1:maxind] ] min_error = minimum(errors) for (i, error) in enumerate(errors) if error <= threshold * min_error @@ -140,3 +149,16 @@ function post_asp_tsvd(path, At, yt, Av, yv) return _post.(path) end + +# TODO: revisit this idea. Maybe we do want to keep this, not as `select` +# but as `solve`. But if we do, then it might be nice to be able to +# extend the path somehow. For now I'm removing it since I don't see +# the immediate need yet. Just calling asp_select is how I would normally +# use this. +# +# function select(tracer, solver, A, y) #can be called by the user to warm-start the selection +# xs, in = select_solution(tracer, solver, A, y) +# return Dict("C" => xs, +# "path" => tracer, +# "nnzs" => length( (tracer[in][:solution]).nzind) ) +# end diff --git a/test/test_asp.jl b/test/test_asp.jl index 6930ca6..420780c 100644 --- a/test/test_asp.jl +++ b/test/test_asp.jl @@ -1,5 +1,5 @@ using ACEfit -using LinearAlgebra, Random, Test +using LinearAlgebra, Random, Test ## @@ -47,7 +47,7 @@ for (select, tolr, tolc) in [ (:final, 10*epsn, 1), results = ACEfit.solve(solver, A, y) C = results["C"] full_path = results["path"] - @show results["nnzs"] + # @show results["nnzs"] @show norm(A * C - y) @show norm(C) @show norm(C - c_ref) @@ -60,7 +60,7 @@ for (select, tolr, tolc) in [ (:final, 10*epsn, 1), results = ACEfit.solve(solver, At, yt, Av, yv) C = results["C"] full_path = results["path"] - @show results["nnzs"] + # @show results["nnzs"] @show norm(Av * C - yv) @show norm(C) @show norm(C - c_ref) @@ -91,7 +91,7 @@ for (select, tolr, tolc) in [ (:final, 20*epsn, 1.5), C_tsvd = results_tsvd["C"] C = results["C"] - @show results["nnzs"] + # @show results["nnzs"] @show norm(A * C - y) @show norm(A * C_tsvd - y) if norm(A * C_tsvd - y)< norm(A * C - y) @@ -106,7 +106,7 @@ for (select, tolr, tolc) in [ (:final, 20*epsn, 1.5), results = ACEfit.solve(solver, At, yt, Av, yv) C_tsvd = results_tsvd["C"] C = results["C"] - @show results["nnzs"] + # @show results["nnzs"] @show norm(A * C - y) @show norm(A * C_tsvd - y) @@ -117,3 +117,36 @@ for (select, tolr, tolc) in [ (:final, 20*epsn, 1.5), end end +## + +# Testing the "select" function +solver_final = ACEfit.ASP( + P = I, + select = :final, + tsvd = false, + nstore = 100, + loglevel = 0 +) + +results_final = ACEfit.solve(solver_final, At, yt, Av, yv) +tracer_final = results_final["path"] + +# Warm-start the solver using the tracer from the final iteration +# select best solution with <= 73 non-zero entries +select = (:bysize, 73) +C_select, _ = ACEfit.asp_select(tracer_final, select) +@test( length(C_select.nzind) <= 73 ) + +# Check if starting the solver initially with (:bysize, 73) gives the same result +solver_bysize = ACEfit.ASP( + P = I, + select = (:bysize, 73), + tsvd = false, + nstore = 100, + loglevel = 0 +) + +results_bysize = ACEfit.solve(solver_bysize, At, yt, Av, yv) +@test results_bysize["C"] == C_select # works + +