Skip to content

Commit

Permalink
chore: rename j_must_have to j_deprecated and only warn about depreca…
Browse files Browse the repository at this point in the history
…ted keys (#3816)

Fix #3523.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit


- **Refactor**
- Simplified code by removing the `j_must_have` function and directly
accessing dictionary keys in various test files.
- Replaced `j_must_have` with direct dictionary access for improved code
readability and maintenance.

- **Chores**
- Updated test files to directly access dictionary values, enhancing
code readability and maintainability.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
  • Loading branch information
njzjz and coderabbitai[bot] authored May 24, 2024
1 parent 7b16911 commit 6aac9f8
Show file tree
Hide file tree
Showing 36 changed files with 191 additions and 326 deletions.
16 changes: 10 additions & 6 deletions deepmd/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
__all__ = [
"select_idx_map",
"make_default_mesh",
"j_must_have",
"j_loader",
"expand_sys_str",
"get_np_precision",
Expand Down Expand Up @@ -127,15 +126,20 @@ def make_default_mesh(pbc: bool, mixed_type: bool) -> np.ndarray:
return default_mesh


# TODO: rename j_must_have to j_deprecated and only warn about deprecated keys
# maybe rename this to j_deprecated and only warn about deprecated keys,
# if the deprecated_key argument is left empty function puppose is only custom
# error since dict[key] already raises KeyError when the key is missing
def j_must_have(
def j_deprecated(
jdata: Dict[str, "_DICT_VAL"], key: str, deprecated_key: List[str] = []
) -> "_DICT_VAL":
"""Assert that supplied dictionary conaines specified key.
Parameters
----------
jdata : Dict[str, _DICT_VAL]
dictionary to check
key : str
key to check
deprecated_key : List[str], optional
list of deprecated keys, by default []
Returns
-------
_DICT_VAL
Expand Down
2 changes: 0 additions & 2 deletions deepmd/tf/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
expand_sys_str,
get_np_precision,
j_loader,
j_must_have,
make_default_mesh,
select_idx_map,
)
Expand All @@ -47,7 +46,6 @@
# from deepmd.common
"select_idx_map",
"make_default_mesh",
"j_must_have",
"j_loader",
"expand_sys_str",
"get_np_precision",
Expand Down
3 changes: 1 addition & 2 deletions deepmd/tf/entrypoints/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

from deepmd.tf.common import (
j_loader,
j_must_have,
)
from deepmd.tf.env import (
reset_default_tf_session_config,
Expand Down Expand Up @@ -211,7 +210,7 @@ def _do_work(jdata: Dict[str, Any], run_opt: RunOptions, is_compress: bool = Fal
modifier.build_fv_graph()

# get training info
stop_batch = j_must_have(jdata["training"], "numb_steps")
stop_batch = jdata["training"]["numb_steps"]
origin_type_map = jdata["model"].get("origin_type_map", None)
if (
origin_type_map is not None and not origin_type_map
Expand Down
5 changes: 2 additions & 3 deletions deepmd/tf/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
)
from deepmd.tf.common import (
get_precision,
j_must_have,
)
from deepmd.tf.env import (
GLOBAL_ENER_FLOAT_PRECISION,
Expand Down Expand Up @@ -91,7 +90,7 @@ def __init__(self, jdata, run_opt, is_compress=False):

def _init_param(self, jdata):
# model config
model_param = j_must_have(jdata, "model")
model_param = jdata["model"]

# nvnmd
self.nvnmd_param = jdata.get("nvnmd", {})
Expand Down Expand Up @@ -123,7 +122,7 @@ def get_lr_and_coef(lr_param):
return lr, scale_lr_coef

# learning rate
lr_param = j_must_have(jdata, "learning_rate")
lr_param = jdata["learning_rate"]
self.lr, self.scale_lr_coef = get_lr_and_coef(lr_param)
# loss
# infer loss type by fitting_type
Expand Down
24 changes: 12 additions & 12 deletions deepmd/utils/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import numpy as np

from deepmd.common import (
j_must_have,
j_deprecated,
)


Expand Down Expand Up @@ -127,8 +127,8 @@ def _smth_descriptor(jdata: Dict[str, Any]) -> Dict[str, Any]:
descriptor["sel"] = jdata["sel_a"]
_jcopy(jdata, descriptor, ("rcut",))
descriptor["rcut_smth"] = jdata.get("rcut_smth", descriptor["rcut"])
descriptor["neuron"] = j_must_have(jdata, "filter_neuron")
descriptor["axis_neuron"] = j_must_have(jdata, "axis_neuron", ["n_axis_neuron"])
descriptor["neuron"] = jdata["filter_neuron"]
descriptor["axis_neuron"] = j_deprecated(jdata, "axis_neuron", ["n_axis_neuron"])
descriptor["resnet_dt"] = False
if "resnet_dt" in jdata:
descriptor["resnet_dt"] = jdata["filter_resnet_dt"]
Expand All @@ -154,7 +154,7 @@ def _fitting_net(jdata: Dict[str, Any]) -> Dict[str, Any]:
seed = jdata.get("seed", None)
if seed is not None:
fitting_net["seed"] = seed
fitting_net["neuron"] = j_must_have(jdata, "fitting_neuron", ["n_neuron"])
fitting_net["neuron"] = j_deprecated(jdata, "fitting_neuron", ["n_neuron"])
fitting_net["resnet_dt"] = True
if "resnet_dt" in jdata:
fitting_net["resnet_dt"] = jdata["resnet_dt"]
Expand Down Expand Up @@ -237,16 +237,16 @@ def _training(jdata: Dict[str, Any]) -> Dict[str, Any]:
training["disp_file"] = "lcurve.out"
if "disp_file" in jdata:
training["disp_file"] = jdata["disp_file"]
training["disp_freq"] = j_must_have(jdata, "disp_freq")
training["numb_test"] = j_must_have(jdata, "numb_test")
training["save_freq"] = j_must_have(jdata, "save_freq")
training["save_ckpt"] = j_must_have(jdata, "save_ckpt")
training["disp_training"] = j_must_have(jdata, "disp_training")
training["time_training"] = j_must_have(jdata, "time_training")
training["disp_freq"] = jdata["disp_freq"]
training["numb_test"] = jdata["numb_test"]
training["save_freq"] = jdata["save_freq"]
training["save_ckpt"] = jdata["save_ckpt"]
training["disp_training"] = jdata["disp_training"]
training["time_training"] = jdata["time_training"]
if "profiling" in jdata:
training["profiling"] = jdata["profiling"]
if training["profiling"]:
training["profiling_file"] = j_must_have(jdata, "profiling_file")
training["profiling_file"] = jdata["profiling_file"]
return training


Expand Down Expand Up @@ -378,7 +378,7 @@ def is_deepmd_v0_input(jdata):
return "model" not in jdata.keys()

def is_deepmd_v1_input(jdata):
return "systems" in j_must_have(jdata, "training").keys()
return "systems" in jdata["training"].keys()

if is_deepmd_v0_input(jdata):
jdata = convert_input_v0_v1(jdata, warning, None)
Expand Down
5 changes: 2 additions & 3 deletions deepmd/utils/data_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import deepmd.utils.random as dp_random
from deepmd.common import (
expand_sys_str,
j_must_have,
make_default_mesh,
)
from deepmd.env import (
Expand Down Expand Up @@ -792,10 +791,10 @@ def get_data(
DeepmdDataSystem
The data system
"""
systems = j_must_have(jdata, "systems")
systems = jdata["systems"]
systems = process_systems(systems)

batch_size = j_must_have(jdata, "batch_size")
batch_size = jdata["batch_size"]
sys_probs = jdata.get("sys_probs", None)
auto_prob = jdata.get("auto_prob", "prob_sys_size")
optional_type_map = not multi_task_mode
Expand Down
21 changes: 9 additions & 12 deletions source/tests/tf/test_data_large_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@
import numpy as np
from packaging.version import parse as parse_version

from deepmd.tf.common import (
j_must_have,
)
from deepmd.tf.descriptor import (
DescrptSeAtten,
)
Expand Down Expand Up @@ -50,11 +47,11 @@ def test_data_mixed_type(self):
jfile = "water_se_atten_mixed_type.json"
jdata = j_loader(jfile)

systems = j_must_have(jdata, "systems")
systems = jdata["systems"]
batch_size = 1
test_size = 1
rcut = j_must_have(jdata["model"]["descriptor"], "rcut")
type_map = j_must_have(jdata["model"], "type_map")
rcut = jdata["model"]["descriptor"]["rcut"]
type_map = jdata["model"]["type_map"]

data = DeepmdDataSystem(systems, batch_size, test_size, rcut, type_map=type_map)
data_requirement = {
Expand Down Expand Up @@ -248,11 +245,11 @@ def test_stripped_data_mixed_type(self):
jfile = "water_se_atten_mixed_type.json"
jdata = j_loader(jfile)

systems = j_must_have(jdata, "systems")
systems = jdata["systems"]
batch_size = 1
test_size = 1
rcut = j_must_have(jdata["model"]["descriptor"], "rcut")
type_map = j_must_have(jdata["model"], "type_map")
rcut = jdata["model"]["descriptor"]["rcut"]
type_map = jdata["model"]["type_map"]

data = DeepmdDataSystem(systems, batch_size, test_size, rcut, type_map=type_map)
data_requirement = {
Expand Down Expand Up @@ -446,11 +443,11 @@ def test_compressible_data_mixed_type(self):
jfile = "water_se_atten_mixed_type.json"
jdata = j_loader(jfile)

systems = j_must_have(jdata, "systems")
systems = jdata["systems"]
batch_size = 1
test_size = 1
rcut = j_must_have(jdata["model"]["descriptor"], "rcut")
type_map = j_must_have(jdata["model"], "type_map")
rcut = jdata["model"]["descriptor"]["rcut"]
type_map = jdata["model"]["type_map"]

data = DeepmdDataSystem(systems, batch_size, test_size, rcut, type_map=type_map)
data_requirement = {
Expand Down
9 changes: 3 additions & 6 deletions source/tests/tf/test_data_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@

import numpy as np

from deepmd.tf.common import (
j_must_have,
)
from deepmd.tf.env import (
GLOBAL_NP_FLOAT_PRECISION,
tf,
Expand Down Expand Up @@ -61,12 +58,12 @@ def _setUp(self):
rcut = model.model.get_rcut()

# init data system
systems = j_must_have(jdata["training"], "systems")
systems = jdata["training"]["systems"]
# systems[0] = tests_path / systems[0]
systems = [tests_path / ii for ii in systems]
set_pfx = "set"
batch_size = j_must_have(jdata["training"], "batch_size")
test_size = j_must_have(jdata["training"], "numb_test")
batch_size = jdata["training"]["batch_size"]
test_size = jdata["training"]["numb_test"]
data = DeepmdDataSystem(
systems, batch_size, test_size, rcut, set_prefix=set_pfx
)
Expand Down
9 changes: 3 additions & 6 deletions source/tests/tf/test_data_modifier_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@

import numpy as np

from deepmd.tf.common import (
j_must_have,
)
from deepmd.tf.env import (
GLOBAL_NP_FLOAT_PRECISION,
tf,
Expand Down Expand Up @@ -64,10 +61,10 @@ def _setUp(self):
rcut = model.model.get_rcut()

# init data system
systems = j_must_have(jdata["training"], "systems")
systems = jdata["training"]["systems"]
set_pfx = "set"
batch_size = j_must_have(jdata["training"], "batch_size")
test_size = j_must_have(jdata["training"], "numb_test")
batch_size = jdata["training"]["batch_size"]
test_size = jdata["training"]["numb_test"]
data = DeepmdDataSystem(
systems, batch_size, test_size, rcut, set_prefix=set_pfx
)
Expand Down
7 changes: 1 addition & 6 deletions source/tests/tf/test_descrpt_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
import numpy as np
from packaging.version import parse as parse_version

from deepmd.tf.common import (
j_must_have,
)
from deepmd.tf.descriptor import (
DescrptHybrid,
)
Expand Down Expand Up @@ -40,10 +37,8 @@ def test_descriptor_hybrid(self):
jfile = "water_hybrid.json"
jdata = j_loader(jfile)

systems = j_must_have(jdata, "systems")
systems = jdata["systems"]
set_pfx = "set"
batch_size = j_must_have(jdata, "batch_size")
test_size = j_must_have(jdata, "numb_test")
batch_size = 2
test_size = 1
rcut = 6
Expand Down
7 changes: 2 additions & 5 deletions source/tests/tf/test_descrpt_se_a_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@

import numpy as np

from deepmd.tf.common import (
j_must_have,
)
from deepmd.tf.descriptor import (
DescrptSeAMask,
)
Expand Down Expand Up @@ -231,12 +228,12 @@ def test_descriptor_se_a_mask(self):
jdata["training"]["validation_data"]["systems"] = [
str(tests_path / "data_dp_mask")
]
systems = j_must_have(jdata["training"]["validation_data"], "systems")
systems = jdata["training"]["validation_data"]["systems"]
set_pfx = "set"
batch_size = 2
test_size = 1
rcut = 20.0 # For DataSystem interface compatibility, not used in this test.
sel = j_must_have(jdata["model"]["descriptor"], "sel")
sel = jdata["model"]["descriptor"]["sel"]
ntypes = len(sel)
total_atom_num = np.cumsum(sel)[-1]

Expand Down
23 changes: 8 additions & 15 deletions source/tests/tf/test_descrpt_se_a_type.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import numpy as np

from deepmd.tf.common import (
j_must_have,
)
from deepmd.tf.descriptor import (
DescrptSeA,
)
Expand Down Expand Up @@ -33,15 +30,12 @@ def test_descriptor_two_sides(self):
jfile = "water_se_a_type.json"
jdata = j_loader(jfile)

systems = j_must_have(jdata, "systems")
systems = jdata["systems"]
set_pfx = "set"
batch_size = j_must_have(jdata, "batch_size")
test_size = j_must_have(jdata, "numb_test")
batch_size = 2
test_size = 1
stop_batch = j_must_have(jdata, "stop_batch")
rcut = j_must_have(jdata["model"]["descriptor"], "rcut")
sel = j_must_have(jdata["model"]["descriptor"], "sel")
rcut = jdata["model"]["descriptor"]["rcut"]
sel = jdata["model"]["descriptor"]["sel"]
ntypes = len(sel)

data = DataSystem(systems, set_pfx, batch_size, test_size, rcut, run_opt=None)
Expand Down Expand Up @@ -197,15 +191,14 @@ def test_descriptor_one_side(self):
jfile = "water_se_a_type.json"
jdata = j_loader(jfile)

systems = j_must_have(jdata, "systems")
systems = jdata["systems"]
set_pfx = "set"
batch_size = j_must_have(jdata, "batch_size")
test_size = j_must_have(jdata, "numb_test")
batch_size = jdata["batch_size"]
test_size = jdata["numb_test"]
batch_size = 1
test_size = 1
stop_batch = j_must_have(jdata, "stop_batch")
rcut = j_must_have(jdata["model"]["descriptor"], "rcut")
sel = j_must_have(jdata["model"]["descriptor"], "sel")
rcut = jdata["model"]["descriptor"]["rcut"]
sel = jdata["model"]["descriptor"]["sel"]
ntypes = len(sel)

data = DataSystem(systems, set_pfx, batch_size, test_size, rcut, run_opt=None)
Expand Down
Loading

0 comments on commit 6aac9f8

Please sign in to comment.