From b2cc0e5d145b06d8dc9b75a5664ebc36c28f06b4 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Thu, 21 Mar 2024 14:56:47 +0800 Subject: [PATCH] pt: support multitask dp test (#3573) Fix #3471 --- deepmd/pt/infer/deep_eval.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/deepmd/pt/infer/deep_eval.py b/deepmd/pt/infer/deep_eval.py index f46d5fce49..1262a56310 100644 --- a/deepmd/pt/infer/deep_eval.py +++ b/deepmd/pt/infer/deep_eval.py @@ -90,6 +90,7 @@ def __init__( *args: List[Any], auto_batch_size: Union[bool, int, AutoBatchSize] = True, neighbor_list: Optional["ase.neighborlist.NewPrimitiveNeighborList"] = None, + head: Optional[str] = None, **kwargs: Dict[str, Any], ): self.output_def = output_def @@ -99,9 +100,24 @@ def __init__( if "model" in state_dict: state_dict = state_dict["model"] self.input_param = state_dict["_extra_state"]["model_params"] - self.input_param["resuming"] = True self.multi_task = "model_dict" in self.input_param - assert not self.multi_task, "multitask mode currently not supported!" + if self.multi_task: + model_keys = list(self.input_param["model_dict"].keys()) + assert ( + head is not None + ), f"Head must be set for multitask model! Available heads are: {model_keys}" + assert ( + head in model_keys + ), f"No head named {head} in model! Available heads are: {model_keys}" + self.input_param = self.input_param["model_dict"][head] + state_dict_head = {"_extra_state": state_dict["_extra_state"]} + for item in state_dict: + if f"model.{head}." in item: + state_dict_head[ + item.replace(f"model.{head}.", "model.Default.") + ] = state_dict[item].clone() + state_dict = state_dict_head + self.input_param["resuming"] = True model = get_model(self.input_param).to(DEVICE) model = torch.jit.script(model) self.dp = ModelWrapper(model)