Skip to content

Commit

Permalink
Improve inferability
Browse files Browse the repository at this point in the history
  • Loading branch information
timholy committed Jul 14, 2024
1 parent 6e29f4d commit ca3d469
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 14 deletions.
8 changes: 4 additions & 4 deletions src/alspgrad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -385,13 +385,13 @@ solve!(alg::ALSPGrad, X, W, H) =

struct ALSPGradUpd_State{T}
WH::Matrix{T}
uhstate::ALSGradUpdH_State
uwstate::ALSGradUpdW_State
uhstate::ALSGradUpdH_State{T}
uwstate::ALSGradUpdW_State{T}

ALSPGradUpd_State{T}(X, W, H) where {T} =
new{T}(W * H,
ALSGradUpdH_State(X, W, H),
ALSGradUpdW_State(X, W, H))
ALSGradUpdH_State{T}(X, W, H),
ALSGradUpdW_State{T}(X, W, H))
end

prepare_state(::ALSPGradUpd{T}, X, W, H) where {T} = ALSPGradUpd_State{T}(X, W, H)
Expand Down
26 changes: 16 additions & 10 deletions src/interf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,37 +54,43 @@ function nnmf(X::AbstractMatrix{T}, k::Integer;
else
throw(ArgumentError("Invalid value for init."))
end
W = W::Matrix{T}
H = H::Matrix{T}

# choose algorithm
if alg == :projals
alginst = ProjectedALS{T}(maxiter=maxiter, tol=tol, verbose=verbose, update_H=update_H)
ret = solve_replicates!(ProjectedALS{T}(maxiter=maxiter, tol=tol, verbose=verbose, update_H=update_H), X, W, H; replicates, initH)
elseif alg == :alspgrad
alginst = ALSPGrad{T}(maxiter=maxiter, tol=tol, verbose=verbose, update_H=update_H)
ret = solve_replicates!(ALSPGrad{T}(maxiter=maxiter, tol=tol, verbose=verbose, update_H=update_H), X, W, H; replicates, initH)
elseif alg == :multmse
alginst = MultUpdate{T}(obj=:mse, maxiter=maxiter, tol=tol, verbose=verbose, update_H=update_H)
ret = solve_replicates!(MultUpdate{T}(obj=:mse, maxiter=maxiter, tol=tol, verbose=verbose, update_H=update_H), X, W, H; replicates, initH)
elseif alg == :multdiv
alginst = MultUpdate{T}(obj=:div, maxiter=maxiter, tol=tol, verbose=verbose, update_H=update_H)
ret = solve_replicates!(MultUpdate{T}(obj=:div, maxiter=maxiter, tol=tol, verbose=verbose, update_H=update_H), X, W, H; replicates, initH)
elseif alg == :cd
alginst = CoordinateDescent{T}(maxiter=maxiter, tol=tol, verbose=verbose, update_H=update_H)
ret = solve_replicates!(CoordinateDescent{T}(maxiter=maxiter, tol=tol, verbose=verbose, update_H=update_H), X, W, H; replicates, initH)
elseif alg == :greedycd
alginst = GreedyCD{T}(maxiter=maxiter, tol=tol, verbose=verbose, update_H=update_H)
ret = solve_replicates!(GreedyCD{T}(maxiter=maxiter, tol=tol, verbose=verbose, update_H=update_H), X, W, H; replicates, initH)
elseif alg == :spa
if init != :spa
throw(ArgumentError("Invalid value for init, use :spa instead."))
end
alginst = SPA{T}(obj=:mse)
ret = solve_replicates!(SPA{T}(obj=:mse), X, W, H; replicates, initH)
else
throw(ArgumentError("Invalid algorithm."))
end

# run optimization
return ret
end

function solve_replicates!(alginst, X, W, H; replicates, initH)
ret = solve!(alginst, X, W, H)
k = size(W, 2)

# replicates
minobjv = ret.objvalue
for _ in 2:replicates
W, H = randinit(X, k; zeroh=!initH, normalize=true)
tmp = solve!(alginst, X, W, H)
Wrand, Hrand = randinit(X, k; zeroh=!initH, normalize=true)
tmp = solve!(alginst, X, Wrand, Hrand)
if minobjv > tmp.objvalue
ret = tmp
minobjv = tmp.objvalue
Expand Down

0 comments on commit ca3d469

Please sign in to comment.