Skip to content

Commit

Permalink
Merge pull request #78 from tinatorabi/main
Browse files Browse the repository at this point in the history
ASP
  • Loading branch information
cortner authored Aug 29, 2024
2 parents d96fa8d + b706e79 commit bb9da84
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 5 deletions.
11 changes: 6 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["William C Witt <[email protected]>, Christoph Ortner <christophor
version = "0.2.1"

[deps]
ActiveSetPursuit = "d86c1dc8-ba26-4c98-b330-3a8efc174d20"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -21,25 +22,25 @@ MLJScikitLearnInterface = "5ae90465-5518-4432-b9d2-8a1def2f0cab"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"

[extensions]
ACEfit_PythonCall_ext = "PythonCall"
ACEfit_MLJLinearModels_ext = [ "MLJ", "MLJLinearModels" ]
ACEfit_MLJLinearModels_ext = ["MLJ", "MLJLinearModels"]
ACEfit_MLJScikitLearnInterface_ext = ["MLJScikitLearnInterface", "PythonCall", "MLJ"]
ACEfit_PythonCall_ext = "PythonCall"

[compat]
julia = "1.9"
IterativeSolvers = "0.9.2"
LowRankApprox = "0.5.3"
MLJ = "0.19"
MLJLinearModels = "0.9"
MLJScikitLearnInterface = "0.7"
LowRankApprox = "0.5.3"
Optim = "1.7"
ParallelDataTransfer = "0.5.0"
ProgressMeter = "1.7"
PythonCall = "0.9"
StaticArrays = "1.5"
julia = "1.9"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", ]
test = ["Test"]
93 changes: 93 additions & 0 deletions src/solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using LowRankApprox: pqrfact
using IterativeSolvers
using .BayesianLinear
using LinearAlgebra: SVD, svd
using ActiveSetPursuit

@doc raw"""
`struct QR` : linear least squares solver, using standard QR factorisation;
Expand Down Expand Up @@ -195,3 +196,95 @@ function solve(solver::TruncatedSVD, A, y)
return Dict{String, Any}("C" => solver.P \ θP)
end


@doc raw"""
`struct ASP` : Active Set Pursuit sparse solver
solves the following optimization problem using the homotopy approach:
```math
\max_{y} \left( b^T y - \frac{1}{2} λ y^T y \right)
```
subject to
```math
\|A^T y\|_{\infty} \leq 1.
```
* Input
* `A` : `m`-by-`n` explicit matrix or linear operator.
* `b` : `m`-vector.
* Solver parameters
* `min_lambda` : Minimum value for `λ`. Defaults to zero if not provided.
* `loglevel` : Logging level.
* `itnMax` : Maximum number of iterations.
* `actMax` : Maximum number of active constraints.
Constructor
```julia
ACEfit.ASP(; P = I, select, params)
```
where
- `P` : right-preconditioner / tychonov operator
- `select`: Selection mode for the final solution.
- `(:byerror, th)`: Selects the smallest active set fit within a factor `th` of the smallest fit error.
- `(:final, nothing)`: Returns the final iterate.
- `params`: The solver parameters, passed as named arguments.
"""
struct ASP
P::Any
select::Tuple
params::NamedTuple
end

function ASP(; P = I, select, params...)
params_tuple = NamedTuple(params)
return ASP(P, select, params_tuple)
end

function solve(solver::ASP, A, y)
# Apply preconditioning
AP = A / solver.P

tracer = asp_homotopy(AP, y; solver.params[1]...)

new_tracer = Vector{NamedTuple{(:solution, :λ), Tuple{Any, Any}}}(undef, length(tracer))

for i in 1:length(tracer)
new_tracer[i] = (solution = solver.P \ tracer[i][1], λ = tracer[i][2])
end

# Select the final solution based on the criterion
xs, in = select_solution(new_tracer, solver, A, y)

println("done.")
return Dict("C" => xs, "path" => new_tracer, "nnzs" => length((tracer[in][1]).nzind) )
end

function select_solution(tracer, solver, A, y)
criterion, threshold = solver.select

if criterion == :final
return tracer[end][1], length(tracer)

elseif criterion == :byerror
errors = [norm(A * t[1] - y) for t in tracer]
min_error = minimum(errors)

# Find the solution with the smallest error within the threshold
for (i, error) in enumerate(errors)
if error <= threshold * min_error
return tracer[i][1], i
end
end
elseif criterion == :bysize
for i in 1:length(tracer)
if length((tracer[i][1]).nzind) == threshold
return tracer[i][1], i
end
end
else
@error("Unknown selection criterion: $criterion")
end
end

29 changes: 29 additions & 0 deletions test/test_linearsolvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,32 @@ C = results["C"]
@show norm(C)
@show norm(C - c_ref)

@info(" ... ASP_homotopy selected by error")
solver = ACEfit.ASP(P = P, select = (:byerror,1.5), params = (loglevel=0, traceFlag=true))
results = ACEfit.solve(solver, A, y)
C = results["C"]
full_path = results["path"]
@show results["nnzs"]
@show norm(A * C - y)
@show norm(C)
@show norm(C - c_ref)

@info(" ... ASP_homotopy selected by size")
solver = ACEfit.ASP(P = P, select = (:bysize,50), params = (loglevel=0, traceFlag=true))
results = ACEfit.solve(solver, A, y)
C = results["C"]
full_path = results["path"]
@show results["nnzs"]
@show norm(A * C - y)
@show norm(C)
@show norm(C - c_ref)

@info(" ... ASP_homotopy final solution")
solver = ACEfit.ASP(P = P, select = (:final,nothing), params = (loglevel=0, traceFlag=true))
results = ACEfit.solve(solver, A, y)
C = results["C"]
full_path = results["path"]
@show results["nnzs"]
@show norm(A * C - y)
@show norm(C)
@show norm(C - c_ref)

0 comments on commit bb9da84

Please sign in to comment.