Skip to content

Commit

Permalink
Fix verbose printing (#86)
Browse files Browse the repository at this point in the history
The `verbose=true` option was broken in #83, but it wasn't detected
because it was never tested.
  • Loading branch information
timholy authored Jul 14, 2024
1 parent 6e29f4d commit e74f54b
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
14 changes: 8 additions & 6 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ function nmf_skeleton!(updater::NMFUpdater{T},
if verbose
start = time()
objv = evaluate_objv(updater, state, X, W, H)
@printf("%-5s %-13s %-13s %-13s %-13s\n", "Iter", "Elapsed time", "objv", "objv.change", "(W & H).change")
@printf("%-5s %-13s %-13s %-13s %-13s\n", "Iter", "Elapsed time", "objv", "objv.change", "(W & H).relchange")
@printf("%5d %13.6e %13.6e\n", 0, 0.0, objv)
end

Expand All @@ -70,7 +70,7 @@ function nmf_skeleton!(updater::NMFUpdater{T},
update_wh!(updater, state, X, W, H)

# determine convergence
converged = stop_condition(W, preW, H, preH, tol)
converged, dev = stop_condition(W, preW, H, preH, tol)

# display info
if verbose
Expand All @@ -90,6 +90,7 @@ end


function stop_condition(W::AbstractArray{T}, preW::AbstractArray, H::AbstractArray, preH::AbstractArray, eps::AbstractFloat) where T
devmax = zero(T)
for j in axes(W,2)
dev_w = sum_w = zero(T)
for i in axes(W,1)
Expand All @@ -100,10 +101,11 @@ function stop_condition(W::AbstractArray{T}, preW::AbstractArray, H::AbstractArr
for i in axes(H,2)
dev_h += (H[j,i] - preH[j,i])^2
sum_h += (H[j,i] + preH[j,i])^2
end
end
devmax = max(devmax, sqrt(max(dev_w/sum_w, dev_h/sum_h)))
if sqrt(dev_w) > eps*sqrt(sum_w) || sqrt(dev_h) > eps*sqrt(sum_h)
return false
end
return false, devmax
end
end
return true
return true, devmax
end
5 changes: 5 additions & 0 deletions test/interf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,10 @@
@test all(H .== ret.H)
@test any(W .!= ret.W)
end

# printing test
redirect_stdout(devnull) do
ret = NMF.nnmf(X, k, alg=:cd, init=:nndsvd, verbose=true)
end
end
end

0 comments on commit e74f54b

Please sign in to comment.