diff --git a/mace/tools/finetuning_utils.py b/mace/tools/finetuning_utils.py index 0d4e2f52..71ca6a7c 100644 --- a/mace/tools/finetuning_utils.py +++ b/mace/tools/finetuning_utils.py @@ -75,7 +75,7 @@ def load_foundations_elements( ) if ( model.interactions[i].__class__.__name__ - == "RealAgnosticResidualInteractionBlock" + in ["RealAgnosticResidualInteractionBlock", "RealAgnosticDensityResidualInteractionBlock"] ): model.interactions[i].skip_tp.weight = torch.nn.Parameter( model_foundations.interactions[i] @@ -101,7 +101,21 @@ def load_foundations_elements( .clone() / (num_species_foundations / num_species) ** 0.5 ) - + if ( + model.interactions[i].__class__.__name__ + in ["RealAgnosticDensityInteractionBlock", "RealAgnosticDensityResidualInteractionBlock"] + ): + # Assuming only 1 layer in density_fn + getattr(model.interactions[i].density_fn, "layer0").weight = ( + torch.nn.Parameter( + getattr( + model_foundations.interactions[i].density_fn, + "layer0", + ) + .weight + .clone() + ) + ) # Transferring products for i in range(2): # Assuming 2 products modules max_range = max_L + 1 if i == 0 else 1