From db468b9e5802f80b312a4b1e84c553a40bc54ea3 Mon Sep 17 00:00:00 2001 From: CheukHinHoJerry Date: Sun, 10 Nov 2024 00:40:06 +0000 Subject: [PATCH 1/2] fixing multihead finetuning with density normalization --- mace/tools/finetuning_utils.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/mace/tools/finetuning_utils.py b/mace/tools/finetuning_utils.py index 0d4e2f52..71459fcd 100644 --- a/mace/tools/finetuning_utils.py +++ b/mace/tools/finetuning_utils.py @@ -75,7 +75,7 @@ def load_foundations_elements( ) if ( model.interactions[i].__class__.__name__ - == "RealAgnosticResidualInteractionBlock" + in ["RealAgnosticResidualInteractionBlock", "RealAgnosticDensityResidualInteractionBlock"] ): model.interactions[i].skip_tp.weight = torch.nn.Parameter( model_foundations.interactions[i] @@ -101,7 +101,21 @@ def load_foundations_elements( .clone() / (num_species_foundations / num_species) ** 0.5 ) - + if ( + model.interactions[i].__class__.__name__ + in ["RealAgnosticResidualInteractionBlock", "RealAgnosticDensityResidualInteractionBlock"] + ): + # Assuming only 1 layer in density_fn + getattr(model.interactions[i].density_fn, "layer0").weight = ( + torch.nn.Parameter( + getattr( + model_foundations.interactions[i].density_fn, + "layer0", + ) + .weight + .clone() + ) + ) # Transferring products for i in range(2): # Assuming 2 products modules max_range = max_L + 1 if i == 0 else 1 From 6da237701a4bf1abc125ad14b1200451ef8e68bf Mon Sep 17 00:00:00 2001 From: CheukHinHoJerry Date: Sun, 10 Nov 2024 03:12:18 +0000 Subject: [PATCH 2/2] minor fix --- mace/tools/finetuning_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mace/tools/finetuning_utils.py b/mace/tools/finetuning_utils.py index 71459fcd..71ca6a7c 100644 --- a/mace/tools/finetuning_utils.py +++ b/mace/tools/finetuning_utils.py @@ -103,7 +103,7 @@ def load_foundations_elements( ) if ( model.interactions[i].__class__.__name__ - in ["RealAgnosticResidualInteractionBlock", "RealAgnosticDensityResidualInteractionBlock"] + in ["RealAgnosticDensityInteractionBlock", "RealAgnosticDensityResidualInteractionBlock"] ): # Assuming only 1 layer in density_fn getattr(model.interactions[i].density_fn, "layer0").weight = (