Skip to content

Commit

Permalink
drop tqdm
Browse files Browse the repository at this point in the history
per discussion.

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Jan 28, 2024
1 parent a8168b5 commit 25405b4
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 40 deletions.
1 change: 0 additions & 1 deletion backend/dynamic_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,5 @@ def dynamic_metadata(
],
"torch": [
"torch>=2a",
"tqdm",
],
}
48 changes: 18 additions & 30 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,6 @@

import numpy as np
import torch
from tqdm import (
tqdm,
)
from tqdm.contrib.logging import (
logging_redirect_tqdm,
)

from deepmd.common import (
symlink_prefix_files,
Expand Down Expand Up @@ -47,7 +41,6 @@
)
from deepmd.pt.utils.env import (
DEVICE,
DISABLE_TQDM,
JIT,
LOCAL_RANK,
NUM_WORKERS,
Expand Down Expand Up @@ -662,29 +655,24 @@ def log_loss_valid(_task_key="Default"):
f.write(str(self.latest_model))

self.t0 = time.time()
with logging_redirect_tqdm():
for step_id in tqdm(
range(self.num_steps),
disable=(bool(dist.get_rank()) if dist.is_initialized() else False)
or DISABLE_TQDM,
): # set to None to disable on non-TTY; disable on not rank 0
if step_id < self.start_step:
continue
if self.multi_task:
chosen_index_list = dp_random.choice(
np.arange(self.num_model),
p=np.array(self.model_prob),
size=self.world_size,
replace=True,
)
assert chosen_index_list.size == self.world_size
model_index = chosen_index_list[self.rank]
model_key = self.model_keys[model_index]
else:
model_key = "Default"
step(step_id, model_key)
if JIT:
break
for step_id in range(self.num_steps):
if step_id < self.start_step:
continue

Check warning on line 660 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L660

Added line #L660 was not covered by tests
if self.multi_task:
chosen_index_list = dp_random.choice(

Check warning on line 662 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L662

Added line #L662 was not covered by tests
np.arange(self.num_model),
p=np.array(self.model_prob),
size=self.world_size,
replace=True,
)

Check failure

Code scanning / CodeQL

Wrong name for an argument in a call Error

Keyword argument 'replace' is not a supported parameter name of
function choice
.
Keyword argument 'size' is not a supported parameter name of
function choice
.
assert chosen_index_list.size == self.world_size
model_index = chosen_index_list[self.rank]
model_key = self.model_keys[model_index]

Check warning on line 670 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L668-L670

Added lines #L668 - L670 were not covered by tests
else:
model_key = "Default"
step(step_id, model_key)
if JIT:
break

Check warning on line 675 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L675

Added line #L675 was not covered by tests

if (
self.rank == 0 or dist.get_rank() == 0
Expand Down
5 changes: 1 addition & 4 deletions deepmd/pt/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@
from torch.utils.data import (
Dataset,
)
from tqdm import (
trange,
)

from deepmd.pt.utils import (
dp_random,
Expand Down Expand Up @@ -506,7 +503,7 @@ def preprocess(self, batch):
assert batch["atype"].max() < len(self._type_map)
nlist, nlist_loc, nlist_type, shift, mapping = [], [], [], [], []

for sid in trange(n_frames, disable=env.DISABLE_TQDM):
for sid in range(n_frames):
region = Region3D(box[sid])
nloc = atype[sid].shape[0]
_coord = normalize_coord(coord[sid], region, nloc)
Expand Down
1 change: 0 additions & 1 deletion deepmd/pt/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
GLOBAL_NP_FLOAT_PRECISION = getattr(np, PRECISION)
GLOBAL_PT_FLOAT_PRECISION = getattr(torch, PRECISION)
GLOBAL_ENER_FLOAT_PRECISION = getattr(np, PRECISION)
DISABLE_TQDM = os.environ.get("DISABLE_TQDM", False)
SAMPLER_RECORD = os.environ.get("SAMPLER_RECORD", False)
try:
# only linux
Expand Down
5 changes: 1 addition & 4 deletions deepmd/pt/utils/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@

import numpy as np
import torch
from tqdm import (
trange,
)

from deepmd.pt.utils import (
env,
Expand Down Expand Up @@ -40,7 +37,7 @@ def make_stat_input(datasets, dataloaders, nbatches):
if datasets[0].mixed_type:
keys.append("real_natoms_vec")
logging.info(f"Packing data for statistics from {len(datasets)} systems")
for i in trange(len(datasets), disable=env.DISABLE_TQDM):
for i in range(len(datasets)):
sys_stat = {key: [] for key in keys}
iterator = iter(dataloaders[i])
for _ in range(nbatches):
Expand Down

0 comments on commit 25405b4

Please sign in to comment.