From ea924695a059c7dc1f622a9905ae53a37c14eaea Mon Sep 17 00:00:00 2001 From: andrewdipper Date: Fri, 2 Aug 2024 21:19:44 -0700 Subject: [PATCH] rename --- numpyro/util.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/numpyro/util.py b/numpyro/util.py index 0d1d0e098..512edf513 100644 --- a/numpyro/util.py +++ b/numpyro/util.py @@ -201,8 +201,7 @@ def progress_bar_factory(num_samples, num_chains): remainder = num_samples % print_rate - chain_idx = {} - + idx_map = {} tqdm_bars = {} finished_chains = [] for chain in range(num_chains): @@ -211,16 +210,16 @@ def progress_bar_factory(num_samples, num_chains): def _calc_chain_idx(iter_num): try: - cval = chain_idx[iter_num] + idx = idx_map[iter_num] except KeyError: - cval = 0 - chain_idx[iter_num] = 0 + idx = 0 + idx_map[iter_num] = 0 - if cval + 1 == num_chains: - del chain_idx[iter_num] + if idx + 1 == num_chains: + del idx_map[iter_num] else: - chain_idx[iter_num] += 1 - return cval + idx_map[iter_num] += 1 + return idx def _update_tqdm(iter_num, increment): iter_num = int(iter_num) @@ -232,7 +231,7 @@ def _update_tqdm(iter_num, increment): def _close_tqdm(iter_num, increment): iter_num = int(iter_num) increment = int(increment) - chain = _calc_chain_idx(iter_num + 1) + chain = _calc_chain_idx(iter_num + 1) # +1 so no collision in idx_map tqdm_bars[chain].update(increment) finished_chains.append(chain) if len(finished_chains) == num_chains: