Skip to content

Commit

Permalink
style: add -> None type hints and apply TCH and PYI rules (#4352)
Browse files Browse the repository at this point in the history
use Ruff

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Nov 13, 2024
1 parent 47b76c8 commit 320c7fd
Show file tree
Hide file tree
Showing 456 changed files with 2,211 additions and 2,138 deletions.
2 changes: 1 addition & 1 deletion backend/read_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def get_argument_from_env() -> tuple[str, list, list, dict, str, str]:
)


def set_scikit_build_env():
def set_scikit_build_env() -> None:
"""Set scikit-build environment variables before executing scikit-build."""
cmake_minimum_required_version, cmake_args, _, _, _, _ = get_argument_from_env()
os.environ["SKBUILD_CMAKE_MINIMUM_VERSION"] = cmake_minimum_required_version
Expand Down
2 changes: 1 addition & 1 deletion data/json/json2yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import yaml


def _main():
def _main() -> None:
parser = argparse.ArgumentParser(
description="convert json config file to yaml",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
Expand Down
4 changes: 2 additions & 2 deletions data/raw/copy_raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np


def copy(in_dir, out_dir, ncopies=[1, 1, 1]):
def copy(in_dir, out_dir, ncopies=[1, 1, 1]) -> None:
has_energy = os.path.isfile(in_dir + "/energy.raw")
has_force = os.path.isfile(in_dir + "/force.raw")
has_virial = os.path.isfile(in_dir + "/virial.raw")
Expand Down Expand Up @@ -71,7 +71,7 @@ def copy(in_dir, out_dir, ncopies=[1, 1, 1]):
np.savetxt(out_dir + "/ncopies.raw", ncopies, fmt="%d")


def _main():
def _main() -> None:
parser = argparse.ArgumentParser(description="parse copy raw args")
parser.add_argument("INPUT", default=".", help="input dir of raw files")
parser.add_argument("OUTPUT", default=".", help="output dir of copied raw files")
Expand Down
2 changes: 1 addition & 1 deletion data/raw/shuffle_raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def detect_raw(path):
return raws


def _main():
def _main() -> None:
args = _parse_args()
raws = args.raws
inpath = args.INPUT
Expand Down
2 changes: 1 addition & 1 deletion deepmd/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def calculate(
atoms: Optional["Atoms"] = None,
properties: list[str] = ["energy", "forces", "virial"],
system_changes: list[str] = all_changes,
):
) -> None:
"""Run calculation with deepmd model.
Parameters
Expand Down
14 changes: 6 additions & 8 deletions deepmd/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,11 @@

if TYPE_CHECKING:
_DICT_VAL = TypeVar("_DICT_VAL")
__all__.extend(
[
"_DICT_VAL",
"_PRECISION",
"_ACTIVATION",
]
)
__all__ += [
"_DICT_VAL",
"_PRECISION",
"_ACTIVATION",
]


def select_idx_map(atom_types: np.ndarray, select_types: np.ndarray) -> np.ndarray:
Expand Down Expand Up @@ -237,7 +235,7 @@ def get_np_precision(precision: "_PRECISION") -> np.dtype:
raise RuntimeError(f"{precision} is not a valid precision")


def symlink_prefix_files(old_prefix: str, new_prefix: str):
def symlink_prefix_files(old_prefix: str, new_prefix: str) -> None:
"""Create symlinks from old checkpoint prefix to new one.
On Windows this function will copy files instead of creating symlinks.
Expand Down
10 changes: 5 additions & 5 deletions deepmd/dpmodel/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ def __init__(
pair_exclude_types: list[tuple[int, int]] = [],
rcond: Optional[float] = None,
preset_out_bias: Optional[dict[str, np.ndarray]] = None,
):
) -> None:
super().__init__()
self.type_map = type_map
self.reinit_atom_exclude(atom_exclude_types)
self.reinit_pair_exclude(pair_exclude_types)
self.rcond = rcond
self.preset_out_bias = preset_out_bias

def init_out_stat(self):
def init_out_stat(self) -> None:
"""Initialize the output bias."""
ntypes = self.get_ntypes()
self.bias_keys: list[str] = list(self.fitting_output_def().keys())
Expand All @@ -68,7 +68,7 @@ def init_out_stat(self):
self.out_bias = out_bias_data
self.out_std = out_std_data

def __setitem__(self, key, value):
def __setitem__(self, key, value) -> None:
if key in ["out_bias"]:
self.out_bias = value
elif key in ["out_std"]:
Expand All @@ -91,7 +91,7 @@ def get_type_map(self) -> list[str]:
def reinit_atom_exclude(
self,
exclude_types: list[int] = [],
):
) -> None:
self.atom_exclude_types = exclude_types
if exclude_types == []:
self.atom_excl = None
Expand All @@ -101,7 +101,7 @@ def reinit_atom_exclude(
def reinit_pair_exclude(
self,
exclude_types: list[tuple[int, int]] = [],
):
) -> None:
self.pair_exclude_types = exclude_types
if exclude_types == []:
self.pair_excl = None
Expand Down
2 changes: 1 addition & 1 deletion deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(
fitting,
type_map: list[str],
**kwargs,
):
) -> None:
super().__init__(type_map, **kwargs)
self.type_map = type_map
self.descriptor = descriptor
Expand Down
4 changes: 2 additions & 2 deletions deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
models: list[BaseAtomicModel],
type_map: list[str],
**kwargs,
):
) -> None:
super().__init__(type_map, **kwargs)
super().init_out_stat()

Expand Down Expand Up @@ -391,7 +391,7 @@ def __init__(
type_map: list[str],
smin_alpha: Optional[float] = 0.1,
**kwargs,
):
) -> None:
models = [dp_model, zbl_model]
kwargs["models"] = models
kwargs["type_map"] = type_map
Expand Down
2 changes: 1 addition & 1 deletion deepmd/dpmodel/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(
rcond: Optional[float] = None,
atom_ener: Optional[list[float]] = None,
**kwargs,
):
) -> None:
super().__init__(type_map, **kwargs)
super().init_out_stat()
self.tab_file = tab_file
Expand Down
2 changes: 1 addition & 1 deletion deepmd/dpmodel/atomic_model/property_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@


class DPPropertyAtomicModel(DPAtomicModel):
def __init__(self, descriptor, fitting, type_map, **kwargs):
def __init__(self, descriptor, fitting, type_map, **kwargs) -> None:
assert isinstance(fitting, PropertyFittingNet)
super().__init__(descriptor, fitting, type_map, **kwargs)
7 changes: 4 additions & 3 deletions deepmd/dpmodel/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
)
from typing import (
Callable,
NoReturn,
Optional,
Union,
)
Expand Down Expand Up @@ -83,7 +84,7 @@ def compute_input_stats(
self,
merged: Union[Callable[[], list[dict]], list[dict]],
path: Optional[DPPath] = None,
):
) -> NoReturn:
"""
Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data.
Expand All @@ -106,7 +107,7 @@ def get_stats(self) -> dict[str, StatItem]:
"""Get the statistics of the descriptor."""
raise NotImplementedError

def share_params(self, base_class, shared_level, resume=False):
def share_params(self, base_class, shared_level, resume=False) -> NoReturn:
"""
Share the parameters of self to the base_class with shared_level during multitask training.
If not start from checkpoint (resume is False),
Expand Down Expand Up @@ -135,7 +136,7 @@ def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor block needs sorted nlist when using `forward_lower`."""


def extend_descrpt_stat(des, type_map, des_with_stat=None):
def extend_descrpt_stat(des, type_map, des_with_stat=None) -> None:
r"""
Extend the statistics of a descriptor block with types from newly provided `type_map`.
Expand Down
23 changes: 13 additions & 10 deletions deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import (
Any,
Callable,
NoReturn,
Optional,
Union,
)
Expand Down Expand Up @@ -388,7 +389,7 @@ def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.se_atten.get_env_protection()

def share_params(self, base_class, shared_level, resume=False):
def share_params(self, base_class, shared_level, resume=False) -> NoReturn:
"""
Share the parameters of self to the base_class with shared_level during multitask training.
If not start from checkpoint (resume is False),
Expand All @@ -404,7 +405,9 @@ def dim_out(self):
def dim_emb(self):
return self.get_dim_emb()

def compute_input_stats(self, merged: list[dict], path: Optional[DPPath] = None):
def compute_input_stats(
self, merged: list[dict], path: Optional[DPPath] = None
) -> NoReturn:
"""Update mean and stddev for descriptor elements."""
raise NotImplementedError

Expand Down Expand Up @@ -783,7 +786,7 @@ def get_dim_emb(self) -> int:
"""Returns the output dimension of embedding."""
return self.filter_neuron[-1]

def __setitem__(self, key, value):
def __setitem__(self, key, value) -> None:
if key in ("avg", "data_avg", "davg"):
self.mean = value
elif key in ("std", "data_std", "dstd"):
Expand Down Expand Up @@ -834,18 +837,18 @@ def compute_input_stats(
self,
merged: Union[Callable[[], list[dict]], list[dict]],
path: Optional[DPPath] = None,
):
) -> NoReturn:
"""Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data."""
raise NotImplementedError

def get_stats(self):
def get_stats(self) -> NoReturn:
"""Get the statistics of the descriptor."""
raise NotImplementedError

def reinit_exclude(
self,
exclude_types: list[tuple[int, int]] = [],
):
) -> None:
self.exclude_types = exclude_types
self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types)

Expand Down Expand Up @@ -1077,7 +1080,7 @@ def __init__(
smooth: bool = True,
precision: str = DEFAULT_PRECISION,
seed: Optional[Union[int, list[int]]] = None,
):
) -> None:
"""Construct a neighbor-wise attention net."""
super().__init__()
self.layer_num = layer_num
Expand Down Expand Up @@ -1132,7 +1135,7 @@ def __getitem__(self, key):
else:
raise TypeError(key)

def __setitem__(self, key, value):
def __setitem__(self, key, value) -> None:
if not isinstance(key, int):
raise TypeError(key)
if isinstance(value, self.network_type):
Expand Down Expand Up @@ -1205,7 +1208,7 @@ def __init__(
smooth: bool = True,
precision: str = DEFAULT_PRECISION,
seed: Optional[Union[int, list[int]]] = None,
):
) -> None:
"""Construct a neighbor-wise attention layer."""
super().__init__()
self.nnei = nnei
Expand Down Expand Up @@ -1311,7 +1314,7 @@ def __init__(
smooth: bool = True,
precision: str = DEFAULT_PRECISION,
seed: Optional[Union[int, list[int]]] = None,
):
) -> None:
"""Construct a multi-head neighbor-wise attention net."""
super().__init__()
assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads"
Expand Down
13 changes: 8 additions & 5 deletions deepmd/dpmodel/descriptor/dpa2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
NoReturn,
Optional,
Union,
)
Expand Down Expand Up @@ -88,7 +89,7 @@ def __init__(
three_body_sel: int = 40,
three_body_rcut: float = 4.0,
three_body_rcut_smth: float = 0.5,
):
) -> None:
r"""The constructor for the RepinitArgs class which defines the parameters of the repinit block in DPA2 descriptor.
Parameters
Expand Down Expand Up @@ -212,7 +213,7 @@ def __init__(
g1_out_conv: bool = True,
g1_out_mlp: bool = True,
ln_eps: Optional[float] = 1e-5,
):
) -> None:
r"""The constructor for the RepformerArgs class which defines the parameters of the repformer block in DPA2 descriptor.
Parameters
Expand Down Expand Up @@ -384,7 +385,7 @@ def __init__(
use_econf_tebd: bool = False,
use_tebd_bias: bool = False,
type_map: Optional[list[str]] = None,
):
) -> None:
r"""The DPA-2 descriptor. see https://arxiv.org/abs/2312.15492.
Parameters
Expand Down Expand Up @@ -656,7 +657,7 @@ def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.env_protection

def share_params(self, base_class, shared_level, resume=False):
def share_params(self, base_class, shared_level, resume=False) -> NoReturn:
"""
Share the parameters of self to the base_class with shared_level during multitask training.
If not start from checkpoint (resume is False),
Expand Down Expand Up @@ -728,7 +729,9 @@ def dim_emb(self):
"""Returns the embedding dimension g2."""
return self.get_dim_emb()

def compute_input_stats(self, merged: list[dict], path: Optional[DPPath] = None):
def compute_input_stats(
self, merged: list[dict], path: Optional[DPPath] = None
) -> NoReturn:
"""Update mean and stddev for descriptor elements."""
raise NotImplementedError

Expand Down
7 changes: 5 additions & 2 deletions deepmd/dpmodel/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import math
from typing import (
Any,
NoReturn,
Optional,
Union,
)
Expand Down Expand Up @@ -165,7 +166,7 @@ def get_env_protection(self) -> float:
)
return all_protection[0]

def share_params(self, base_class, shared_level, resume=False):
def share_params(self, base_class, shared_level, resume=False) -> NoReturn:
"""
Share the parameters of self to the base_class with shared_level during multitask training.
If not start from checkpoint (resume is False),
Expand All @@ -187,7 +188,9 @@ def change_type_map(
else None,
)

def compute_input_stats(self, merged: list[dict], path: Optional[DPPath] = None):
def compute_input_stats(
self, merged: list[dict], path: Optional[DPPath] = None
) -> None:
"""Update mean and stddev for descriptor elements."""
for descrpt in self.descrpt_list:
descrpt.compute_input_stats(merged, path)
Expand Down
3 changes: 2 additions & 1 deletion deepmd/dpmodel/descriptor/make_base_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
)
from typing import (
Callable,
NoReturn,
Optional,
Union,
)
Expand Down Expand Up @@ -143,7 +144,7 @@ def compute_input_stats(
self,
merged: Union[Callable[[], list[dict]], list[dict]],
path: Optional[DPPath] = None,
):
) -> NoReturn:
"""Update mean and stddev for descriptor elements."""
raise NotImplementedError

Expand Down
Loading

0 comments on commit 320c7fd

Please sign in to comment.