Skip to content

Commit

Permalink
remove warnings in convertion
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyes319 committed Nov 20, 2024
1 parent 315d7c5 commit 9dca489
Show file tree
Hide file tree
Showing 3 changed files with 1 addition and 24 deletions.
2 changes: 1 addition & 1 deletion mace/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .__version__ import __version__
from .__version__ import __version__
12 changes: 0 additions & 12 deletions mace/cli/convert_cueq_e3nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,6 @@ def transfer_symmetric_contractions(
):
"""Transfer symmetric contraction weights from CuEq to E3nn format"""
kmax_pairs = get_kmax_pairs(max_L, correlation)
logging.warning(
f"Using kmax pairs {kmax_pairs} for max_L={max_L}, correlation={correlation}"
)

for i, kmax in kmax_pairs:
# Get the combined weight tensor from source
Expand Down Expand Up @@ -95,15 +92,13 @@ def transfer_weights(

# Transfer main weights
transfer_keys = get_transfer_keys()
logging.warning("Transferring main weights...")
for key in transfer_keys:
if key in source_dict: # Check if key exists
target_dict[key] = source_dict[key]
else:
logging.warning(f"Key {key} not found in source model")

# Transfer symmetric contractions
logging.warning("Transferring symmetric contractions...")
transfer_symmetric_contractions(source_dict, target_dict, max_L, correlation)

# Transfer remaining matching keys
Expand All @@ -114,9 +109,6 @@ def transfer_weights(
remaining_keys = {k for k in remaining_keys if "symmetric_contraction" not in k}

if remaining_keys:
logging.warning(
f"Found {len(remaining_keys)} additional matching keys to transfer"
)
for key in remaining_keys:
if source_dict[key].shape == target_dict[key].shape:
logging.debug(f"Transferring additional key: {key}")
Expand All @@ -140,21 +132,18 @@ def transfer_weights(
def run(input_model, output_model="_e3nn.model", device="cuda", return_model=True):

# Load CuEq model
logging.warning("Loading CuEq model")
if isinstance(input_model, str):
source_model = torch.load(input_model, map_location=device)
else:
source_model = input_model
default_dtype = next(source_model.parameters()).dtype
torch.set_default_dtype(default_dtype)
# Extract configuration
logging.warning("Extracting model configuration")
config = extract_config_mace_model(source_model)

# Get max_L and correlation from config
max_L = config["hidden_irreps"].lmax
correlation = config["correlation"]
logging.warning(f"Extracted max_L={max_L}, correlation={correlation}")

# Remove CuEq config
config.pop("cueq_config", None)
Expand All @@ -164,7 +153,6 @@ def run(input_model, output_model="_e3nn.model", device="cuda", return_model=Tru
target_model = source_model.__class__(**config)

# Transfer weights with proper remapping
logging.warning("Transferring weights with remapping...")
transfer_weights(source_model, target_model, max_L, correlation)

if return_model:
Expand Down
11 changes: 0 additions & 11 deletions mace/cli/convert_e3nn_cueq.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,6 @@ def transfer_symmetric_contractions(
):
"""Transfer symmetric contraction weights"""
kmax_pairs = get_kmax_pairs(max_L, correlation)
logging.warning(
f"Using kmax pairs {kmax_pairs} for max_L={max_L}, correlation={correlation}"
)

for i, kmax in kmax_pairs:
wm = torch.concatenate(
Expand Down Expand Up @@ -80,15 +77,13 @@ def transfer_weights(

# Transfer main weights
transfer_keys = get_transfer_keys()
logging.warning("Transferring main weights...")
for key in transfer_keys:
if key in source_dict: # Check if key exists
target_dict[key] = source_dict[key]
else:
logging.warning(f"Key {key} not found in source model")

# Transfer symmetric contractions
logging.warning("Transferring symmetric contractions...")
transfer_symmetric_contractions(source_dict, target_dict, max_L, correlation)

transferred_keys = set(transfer_keys)
Expand All @@ -97,9 +92,6 @@ def transfer_weights(
)
remaining_keys = {k for k in remaining_keys if "symmetric_contraction" not in k}
if remaining_keys:
logging.warning(
f"Found {len(remaining_keys)} additional matching keys to transfer"
)
for key in remaining_keys:
if source_dict[key].shape == target_dict[key].shape:
logging.debug(f"Transferring additional key: {key}")
Expand Down Expand Up @@ -137,13 +129,11 @@ def run(
default_dtype = next(source_model.parameters()).dtype
torch.set_default_dtype(default_dtype)
# Extract configuration
logging.warning("Extracting model configuration")
config = extract_config_mace_model(source_model)

# Get max_L and correlation from config
max_L = config["hidden_irreps"].lmax
correlation = config["correlation"]
logging.warning(f"Extracted max_L={max_L}, correlation={correlation}")

# Add cuequivariance config
config["cueq_config"] = CuEquivarianceConfig(
Expand All @@ -158,7 +148,6 @@ def run(
target_model = source_model.__class__(**config).to(device)

# Transfer weights with proper remapping
logging.warning("Transferring weights with remapping...")
transfer_weights(source_model, target_model, max_L, correlation)

if return_model:
Expand Down

0 comments on commit 9dca489

Please sign in to comment.