Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Post fit selection from asp path #88

Merged
merged 3 commits into from
Sep 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 31 additions & 9 deletions src/asp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,19 +83,28 @@ function solve(solver::ASP, A, y, Aval=A, yval=y)
for p in new_tracer ]
end

xs, in = select_solution(new_post, solver, Aval, yval)
tracer_final = _add_errors(new_post, Aval, yval)
xs, in = asp_select(tracer_final, solver.select)

return Dict( "C" => xs,
"path" => new_post,
"nnzs" => length( (new_post[in][:solution]).nzind) )
return Dict( "C" => xs,
"path" => tracer_final, )
end


function select_solution(tracer, solver, A, y)
if solver.select == :final
function _add_errors(tracer, A, y)
rtN = sqrt(length(y))
return [ ( solution = t.solution, λ = t.λ, σ = t.σ,
rmse = norm(A * t.solution - y) / rtN )
for t in tracer ]
end

asp_select(D::Dict, select) = asp_select(D["path"], select)

function asp_select(tracer, select)
if select == :final
criterion = :final
else
criterion, p = solver.select
criterion, p = select
end

if criterion == :final
Expand All @@ -108,12 +117,12 @@ function select_solution(tracer, solver, A, y)
elseif criterion == :bysize
maxind = findfirst(t -> length((t[:solution]).nzind) > p,
tracer) - 1
threshold = 1.0
threshold = 1.0
else
error("Unknown selection criterion: $criterion")
end

errors = [ norm(A * t[:solution] - y) for t in tracer[1:maxind] ]
errors = [ t.rmse for t in tracer[1:maxind] ]
min_error = minimum(errors)
for (i, error) in enumerate(errors)
if error <= threshold * min_error
Expand All @@ -140,3 +149,16 @@ function post_asp_tsvd(path, At, yt, Av, yv)

return _post.(path)
end

# TODO: revisit this idea. Maybe we do want to keep this, not as `select`
# but as `solve`. But if we do, then it might be nice to be able to
# extend the path somehow. For now I'm removing it since I don't see
# the immediate need yet. Just calling asp_select is how I would normally
# use this.
#
# function select(tracer, solver, A, y) #can be called by the user to warm-start the selection
# xs, in = select_solution(tracer, solver, A, y)
# return Dict("C" => xs,
# "path" => tracer,
# "nnzs" => length( (tracer[in][:solution]).nzind) )
# end
43 changes: 38 additions & 5 deletions test/test_asp.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using ACEfit
using LinearAlgebra, Random, Test
using LinearAlgebra, Random, Test

##

Expand Down Expand Up @@ -47,7 +47,7 @@ for (select, tolr, tolc) in [ (:final, 10*epsn, 1),
results = ACEfit.solve(solver, A, y)
C = results["C"]
full_path = results["path"]
@show results["nnzs"]
# @show results["nnzs"]
@show norm(A * C - y)
@show norm(C)
@show norm(C - c_ref)
Expand All @@ -60,7 +60,7 @@ for (select, tolr, tolc) in [ (:final, 10*epsn, 1),
results = ACEfit.solve(solver, At, yt, Av, yv)
C = results["C"]
full_path = results["path"]
@show results["nnzs"]
# @show results["nnzs"]
@show norm(Av * C - yv)
@show norm(C)
@show norm(C - c_ref)
Expand Down Expand Up @@ -91,7 +91,7 @@ for (select, tolr, tolc) in [ (:final, 20*epsn, 1.5),
C_tsvd = results_tsvd["C"]
C = results["C"]

@show results["nnzs"]
# @show results["nnzs"]
@show norm(A * C - y)
@show norm(A * C_tsvd - y)
if norm(A * C_tsvd - y)< norm(A * C - y)
Expand All @@ -106,7 +106,7 @@ for (select, tolr, tolc) in [ (:final, 20*epsn, 1.5),
results = ACEfit.solve(solver, At, yt, Av, yv)
C_tsvd = results_tsvd["C"]
C = results["C"]
@show results["nnzs"]
# @show results["nnzs"]
@show norm(A * C - y)
@show norm(A * C_tsvd - y)

Expand All @@ -117,3 +117,36 @@ for (select, tolr, tolc) in [ (:final, 20*epsn, 1.5),
end
end

##

# Testing the "select" function
solver_final = ACEfit.ASP(
P = I,
select = :final,
tsvd = false,
nstore = 100,
loglevel = 0
)

results_final = ACEfit.solve(solver_final, At, yt, Av, yv)
tracer_final = results_final["path"]

# Warm-start the solver using the tracer from the final iteration
# select best solution with <= 73 non-zero entries
select = (:bysize, 73)
C_select, _ = ACEfit.asp_select(tracer_final, select)
@test( length(C_select.nzind) <= 73 )

# Check if starting the solver initially with (:bysize, 73) gives the same result
solver_bysize = ACEfit.ASP(
P = I,
select = (:bysize, 73),
tsvd = false,
nstore = 100,
loglevel = 0
)

results_bysize = ACEfit.solve(solver_bysize, At, yt, Av, yv)
@test results_bysize["C"] == C_select # works


Loading