diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 3ea2965fa7..7ba5f0b63a 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -297,8 +297,8 @@ def train(FLAGS): "Calculate neighbor statistics... (add --skip-neighbor-stat to skip this step)" ) - type_map = config["model"].get("type_map") if not multi_task: + type_map = config["model"].get("type_map") train_data = get_data( config["training"]["training_data"], 0, type_map, None ) @@ -308,6 +308,7 @@ def train(FLAGS): else: min_nbor_dist = {} for model_item in config["model"]["model_dict"]: + type_map = config["model"]["model_dict"][model_item].get("type_map") train_data = get_data( config["training"]["data_dict"][model_item]["training_data"], 0,