Skip to content

Commit

Permalink
pt: support multitask dp test (#3573)
Browse files Browse the repository at this point in the history
Fix #3471
  • Loading branch information
iProzd authored Mar 21, 2024
1 parent 5aa1b89 commit b2cc0e5
Showing 1 changed file with 18 additions and 2 deletions.
20 changes: 18 additions & 2 deletions deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def __init__(
*args: List[Any],
auto_batch_size: Union[bool, int, AutoBatchSize] = True,
neighbor_list: Optional["ase.neighborlist.NewPrimitiveNeighborList"] = None,
head: Optional[str] = None,
**kwargs: Dict[str, Any],
):
self.output_def = output_def
Expand All @@ -99,9 +100,24 @@ def __init__(
if "model" in state_dict:
state_dict = state_dict["model"]
self.input_param = state_dict["_extra_state"]["model_params"]
self.input_param["resuming"] = True
self.multi_task = "model_dict" in self.input_param
assert not self.multi_task, "multitask mode currently not supported!"
if self.multi_task:
model_keys = list(self.input_param["model_dict"].keys())
assert (
head is not None
), f"Head must be set for multitask model! Available heads are: {model_keys}"
assert (
head in model_keys
), f"No head named {head} in model! Available heads are: {model_keys}"
self.input_param = self.input_param["model_dict"][head]
state_dict_head = {"_extra_state": state_dict["_extra_state"]}
for item in state_dict:
if f"model.{head}." in item:
state_dict_head[
item.replace(f"model.{head}.", "model.Default.")
] = state_dict[item].clone()
state_dict = state_dict_head
self.input_param["resuming"] = True
model = get_model(self.input_param).to(DEVICE)
model = torch.jit.script(model)
self.dp = ModelWrapper(model)
Expand Down

0 comments on commit b2cc0e5

Please sign in to comment.