Skip to content

Commit

Permalink
chore: refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
anyangml committed Mar 8, 2024
1 parent 3de4709 commit 831158c
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 8 deletions.
4 changes: 0 additions & 4 deletions deepmd/dpmodel/fitting/general_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,6 @@ def __setitem__(self, key, value):
self.aparam_inv_std = value
elif key in ["scale"]:
self.scale = value
elif key in ["constant_matrix"]:
self.constant_matrix = value
else:
raise KeyError(key)

Expand All @@ -209,8 +207,6 @@ def __getitem__(self, key):
return self.aparam_inv_std
elif key in ["scale"]:
return self.scale
elif key in ["constant_matrix"]:
return self.constant_matrix
else:
raise KeyError(key)

Expand Down
12 changes: 12 additions & 0 deletions deepmd/dpmodel/fitting/polarizability_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,18 @@ def _net_out_dim(self):
else self.embedding_width * self.embedding_width
)

def __setitem__(self, key, value):
if key in ["constant_matrix"]:
self.constant_matrix = value
else:
super().__setitem__(key, value)

def __getitem__(self, key):
if key in ["constant_matrix"]:
return self.constant_matrix
else:
super().__getitem__(key, value)

def serialize(self) -> dict:
data = super().serialize()
data["type"] = "polar"
Expand Down
4 changes: 0 additions & 4 deletions deepmd/pt/model/task/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,8 +415,6 @@ def __setitem__(self, key, value):
self.aparam_inv_std = value
elif key in ["scale"]:
self.scale = value
elif key in ["constant_matrix"]:
self.constant_matrix = value
else:
raise KeyError(key)

Expand All @@ -433,8 +431,6 @@ def __getitem__(self, key):
return self.aparam_inv_std
elif key in ["scale"]:
return self.scale
elif key in ["constant_matrix"]:
return self.constant_matrix
else:
raise KeyError(key)

Expand Down
12 changes: 12 additions & 0 deletions deepmd/pt/model/task/polarizability.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,18 @@ def _net_out_dim(self):
else self.embedding_width * self.embedding_width
)

def __setitem__(self, key, value):
if key in ["constant_matrix"]:
self.constant_matrix = value
else:
super().__setitem__(key, value)

def __getitem__(self, key):
if key in ["constant_matrix"]:
return self.constant_matrix
else:
super().__getitem__(key, value)

def serialize(self) -> dict:
data = super().serialize()
data["type"] = "polar"
Expand Down

0 comments on commit 831158c

Please sign in to comment.