From 9dca489982906292881a810d637dc46a1140760f Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Wed, 20 Nov 2024 15:06:27 +0000 Subject: [PATCH] remove warnings in convertion --- mace/__init__.py | 2 +- mace/cli/convert_cueq_e3nn.py | 12 ------------ mace/cli/convert_e3nn_cueq.py | 11 ----------- 3 files changed, 1 insertion(+), 24 deletions(-) diff --git a/mace/__init__.py b/mace/__init__.py index e9c9ef48..9226fe7e 100644 --- a/mace/__init__.py +++ b/mace/__init__.py @@ -1 +1 @@ -from .__version__ import __version__ \ No newline at end of file +from .__version__ import __version__ diff --git a/mace/cli/convert_cueq_e3nn.py b/mace/cli/convert_cueq_e3nn.py index cd72b343..732e2bf5 100644 --- a/mace/cli/convert_cueq_e3nn.py +++ b/mace/cli/convert_cueq_e3nn.py @@ -48,9 +48,6 @@ def transfer_symmetric_contractions( ): """Transfer symmetric contraction weights from CuEq to E3nn format""" kmax_pairs = get_kmax_pairs(max_L, correlation) - logging.warning( - f"Using kmax pairs {kmax_pairs} for max_L={max_L}, correlation={correlation}" - ) for i, kmax in kmax_pairs: # Get the combined weight tensor from source @@ -95,7 +92,6 @@ def transfer_weights( # Transfer main weights transfer_keys = get_transfer_keys() - logging.warning("Transferring main weights...") for key in transfer_keys: if key in source_dict: # Check if key exists target_dict[key] = source_dict[key] @@ -103,7 +99,6 @@ def transfer_weights( logging.warning(f"Key {key} not found in source model") # Transfer symmetric contractions - logging.warning("Transferring symmetric contractions...") transfer_symmetric_contractions(source_dict, target_dict, max_L, correlation) # Transfer remaining matching keys @@ -114,9 +109,6 @@ def transfer_weights( remaining_keys = {k for k in remaining_keys if "symmetric_contraction" not in k} if remaining_keys: - logging.warning( - f"Found {len(remaining_keys)} additional matching keys to transfer" - ) for key in remaining_keys: if source_dict[key].shape == target_dict[key].shape: logging.debug(f"Transferring additional key: {key}") @@ -140,7 +132,6 @@ def transfer_weights( def run(input_model, output_model="_e3nn.model", device="cuda", return_model=True): # Load CuEq model - logging.warning("Loading CuEq model") if isinstance(input_model, str): source_model = torch.load(input_model, map_location=device) else: @@ -148,13 +139,11 @@ def run(input_model, output_model="_e3nn.model", device="cuda", return_model=Tru default_dtype = next(source_model.parameters()).dtype torch.set_default_dtype(default_dtype) # Extract configuration - logging.warning("Extracting model configuration") config = extract_config_mace_model(source_model) # Get max_L and correlation from config max_L = config["hidden_irreps"].lmax correlation = config["correlation"] - logging.warning(f"Extracted max_L={max_L}, correlation={correlation}") # Remove CuEq config config.pop("cueq_config", None) @@ -164,7 +153,6 @@ def run(input_model, output_model="_e3nn.model", device="cuda", return_model=Tru target_model = source_model.__class__(**config) # Transfer weights with proper remapping - logging.warning("Transferring weights with remapping...") transfer_weights(source_model, target_model, max_L, correlation) if return_model: diff --git a/mace/cli/convert_e3nn_cueq.py b/mace/cli/convert_e3nn_cueq.py index fbcb72e5..45e07257 100644 --- a/mace/cli/convert_e3nn_cueq.py +++ b/mace/cli/convert_e3nn_cueq.py @@ -49,9 +49,6 @@ def transfer_symmetric_contractions( ): """Transfer symmetric contraction weights""" kmax_pairs = get_kmax_pairs(max_L, correlation) - logging.warning( - f"Using kmax pairs {kmax_pairs} for max_L={max_L}, correlation={correlation}" - ) for i, kmax in kmax_pairs: wm = torch.concatenate( @@ -80,7 +77,6 @@ def transfer_weights( # Transfer main weights transfer_keys = get_transfer_keys() - logging.warning("Transferring main weights...") for key in transfer_keys: if key in source_dict: # Check if key exists target_dict[key] = source_dict[key] @@ -88,7 +84,6 @@ def transfer_weights( logging.warning(f"Key {key} not found in source model") # Transfer symmetric contractions - logging.warning("Transferring symmetric contractions...") transfer_symmetric_contractions(source_dict, target_dict, max_L, correlation) transferred_keys = set(transfer_keys) @@ -97,9 +92,6 @@ def transfer_weights( ) remaining_keys = {k for k in remaining_keys if "symmetric_contraction" not in k} if remaining_keys: - logging.warning( - f"Found {len(remaining_keys)} additional matching keys to transfer" - ) for key in remaining_keys: if source_dict[key].shape == target_dict[key].shape: logging.debug(f"Transferring additional key: {key}") @@ -137,13 +129,11 @@ def run( default_dtype = next(source_model.parameters()).dtype torch.set_default_dtype(default_dtype) # Extract configuration - logging.warning("Extracting model configuration") config = extract_config_mace_model(source_model) # Get max_L and correlation from config max_L = config["hidden_irreps"].lmax correlation = config["correlation"] - logging.warning(f"Extracted max_L={max_L}, correlation={correlation}") # Add cuequivariance config config["cueq_config"] = CuEquivarianceConfig( @@ -158,7 +148,6 @@ def run( target_model = source_model.__class__(**config).to(device) # Transfer weights with proper remapping - logging.warning("Transferring weights with remapping...") transfer_weights(source_model, target_model, max_L, correlation) if return_model: