Skip to content

Commit

Permalink
better mem stat
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Jan 9, 2025
1 parent 169ca36 commit e579c74
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions users/zeyer/experiments/exp2024_04_23_baselines/ctc_recog_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,8 +480,10 @@ def model_recog_flashlight(
dev_s = rf.get_default_device()
dev = torch.device(dev_s)

total_mem = None
if dev.type == "cuda":
torch.cuda.reset_peak_memory_stats(dev)
_, total_mem = torch.cuda.mem_get_info(dev if dev.index is not None else None)

def _collect_mem_stats():
if dev.type == "cuda":
Expand All @@ -493,7 +495,12 @@ def _collect_mem_stats():
]
return ["(unknown)"]

print(f"Memory usage {dev_s} before encoder forward:", " ".join(_collect_mem_stats()))
print(
f"Memory usage {dev_s} before encoder forward:",
" ".join(_collect_mem_stats()),
"total:",
util.human_bytes_size(total_mem) if total_mem else "(unknown)",
)

lm_initial_state = lm.default_initial_state(batch_dims=[])

Expand Down Expand Up @@ -534,11 +541,15 @@ def _calc_next_lm_state(self, state: LMState) -> Tuple[Any, torch.Tensor]:

if dev.type == "cuda":
# Maybe check if we should free some more memory.
count_pop = 0
while self._calc_next_lm_state.cache_len() > 0:
free, total = torch.cuda.mem_get_info(dev if dev.index is not None else None)
if free / total > 0.2:
alloc = torch.cuda.memory_allocated(dev)
if alloc / total_mem < 0.8:
break
self._calc_next_lm_state.cache_pop_oldest()
count_pop += 1
if count_pop > 0:
print(f"Pop {count_pop} from cache, mem usage {dev_s}: {' '.join(_collect_mem_stats())}")

if prev_lm_state is not None or lm_initial_state is None:
# We have the prev state, or there is no state at all.
Expand Down

0 comments on commit e579c74

Please sign in to comment.