Skip to content

Commit

Permalink
Add type hint for Callable
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Feb 29, 2024
1 parent ab35653 commit 64d6079
Show file tree
Hide file tree
Showing 12 changed files with 33 additions and 13 deletions.
4 changes: 3 additions & 1 deletion deepmd/dpmodel/descriptor/make_base_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ def share_params(self, base_class, shared_level, resume=False):
pass

def compute_input_stats(
self, merged: Union[Callable, List[dict]], path: Optional[DPPath] = None
self,
merged: Union[Callable[[], List[dict]], List[dict]],
path: Optional[DPPath] = None,
):
"""Update mean and stddev for descriptor elements."""
raise NotImplementedError
Expand Down
4 changes: 3 additions & 1 deletion deepmd/pt/model/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def get_dim_emb(self) -> int:
pass

def compute_input_stats(
self, merged: Union[Callable, List[dict]], path: Optional[DPPath] = None
self,
merged: Union[Callable[[], List[dict]], List[dict]],
path: Optional[DPPath] = None,
):
"""Update mean and stddev for DescriptorBlock elements."""
raise NotImplementedError
Expand Down
4 changes: 3 additions & 1 deletion deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,9 @@ def dim_emb(self):
return self.get_dim_emb()

def compute_input_stats(
self, merged: Union[Callable, List[dict]], path: Optional[DPPath] = None
self,
merged: Union[Callable[[], List[dict]], List[dict]],
path: Optional[DPPath] = None,
):
return self.se_atten.compute_input_stats(merged, path)

Expand Down
4 changes: 3 additions & 1 deletion deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,9 @@ def dim_emb(self):
return self.get_dim_emb()

def compute_input_stats(
self, merged: Union[Callable, List[dict]], path: Optional[DPPath] = None
self,
merged: Union[Callable[[], List[dict]], List[dict]],
path: Optional[DPPath] = None,
):
for ii, descrpt in enumerate([self.repinit, self.repformers]):
descrpt.compute_input_stats(merged, path)
Expand Down
4 changes: 3 additions & 1 deletion deepmd/pt/model/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,9 @@ def share_params(self, base_class, shared_level, resume=False):
raise NotImplementedError

def compute_input_stats(
self, merged: Union[Callable, List[dict]], path: Optional[DPPath] = None
self,
merged: Union[Callable[[], List[dict]], List[dict]],
path: Optional[DPPath] = None,
):
"""Update mean and stddev for descriptor elements."""
for ii, descrpt in enumerate(self.descriptor_list):
Expand Down
4 changes: 3 additions & 1 deletion deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,9 @@ def forward(
return g1, g2, h2, rot_mat.view(-1, nloc, self.dim_emb, 3), sw

def compute_input_stats(
self, merged: Union[Callable, List[dict]], path: Optional[DPPath] = None
self,
merged: Union[Callable[[], List[dict]], List[dict]],
path: Optional[DPPath] = None,
):
"""Update mean and stddev for descriptor elements."""
env_mat_stat = EnvMatStatSe(self)
Expand Down
8 changes: 6 additions & 2 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,9 @@ def dim_out(self):
return self.sea.dim_out

def compute_input_stats(
self, merged: Union[Callable, List[dict]], path: Optional[DPPath] = None
self,
merged: Union[Callable[[], List[dict]], List[dict]],
path: Optional[DPPath] = None,
):
"""Update mean and stddev for descriptor elements."""
return self.sea.compute_input_stats(merged, path)
Expand Down Expand Up @@ -405,7 +407,9 @@ def __getitem__(self, key):
raise KeyError(key)

def compute_input_stats(
self, merged: Union[Callable, List[dict]], path: Optional[DPPath] = None
self,
merged: Union[Callable[[], List[dict]], List[dict]],
path: Optional[DPPath] = None,
):
"""Update mean and stddev for descriptor elements."""
env_mat_stat = EnvMatStatSe(self)
Expand Down
4 changes: 3 additions & 1 deletion deepmd/pt/model/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,9 @@ def dim_emb(self):
return self.get_dim_emb()

def compute_input_stats(
self, merged: Union[Callable, List[dict]], path: Optional[DPPath] = None
self,
merged: Union[Callable[[], List[dict]], List[dict]],
path: Optional[DPPath] = None,
):
"""Update mean and stddev for descriptor elements."""
env_mat_stat = EnvMatStatSe(self)
Expand Down
4 changes: 3 additions & 1 deletion deepmd/pt/model/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,9 @@ def share_params(self, base_class, shared_level, resume=False):
raise NotImplementedError

def compute_input_stats(
self, merged: Union[Callable, List[dict]], path: Optional[DPPath] = None
self,
merged: Union[Callable[[], List[dict]], List[dict]],
path: Optional[DPPath] = None,
):
"""Update mean and stddev for descriptor elements."""
env_mat_stat = EnvMatStatSe(self)
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/task/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def output_def(self) -> FittingOutputDef:

def compute_output_stats(
self,
merged: Union[Callable, List[dict]],
merged: Union[Callable[[], List[dict]], List[dict]],
stat_file_path: Optional[DPPath] = None,
):
raise NotImplementedError
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/task/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def serialize(self) -> dict:

def compute_output_stats(
self,
merged: Union[Callable, List[dict]],
merged: Union[Callable[[], List[dict]], List[dict]],
stat_file_path: Optional[DPPath] = None,
):
if stat_file_path is not None:
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/task/polarizability.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def output_def(self) -> FittingOutputDef:

def compute_output_stats(
self,
merged: Union[Callable, List[dict]],
merged: Union[Callable[[], List[dict]], List[dict]],
stat_file_path: Optional[DPPath] = None,
):
raise NotImplementedError
Expand Down

0 comments on commit 64d6079

Please sign in to comment.