Skip to content

Commit

Permalink
Merge pull request #56 from mila-iqia/difface_fixes
Browse files Browse the repository at this point in the history
Difface fixes
  • Loading branch information
rousseab authored Jun 10, 2024
2 parents 2b9df8d + babbc08 commit 85b1f6e
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 8 deletions.
5 changes: 1 addition & 4 deletions crystal_diffusion/models/diffusion_mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch
from e3nn import o3
from e3nn.nn import Activation, BatchNorm
from e3nn.nn import Activation
from mace.modules import (EquivariantProductBasisBlock, InteractionBlock,
LinearNodeEmbeddingBlock, RadialEmbeddingBlock)
from mace.modules.utils import get_edge_vectors_and_lengths
Expand Down Expand Up @@ -146,9 +146,6 @@ def __init__(
non_linearity = Activation(irreps_in=diffusion_scalar_irreps_out, acts=[gate])
self.diffusion_scalar_embedding.append(non_linearity)

normalization = BatchNorm(diffusion_scalar_irreps_out)
self.diffusion_scalar_embedding.append(normalization)

linear = o3.Linear(irreps_in=diffusion_scalar_irreps_out,
irreps_out=diffusion_scalar_irreps_out,
biases=True)
Expand Down
3 changes: 1 addition & 2 deletions crystal_diffusion/models/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ class CosineAnnealingLRSchedulerParameters(SchedulerParameters):
eta_min: float = 0.0


def load_scheduler_dictionary(hyper_params: SchedulerParameters,
optimizer: optim.Optimizer) -> Dict[AnyStr, Union[optim.lr_scheduler, AnyStr]]:
def load_scheduler_dictionary(hyper_params: SchedulerParameters, optimizer: optim.Optimizer) -> Dict[AnyStr, Any]:
"""Instantiate the Scheduler.
Args:
Expand Down
2 changes: 1 addition & 1 deletion crystal_diffusion/utils/sample_trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def read_from_pickle(path_to_pickle: str):
"""Read from pickle."""
with open(path_to_pickle, 'rb') as fd:
sample_trajectory = SampleTrajectory()
sample_trajectory.data = torch.load(fd)
sample_trajectory.data = torch.load(fd, map_location=torch.device('cpu'))
return sample_trajectory


Expand Down
2 changes: 1 addition & 1 deletion examples/mila_cluster/diffusion/config_diffusion_mace.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ model:
num_interactions: 2
hidden_irreps: 8x0e + 8x1o
mlp_irreps: 8x0e

number_of_mlp_layers: 0
avg_num_neighbors: 1
correlation: 3
gate: silu
Expand Down
85 changes: 85 additions & 0 deletions experiments/si_diffusion_1x1x1/config_diffusion_mace.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# general
exp_name: difface
run_name: run1
max_epoch: 25
log_every_n_steps: 1
gradient_clipping: 0.1

# set to null to avoid setting a seed (can speed up GPU computation, but
# results will not be reproducible)
seed: 1234

# data
data:
batch_size: 512
num_workers: 8
max_atom: 8

# architecture
spatial_dimension: 3
model:
score_network:
architecture: diffusion_mace
number_of_atoms: 8
r_max: 5.0
num_bessel: 8
num_polynomial_cutoff: 5
max_ell: 2
interaction_cls: RealAgnosticResidualInteractionBlock
interaction_cls_first: RealAgnosticInteractionBlock
num_interactions: 2
hidden_irreps: 128x0e + 128x1o + 128x2e
mlp_irreps: 128x0e
number_of_mlp_layers: 0
avg_num_neighbors: 1
correlation: 3
gate: silu
radial_MLP: [128, 128, 128]
radial_type: bessel
noise:
total_time_steps: 100
sigma_min: 0.001 # default value
sigma_max: 0.5 # default value'

# optimizer and scheduler
optimizer:
name: adamw
learning_rate: 0.001
weight_decay: 1.0e-8

scheduler:
name: ReduceLROnPlateau
factor: 0.1
patience: 20

# early stopping
early_stopping:
metric: validation_epoch_loss
mode: min
patience: 10

model_checkpoint:
monitor: validation_epoch_loss
mode: min

# Sampling from the generative model
diffusion_sampling:
noise:
total_time_steps: 100
sigma_min: 0.001 # default value
sigma_max: 0.5 # default value
sampling:
spatial_dimension: 3
number_of_corrector_steps: 1
number_of_atoms: 8
number_of_samples: 1000
sample_every_n_epochs: 5
cell_dimensions: [5.43, 5.43, 5.43]

# A callback to check the loss vs. sigma
loss_monitoring:
number_of_bins: 50
sample_every_n_epochs: 2

logging:
- comet

0 comments on commit 85b1f6e

Please sign in to comment.