Skip to content

Commit

Permalink
argcheck: restrict the type of elements in a list
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Oct 23, 2023
1 parent e9be507 commit 3053927
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 44 deletions.
136 changes: 93 additions & 43 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def type_embedding_args():
doc_trainable = "If the parameters in the embedding net are trainable"

return [
Argument("neuron", list, optional=True, default=[8], doc=doc_neuron),
Argument("neuron", List[int], optional=True, default=[8], doc=doc_neuron),
Argument(
"activation_function",
str,
Expand All @@ -77,9 +77,9 @@ def spin_args():
doc_virtual_len = "The distance between virtual atom representing spin and its corresponding real atom for each atom type with spin"

return [
Argument("use_spin", list, doc=doc_use_spin),
Argument("spin_norm", list, doc=doc_spin_norm),
Argument("virtual_len", list, doc=doc_virtual_len),
Argument("use_spin", List[bool], doc=doc_use_spin),
Argument("spin_norm", List[float], doc=doc_spin_norm),
Argument("virtual_len", List[float], doc=doc_virtual_len),
]


Expand Down Expand Up @@ -159,10 +159,10 @@ def descrpt_local_frame_args():
- axis_rule[i*6+5]: index of the axis atom defining the second axis. Note that the neighbors with the same class and type are sorted according to their relative distance."

return [
Argument("sel_a", list, optional=False, doc=doc_sel_a),
Argument("sel_r", list, optional=False, doc=doc_sel_r),
Argument("sel_a", List[int], optional=False, doc=doc_sel_a),
Argument("sel_r", List[int], optional=False, doc=doc_sel_r),
Argument("rcut", float, optional=True, default=6.0, doc=doc_rcut),
Argument("axis_rule", list, optional=False, doc=doc_axis_rule),
Argument("axis_rule", List[int], optional=False, doc=doc_axis_rule),
]


Expand All @@ -185,10 +185,12 @@ def descrpt_se_a_args():
doc_set_davg_zero = "Set the normalization average to zero. This option should be set when `atom_ener` in the energy fitting is used"

return [
Argument("sel", [list, str], optional=True, default="auto", doc=doc_sel),
Argument("sel", [List[int], str], optional=True, default="auto", doc=doc_sel),
Argument("rcut", float, optional=True, default=6.0, doc=doc_rcut),
Argument("rcut_smth", float, optional=True, default=0.5, doc=doc_rcut_smth),
Argument("neuron", list, optional=True, default=[10, 20, 40], doc=doc_neuron),
Argument(
"neuron", List[int], optional=True, default=[10, 20, 40], doc=doc_neuron
),
Argument(
"axis_neuron",
int,
Expand All @@ -212,7 +214,11 @@ def descrpt_se_a_args():
Argument("trainable", bool, optional=True, default=True, doc=doc_trainable),
Argument("seed", [int, None], optional=True, doc=doc_seed),
Argument(
"exclude_types", list, optional=True, default=[], doc=doc_exclude_types
"exclude_types",
List[List[int]],
optional=True,
default=[],
doc=doc_exclude_types,
),
Argument(
"set_davg_zero", bool, optional=True, default=False, doc=doc_set_davg_zero
Expand All @@ -236,10 +242,12 @@ def descrpt_se_t_args():
doc_set_davg_zero = "Set the normalization average to zero. This option should be set when `atom_ener` in the energy fitting is used"

return [
Argument("sel", [list, str], optional=True, default="auto", doc=doc_sel),
Argument("sel", [List[int], str], optional=True, default="auto", doc=doc_sel),
Argument("rcut", float, optional=True, default=6.0, doc=doc_rcut),
Argument("rcut_smth", float, optional=True, default=0.5, doc=doc_rcut_smth),
Argument("neuron", list, optional=True, default=[10, 20, 40], doc=doc_neuron),
Argument(
"neuron", List[int], optional=True, default=[10, 20, 40], doc=doc_neuron
),
Argument(
"activation_function",
str,
Expand Down Expand Up @@ -289,10 +297,12 @@ def descrpt_se_r_args():
doc_set_davg_zero = "Set the normalization average to zero. This option should be set when `atom_ener` in the energy fitting is used"

return [
Argument("sel", [list, str], optional=True, default="auto", doc=doc_sel),
Argument("sel", [List[int], str], optional=True, default="auto", doc=doc_sel),
Argument("rcut", float, optional=True, default=6.0, doc=doc_rcut),
Argument("rcut_smth", float, optional=True, default=0.5, doc=doc_rcut_smth),
Argument("neuron", list, optional=True, default=[10, 20, 40], doc=doc_neuron),
Argument(
"neuron", List[int], optional=True, default=[10, 20, 40], doc=doc_neuron
),
Argument(
"activation_function",
str,
Expand All @@ -308,7 +318,11 @@ def descrpt_se_r_args():
Argument("trainable", bool, optional=True, default=True, doc=doc_trainable),
Argument("seed", [int, None], optional=True, doc=doc_seed),
Argument(
"exclude_types", list, optional=True, default=[], doc=doc_exclude_types
"exclude_types",
List[List[int]],
optional=True,
default=[],
doc=doc_exclude_types,
),
Argument(
"set_davg_zero", bool, optional=True, default=False, doc=doc_set_davg_zero
Expand Down Expand Up @@ -356,10 +370,14 @@ def descrpt_se_atten_common_args():
doc_attn_mask = "Whether to do mask on the diagonal in the attention matrix"

return [
Argument("sel", [int, list, str], optional=True, default="auto", doc=doc_sel),
Argument(
"sel", [int, List[int], str], optional=True, default="auto", doc=doc_sel
),
Argument("rcut", float, optional=True, default=6.0, doc=doc_rcut),
Argument("rcut_smth", float, optional=True, default=0.5, doc=doc_rcut_smth),
Argument("neuron", list, optional=True, default=[10, 20, 40], doc=doc_neuron),
Argument(
"neuron", List[int], optional=True, default=[10, 20, 40], doc=doc_neuron
),
Argument(
"axis_neuron",
int,
Expand All @@ -383,7 +401,11 @@ def descrpt_se_atten_common_args():
Argument("trainable", bool, optional=True, default=True, doc=doc_trainable),
Argument("seed", [int, None], optional=True, doc=doc_seed),
Argument(
"exclude_types", list, optional=True, default=[], doc=doc_exclude_types
"exclude_types",
List[List[int]],
optional=True,
default=[],
doc=doc_exclude_types,
),
Argument("attn", int, optional=True, default=128, doc=doc_attn),
Argument("attn_layer", int, optional=True, default=2, doc=doc_attn_layer),
Expand Down Expand Up @@ -454,8 +476,10 @@ def descrpt_se_a_mask_args():
doc_seed = "Random seed for parameter initialization"

return [
Argument("sel", [list, str], optional=True, default="auto", doc=doc_sel),
Argument("neuron", list, optional=True, default=[10, 20, 40], doc=doc_neuron),
Argument("sel", [List[int], str], optional=True, default="auto", doc=doc_sel),
Argument(
"neuron", List[int], optional=True, default=[10, 20, 40], doc=doc_neuron
),
Argument(
"axis_neuron",
int,
Expand All @@ -476,7 +500,11 @@ def descrpt_se_a_mask_args():
"type_one_side", bool, optional=True, default=False, doc=doc_type_one_side
),
Argument(
"exclude_types", list, optional=True, default=[], doc=doc_exclude_types
"exclude_types",
List[List[int]],
optional=True,
default=[],
doc=doc_exclude_types,
),
Argument("precision", str, optional=True, default="default", doc=doc_precision),
Argument("trainable", bool, optional=True, default=True, doc=doc_trainable),
Expand Down Expand Up @@ -525,7 +553,7 @@ def fitting_ener():
doc_resnet_dt = 'Whether to use a "Timestep" in the skip connection'
doc_trainable = "Whether the parameters in the fitting net are trainable. This option can be\n\n\
- bool: True if all parameters of the fitting net are trainable, False otherwise.\n\n\
- list of bool: Specifies if each layer is trainable. Since the fitting net is composed by hidden layers followed by a output layer, the length of tihs list should be equal to len(`neuron`)+1."
- list of bool: Specifies if each layer is trainable. Since the fitting net is composed by hidden layers followed by a output layer, the length of this list should be equal to len(`neuron`)+1."
doc_rcond = "The condition number used to determine the inital energy shift for each type of atoms. See `rcond` in :py:meth:`numpy.linalg.lstsq` for more details."
doc_seed = "Random seed for parameter initialization of the fitting net"
doc_atom_ener = "Specify the atomic energy in vacuum for each type"
Expand All @@ -547,7 +575,7 @@ def fitting_ener():
Argument("numb_aparam", int, optional=True, default=0, doc=doc_numb_aparam),
Argument(
"neuron",
list,
List[int],
optional=True,
default=[120, 120, 120],
alias=["n_neuron"],
Expand All @@ -563,14 +591,24 @@ def fitting_ener():
Argument("precision", str, optional=True, default="default", doc=doc_precision),
Argument("resnet_dt", bool, optional=True, default=True, doc=doc_resnet_dt),
Argument(
"trainable", [list, bool], optional=True, default=True, doc=doc_trainable
"trainable",
[List[bool], bool],
optional=True,
default=True,
doc=doc_trainable,
),
Argument(
"rcond", [float, type(None)], optional=True, default=None, doc=doc_rcond
),
Argument("seed", [int, None], optional=True, doc=doc_seed),
Argument("atom_ener", list, optional=True, default=[], doc=doc_atom_ener),
Argument("layer_name", list, optional=True, doc=doc_layer_name),
Argument(
"atom_ener",
List[Optional[float]],
optional=True,
default=[],
doc=doc_atom_ener,
),
Argument("layer_name", List[str], optional=True, doc=doc_layer_name),
Argument(
"use_aparam_as_mask",
bool,
Expand Down Expand Up @@ -602,7 +640,7 @@ def fitting_dos():
Argument("numb_fparam", int, optional=True, default=0, doc=doc_numb_fparam),
Argument("numb_aparam", int, optional=True, default=0, doc=doc_numb_aparam),
Argument(
"neuron", list, optional=True, default=[120, 120, 120], doc=doc_neuron
"neuron", List[int], optional=True, default=[120, 120, 120], doc=doc_neuron
),
Argument(
"activation_function",
Expand All @@ -614,7 +652,11 @@ def fitting_dos():
Argument("precision", str, optional=True, default="float64", doc=doc_precision),
Argument("resnet_dt", bool, optional=True, default=True, doc=doc_resnet_dt),
Argument(
"trainable", [list, bool], optional=True, default=True, doc=doc_trainable
"trainable",
[List[bool], bool],
optional=True,
default=True,
doc=doc_trainable,
),
Argument(
"rcond", [float, type(None)], optional=True, default=None, doc=doc_rcond
Expand Down Expand Up @@ -642,7 +684,7 @@ def fitting_polar():
return [
Argument(
"neuron",
list,
List[int],
optional=True,
default=[120, 120, 120],
alias=["n_neuron"],
Expand All @@ -658,12 +700,14 @@ def fitting_polar():
Argument("resnet_dt", bool, optional=True, default=True, doc=doc_resnet_dt),
Argument("precision", str, optional=True, default="default", doc=doc_precision),
Argument("fit_diag", bool, optional=True, default=True, doc=doc_fit_diag),
Argument("scale", [list, float], optional=True, default=1.0, doc=doc_scale),
Argument(
"scale", [List[float], float], optional=True, default=1.0, doc=doc_scale
),
# Argument("diag_shift", [list,float], optional = True, default = 0.0, doc = doc_diag_shift),
Argument("shift_diag", bool, optional=True, default=True, doc=doc_shift_diag),
Argument(
"sel_type",
[list, int, None],
[List[int], int, None],
optional=True,
alias=["pol_type"],
doc=doc_sel_type,
Expand All @@ -687,7 +731,7 @@ def fitting_dipole():
return [
Argument(
"neuron",
list,
List[int],
optional=True,
default=[120, 120, 120],
alias=["n_neuron"],
Expand All @@ -704,7 +748,7 @@ def fitting_dipole():
Argument("precision", str, optional=True, default="default", doc=doc_precision),
Argument(
"sel_type",
[list, int, None],
[List[int], int, None],
optional=True,
alias=["dipole_type"],
doc=doc_sel_type,
Expand Down Expand Up @@ -740,8 +784,10 @@ def modifier_dipole_charge():

return [
Argument("model_name", str, optional=False, doc=doc_model_name),
Argument("model_charge_map", list, optional=False, doc=doc_model_charge_map),
Argument("sys_charge_map", list, optional=False, doc=doc_sys_charge_map),
Argument(
"model_charge_map", List[float], optional=False, doc=doc_model_charge_map
),
Argument("sys_charge_map", List[float], optional=False, doc=doc_sys_charge_map),
Argument("ewald_beta", float, optional=True, default=0.4, doc=doc_ewald_beta),
Argument("ewald_h", float, optional=True, default=1.0, doc=doc_ewald_h),
]
Expand Down Expand Up @@ -770,7 +816,7 @@ def model_compression():

return [
Argument("model_file", str, optional=False, doc=doc_model_file),
Argument("table_config", list, optional=False, doc=doc_table_config),
Argument("table_config", List[float], optional=False, doc=doc_table_config),
Argument("min_nbor_dist", float, optional=False, doc=doc_min_nbor_dist),
]

Expand Down Expand Up @@ -814,7 +860,7 @@ def model_args(exclude_hybrid=False):
"model",
dict,
[
Argument("type_map", list, optional=True, doc=doc_type_map),
Argument("type_map", List[str], optional=True, doc=doc_type_map),
Argument(
"data_stat_nbatch",
int,
Expand Down Expand Up @@ -1456,11 +1502,13 @@ def training_data_args(): # ! added by Ziyao: new specification style for data
)

args = [
Argument("systems", [list, str], optional=False, default=".", doc=doc_systems),
Argument(
"systems", [List[str], str], optional=False, default=".", doc=doc_systems
),
Argument("set_prefix", str, optional=True, default="set", doc=doc_set_prefix),
Argument(
"batch_size",
[list, int, str],
[List[int], int, str],
optional=True,
default="auto",
doc=doc_batch_size,
Expand All @@ -1477,7 +1525,7 @@ def training_data_args(): # ! added by Ziyao: new specification style for data
),
Argument(
"sys_probs",
list,
List[float],
optional=True,
default=None,
doc=doc_sys_probs,
Expand Down Expand Up @@ -1521,11 +1569,13 @@ def validation_data_args(): # ! added by Ziyao: new specification style for dat
doc_numb_btch = "An integer that specifies the number of batches to be sampled for each validation period."

args = [
Argument("systems", [list, str], optional=False, default=".", doc=doc_systems),
Argument(
"systems", [List[str], str], optional=False, default=".", doc=doc_systems
),
Argument("set_prefix", str, optional=True, default="set", doc=doc_set_prefix),
Argument(
"batch_size",
[list, int, str],
[List[int], int, str],
optional=True,
default="auto",
doc=doc_batch_size,
Expand All @@ -1542,7 +1592,7 @@ def validation_data_args(): # ! added by Ziyao: new specification style for dat
),
Argument(
"sys_probs",
list,
List[float],
optional=True,
default=None,
doc=doc_sys_probs,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ dependencies = [
'numpy',
'scipy',
'pyyaml',
'dargs >= 0.3.5',
'dargs >= 0.4.0',
'python-hostlist >= 1.21',
'typing_extensions; python_version < "3.8"',
'importlib_metadata>=1.4; python_version < "3.8"',
Expand Down

0 comments on commit 3053927

Please sign in to comment.