From e74f54b75d1217c64c5bd6740d8d337461ef871f Mon Sep 17 00:00:00 2001 From: Tim Holy Date: Sun, 14 Jul 2024 03:40:15 -0500 Subject: [PATCH] Fix verbose printing (#86) The `verbose=true` option was broken in #83, but it wasn't detected because it was never tested. --- src/common.jl | 14 ++++++++------ test/interf.jl | 5 +++++ 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/src/common.jl b/src/common.jl index 0f089a3..e374f19 100644 --- a/src/common.jl +++ b/src/common.jl @@ -54,7 +54,7 @@ function nmf_skeleton!(updater::NMFUpdater{T}, if verbose start = time() objv = evaluate_objv(updater, state, X, W, H) - @printf("%-5s %-13s %-13s %-13s %-13s\n", "Iter", "Elapsed time", "objv", "objv.change", "(W & H).change") + @printf("%-5s %-13s %-13s %-13s %-13s\n", "Iter", "Elapsed time", "objv", "objv.change", "(W & H).relchange") @printf("%5d %13.6e %13.6e\n", 0, 0.0, objv) end @@ -70,7 +70,7 @@ function nmf_skeleton!(updater::NMFUpdater{T}, update_wh!(updater, state, X, W, H) # determine convergence - converged = stop_condition(W, preW, H, preH, tol) + converged, dev = stop_condition(W, preW, H, preH, tol) # display info if verbose @@ -90,6 +90,7 @@ end function stop_condition(W::AbstractArray{T}, preW::AbstractArray, H::AbstractArray, preH::AbstractArray, eps::AbstractFloat) where T + devmax = zero(T) for j in axes(W,2) dev_w = sum_w = zero(T) for i in axes(W,1) @@ -100,10 +101,11 @@ function stop_condition(W::AbstractArray{T}, preW::AbstractArray, H::AbstractArr for i in axes(H,2) dev_h += (H[j,i] - preH[j,i])^2 sum_h += (H[j,i] + preH[j,i])^2 - end + end + devmax = max(devmax, sqrt(max(dev_w/sum_w, dev_h/sum_h))) if sqrt(dev_w) > eps*sqrt(sum_w) || sqrt(dev_h) > eps*sqrt(sum_h) - return false - end + return false, devmax + end end - return true + return true, devmax end diff --git a/test/interf.jl b/test/interf.jl index f32ca84..b90abd8 100644 --- a/test/interf.jl +++ b/test/interf.jl @@ -35,5 +35,10 @@ @test all(H .== ret.H) @test any(W .!= ret.W) end + + # printing test + redirect_stdout(devnull) do + ret = NMF.nnmf(X, k, alg=:cd, init=:nndsvd, verbose=true) + end end end