Skip to content

Commit

Permalink
Merge pull request #682 from CheukHinHoJerry/fix_dn_mh_ft
Browse files Browse the repository at this point in the history
Fixing multihead finetuning with density normalization
  • Loading branch information
ilyes319 authored Nov 11, 2024
2 parents c1184eb + 6da2377 commit abd7e5e
Showing 1 changed file with 16 additions and 2 deletions.
18 changes: 16 additions & 2 deletions mace/tools/finetuning_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def load_foundations_elements(
)
if (
model.interactions[i].__class__.__name__
== "RealAgnosticResidualInteractionBlock"
in ["RealAgnosticResidualInteractionBlock", "RealAgnosticDensityResidualInteractionBlock"]
):
model.interactions[i].skip_tp.weight = torch.nn.Parameter(
model_foundations.interactions[i]
Expand All @@ -101,7 +101,21 @@ def load_foundations_elements(
.clone()
/ (num_species_foundations / num_species) ** 0.5
)

if (
model.interactions[i].__class__.__name__
in ["RealAgnosticDensityInteractionBlock", "RealAgnosticDensityResidualInteractionBlock"]
):
# Assuming only 1 layer in density_fn
getattr(model.interactions[i].density_fn, "layer0").weight = (
torch.nn.Parameter(
getattr(
model_foundations.interactions[i].density_fn,
"layer0",
)
.weight
.clone()
)
)
# Transferring products
for i in range(2): # Assuming 2 products modules
max_range = max_L + 1 if i == 0 else 1
Expand Down

0 comments on commit abd7e5e

Please sign in to comment.