Skip to content

Commit

Permalink
rename
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewdipper committed Aug 3, 2024
1 parent 6f11983 commit ea92469
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions numpyro/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,7 @@ def progress_bar_factory(num_samples, num_chains):

remainder = num_samples % print_rate

chain_idx = {}

idx_map = {}
tqdm_bars = {}
finished_chains = []
for chain in range(num_chains):
Expand All @@ -211,16 +210,16 @@ def progress_bar_factory(num_samples, num_chains):

def _calc_chain_idx(iter_num):
try:
cval = chain_idx[iter_num]
idx = idx_map[iter_num]
except KeyError:
cval = 0
chain_idx[iter_num] = 0
idx = 0
idx_map[iter_num] = 0

if cval + 1 == num_chains:
del chain_idx[iter_num]
if idx + 1 == num_chains:
del idx_map[iter_num]
else:
chain_idx[iter_num] += 1
return cval
idx_map[iter_num] += 1
return idx

def _update_tqdm(iter_num, increment):
iter_num = int(iter_num)
Expand All @@ -232,7 +231,7 @@ def _update_tqdm(iter_num, increment):
def _close_tqdm(iter_num, increment):
iter_num = int(iter_num)
increment = int(increment)
chain = _calc_chain_idx(iter_num + 1)
chain = _calc_chain_idx(iter_num + 1) # +1 so no collision in idx_map
tqdm_bars[chain].update(increment)
finished_chains.append(chain)
if len(finished_chains) == num_chains:
Expand Down

0 comments on commit ea92469

Please sign in to comment.