From 9d54a3317df860edb669b770600fb6e158af1dce Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Thu, 6 Jun 2024 16:28:28 -0400 Subject: [PATCH 1/5] Removing batch norm. --- crystal_diffusion/models/diffusion_mace.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/crystal_diffusion/models/diffusion_mace.py b/crystal_diffusion/models/diffusion_mace.py index deef4eb0..80b79f1c 100644 --- a/crystal_diffusion/models/diffusion_mace.py +++ b/crystal_diffusion/models/diffusion_mace.py @@ -2,7 +2,7 @@ import torch from e3nn import o3 -from e3nn.nn import Activation, BatchNorm +from e3nn.nn import Activation from mace.modules import (EquivariantProductBasisBlock, InteractionBlock, LinearNodeEmbeddingBlock, RadialEmbeddingBlock) from mace.modules.utils import get_edge_vectors_and_lengths @@ -146,9 +146,6 @@ def __init__( non_linearity = Activation(irreps_in=diffusion_scalar_irreps_out, acts=[gate]) self.diffusion_scalar_embedding.append(non_linearity) - normalization = BatchNorm(diffusion_scalar_irreps_out) - self.diffusion_scalar_embedding.append(normalization) - linear = o3.Linear(irreps_in=diffusion_scalar_irreps_out, irreps_out=diffusion_scalar_irreps_out, biases=True) From a2347503d64a2b13884de6d395b0e0f43aba7fa6 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Thu, 6 Jun 2024 16:29:18 -0400 Subject: [PATCH 2/5] Fix example config. --- examples/mila_cluster/diffusion/config_diffusion_mace.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/mila_cluster/diffusion/config_diffusion_mace.yaml b/examples/mila_cluster/diffusion/config_diffusion_mace.yaml index 0efa2096..145a0133 100644 --- a/examples/mila_cluster/diffusion/config_diffusion_mace.yaml +++ b/examples/mila_cluster/diffusion/config_diffusion_mace.yaml @@ -30,7 +30,7 @@ model: num_interactions: 2 hidden_irreps: 8x0e + 8x1o mlp_irreps: 8x0e - + number_of_mlp_layers: 0 avg_num_neighbors: 1 correlation: 3 gate: silu From ce7495488ec74f22004247f53598ae6c2d7b29b8 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Thu, 6 Jun 2024 16:42:16 -0400 Subject: [PATCH 3/5] Config for diffusion mace. --- .../config_diffusion_mace.yaml | 85 +++++++++++++++++++ 1 file changed, 85 insertions(+) create mode 100644 experiments/si_diffusion_1x1x1/config_diffusion_mace.yaml diff --git a/experiments/si_diffusion_1x1x1/config_diffusion_mace.yaml b/experiments/si_diffusion_1x1x1/config_diffusion_mace.yaml new file mode 100644 index 00000000..6acefc6e --- /dev/null +++ b/experiments/si_diffusion_1x1x1/config_diffusion_mace.yaml @@ -0,0 +1,85 @@ +# general +exp_name: difface +run_name: run1 +max_epoch: 25 +log_every_n_steps: 1 +gradient_clipping: 0.1 + +# set to null to avoid setting a seed (can speed up GPU computation, but +# results will not be reproducible) +seed: 1234 + +# data +data: + batch_size: 512 + num_workers: 8 + max_atom: 8 + +# architecture +spatial_dimension: 3 +model: + score_network: + architecture: diffusion_mace + number_of_atoms: 8 + r_max: 5.0 + num_bessel: 8 + num_polynomial_cutoff: 5 + max_ell: 2 + interaction_cls: RealAgnosticResidualInteractionBlock + interaction_cls_first: RealAgnosticInteractionBlock + num_interactions: 2 + hidden_irreps: 128x0e + 128x1o + 128x2e + mlp_irreps: 128x0e + number_of_mlp_layers: 0 + avg_num_neighbors: 1 + correlation: 3 + gate: silu + radial_MLP: [128, 128, 128] + radial_type: bessel + noise: + total_time_steps: 100 + sigma_min: 0.001 # default value + sigma_max: 0.5 # default value' + +# optimizer and scheduler +optimizer: + name: adamw + learning_rate: 0.001 + weight_decay: 1.0e-8 + +scheduler: + name: ReduceLROnPlateau + factor: 0.1 + patience: 20 + +# early stopping +early_stopping: + metric: validation_epoch_loss + mode: min + patience: 10 + +model_checkpoint: + monitor: validation_epoch_loss + mode: min + +# Sampling from the generative model +diffusion_sampling: + noise: + total_time_steps: 100 + sigma_min: 0.001 # default value + sigma_max: 0.5 # default value + sampling: + spatial_dimension: 3 + number_of_corrector_steps: 1 + number_of_atoms: 8 + number_of_samples: 1000 + sample_every_n_epochs: 5 + cell_dimensions: [5.43, 5.43, 5.43] + +# A callback to check the loss vs. sigma +loss_monitoring: + number_of_bins: 50 + sample_every_n_epochs: 2 + +logging: + - comet \ No newline at end of file From c4fe039e46286518a80a677dbc97289dc274eb43 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Thu, 6 Jun 2024 17:17:42 -0400 Subject: [PATCH 4/5] Fix type hint bjork. --- crystal_diffusion/models/scheduler.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/crystal_diffusion/models/scheduler.py b/crystal_diffusion/models/scheduler.py index 42074156..9c1fee6d 100644 --- a/crystal_diffusion/models/scheduler.py +++ b/crystal_diffusion/models/scheduler.py @@ -62,8 +62,7 @@ class CosineAnnealingLRSchedulerParameters(SchedulerParameters): eta_min: float = 0.0 -def load_scheduler_dictionary(hyper_params: SchedulerParameters, - optimizer: optim.Optimizer) -> Dict[AnyStr, Union[optim.lr_scheduler, AnyStr]]: +def load_scheduler_dictionary(hyper_params: SchedulerParameters, optimizer: optim.Optimizer) -> Dict[AnyStr, Any]: """Instantiate the Scheduler. Args: From babbc0851550d7a5aeb4b8bfaf778de185603897 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Thu, 6 Jun 2024 20:27:52 -0400 Subject: [PATCH 5/5] Always read to CPU. --- crystal_diffusion/utils/sample_trajectory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crystal_diffusion/utils/sample_trajectory.py b/crystal_diffusion/utils/sample_trajectory.py index 8e9cc64c..29bb5a26 100644 --- a/crystal_diffusion/utils/sample_trajectory.py +++ b/crystal_diffusion/utils/sample_trajectory.py @@ -52,7 +52,7 @@ def read_from_pickle(path_to_pickle: str): """Read from pickle.""" with open(path_to_pickle, 'rb') as fd: sample_trajectory = SampleTrajectory() - sample_trajectory.data = torch.load(fd) + sample_trajectory.data = torch.load(fd, map_location=torch.device('cpu')) return sample_trajectory