Skip to content

Commit

Permalink
adding samplers
Browse files Browse the repository at this point in the history
  • Loading branch information
khaledkah committed Apr 3, 2024
1 parent 54c8a7e commit 89ba818
Show file tree
Hide file tree
Showing 28 changed files with 1,608 additions and 1,072 deletions.
732 changes: 509 additions & 223 deletions notebooks/denoising_tutorial.ipynb

Large diffs are not rendered by default.

20 changes: 10 additions & 10 deletions notebooks/diffusion_tutorial.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions src/morered/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
noise_schedules,
optimization,
transform,
sampling,
callbacks,
utils,
)
from morered.task import *
33 changes: 19 additions & 14 deletions src/morered/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ class SamplerCallback(Callback):
def __init__(
self,
sampler: Sampler,
t: Optional[Union[int, torch.Tensor]] = None,
t: Optional[int] = None,
max_steps: Optional[int] = None,
sample_prior: bool = True,
name: str = "sampling",
store_path: str = "samples",
Expand All @@ -115,6 +116,7 @@ def __init__(
Args:
sampler: sampler to be used for sampling/denoising.
t: time step to start denoising. Defaults noise to start from prior.
max_steps: maximum number of reverse steps when using MoreRed.
sample_prior: whether to sample from the prior or use input as start sample.
name: name of the callback.
store_path: path to store the results and samples.
Expand All @@ -129,6 +131,7 @@ def __init__(
super().__init__()
self.sampler = sampler
self.t = t
self.max_steps = max_steps
self.sample_prior = sample_prior
self.name = name
self.store_path = store_path
Expand All @@ -139,9 +142,6 @@ def __init__(
self.log_validity = log_validity
self.bonds_data = generate_bonds_data(bonds_data_path)

if isinstance(self.t, int):
self.t = torch.tensor([self.t])

if not os.path.exists(self.store_path):
os.makedirs(self.store_path)

Expand All @@ -158,9 +158,14 @@ def sample(
# update the sampling model
self.sampler.update_model(model)

# sample from the prior
if self.sample_prior:
x_t = self.sampler.sample_prior(batch, self.t)
batch.update(x_t)

# sample / denoise
samples, num_steps, hist = self.sampler(
batch, self.t, self.sample_prior # type: ignore
batch, t=self.t, max_steps=self.max_steps
)

# add important properties to save along with the sampled ones
Expand All @@ -186,9 +191,9 @@ def sample(
results = {
"samples": samples,
"hist": hist,
"num_steps": num_steps.cpu()
if isinstance(num_steps, torch.Tensor)
else num_steps,
"num_steps": (
num_steps.cpu() if isinstance(num_steps, torch.Tensor) else num_steps
),
"t": self.t.cpu() if isinstance(self.t, torch.Tensor) else self.t,
}

Expand Down Expand Up @@ -255,14 +260,14 @@ def _step(
connected = np.array(validity_res["connected"])
connected_wo_h = np.array(validity_res["connected_wo_h"])
results["bonds"] = validity_res["bonds"]
results["connectivity"] = torch.from_numpy(connected).to("cpu")
results["stable_atoms"] = torch.from_numpy(stable_ats).to("cpu")
results["stable_molecules"] = torch.from_numpy(stable_mols).to("cpu")
results["stable_atoms_wo_h"] = torch.from_numpy(stable_ats_wo_h).to("cpu")
results["connectivity"] = torch.from_numpy(connected).cpu()
results["stable_atoms"] = torch.from_numpy(stable_ats).cpu()
results["stable_molecules"] = torch.from_numpy(stable_mols).cpu()
results["stable_atoms_wo_h"] = torch.from_numpy(stable_ats_wo_h).cpu()
results["stable_molecules_wo_h"] = torch.from_numpy(stable_mols_wo_h).to(
"cpu"
)
results["connectivity_wo_h"] = torch.from_numpy(connected_wo_h).to("cpu")
results["connectivity_wo_h"] = torch.from_numpy(connected_wo_h).cpu()

# infer metrics from validity results
metrics = {
Expand Down Expand Up @@ -295,7 +300,7 @@ def _step(
else batch[properties.R]
)

res_rmsd = batch_rmsd(reference_R, results["samples"]).to("cpu")
res_rmsd = batch_rmsd(reference_R, results["samples"]).cpu()

results["rmsd"] = res_rmsd
metrics["rmsd"] = res_rmsd.mean()
Expand Down
15 changes: 15 additions & 0 deletions src/morered/configs/callbacks/sampling.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
sampling:
_target_: morered.callbacks.SamplerCallback
sampler: ${sampler}
name: sampling
t: ???
max_steps: ???
sample_prior: True
store_path: samples
every_n_batchs: 1
every_n_epochs: 200
start_epoch: 1
log_rmsd: False
log_validity: True
bonds_data_path: null

5 changes: 5 additions & 0 deletions src/morered/configs/experiment/vp_gauss_ddpm _qm9.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# @package _global_

defaults:
- vp_gauss_ddpm
- override /data: qm9_filtered
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ defaults:
- override /model: nnp
- override /data: qm7x
- override /task: diffusion_task
- override /sampler: ddpm

run:
experiment: vp_gauss_ddpm_jt
experiment: vp_gauss_ddpm

globals:
cutoff: 5.
Expand All @@ -15,7 +16,6 @@ globals:
noise_target_key: eps
noise_output_key: eps_pred
time_target_key: t
time_output_key: t_pred

noise_schedule:
_target_: morered.noise_schedules.PolynomialSchedule
Expand Down Expand Up @@ -46,7 +46,6 @@ data:
- _target_: morered.transform.Diffuse
diffuse_property: _positions
diffusion_process: ${globals.diffusion_process}
T: ${globals.noise_schedule.T}
time_key: ${globals.time_target_key}

- _target_: schnetpack.transform.MatScipyNeighborList
Expand All @@ -61,22 +60,15 @@ model:
cutoff: ${globals.cutoff}
n_atom_basis: ${globals.n_atom_basis}
output_modules:
- _target_: morered.model.heads.DiffusionTime
n_in: ${globals.n_atom_basis}
n_hidden: null
n_layers: 3
output_key: ${globals.time_output_key}
aggregation_mode: null
detach_representation: False
- _target_: morered.model.heads.TimeAwareEquivariant
n_in: ${globals.n_atom_basis}
n_hidden: null
n_layers: 3
output_key: ${globals.noise_output_key}
include_time: True
time_head: ${model.output_modules.0}
time_head: null
detach_time_head: False
time_key: ${globals.time_output_key}
time_key: ${globals.time_target_key}
do_postprocessing: True
postprocessors:
- _target_: morered.transform.BatchSubtractCenterOfMass
Expand All @@ -88,16 +80,6 @@ task:
skip_exploding_batches: True
include_l0: False
outputs:
- _target_: schnetpack.task.ModelOutput
name: ${globals.time_output_key}
target_property: ${globals.time_target_key}
loss_fn:
_target_: torch.nn.MSELoss
metrics:
mse:
_target_: torchmetrics.regression.MeanSquaredError
squared: True
loss_weight: 0.1
- _target_: morered.task.DiffModelOutput
name: ${globals.noise_output_key}
target_property: ${globals.noise_target_key}
Expand All @@ -107,7 +89,7 @@ task:
mse:
_target_: torchmetrics.regression.MeanSquaredError
squared: True
loss_weight: 0.9
loss_weight: 1.0
nll_metric: null
# _target_: morered.optimization.metrics.NLL
# noise_schedule: ${globals.noise_schedule}
Expand All @@ -118,3 +100,11 @@ task:
# time_key: ${globals.time_target_key}
# noise_key: ${globals.noise_target_key}
# noise_pred_key: ${globals.noise_output_key}

sampler:
denoiser: null

callbacks:
sampling:
t: null
max_steps: null
77 changes: 77 additions & 0 deletions src/morered/configs/experiment/vp_gauss_morered_jt.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# @package _global_

defaults:
- vp_gauss_ddpm
- override /sampler: morered_jt

run:
experiment: vp_gauss_morered_jt

globals:
time_output_key: t_pred

model:
output_modules:
- _target_: morered.model.heads.DiffusionTime
n_in: ${globals.n_atom_basis}
n_hidden: null
n_layers: 3
output_key: ${globals.time_output_key}
aggregation_mode: null
detach_representation: False
- _target_: morered.model.heads.TimeAwareEquivariant
n_in: ${globals.n_atom_basis}
n_hidden: null
n_layers: 3
output_key: ${globals.noise_output_key}
include_time: True
time_head: ${model.output_modules.0}
detach_time_head: False
time_key: ${globals.time_output_key}

task:
outputs:
- _target_: schnetpack.task.ModelOutput
name: ${globals.time_output_key}
target_property: ${globals.time_target_key}
loss_fn:
_target_: torch.nn.MSELoss
metrics:
mse:
_target_: torchmetrics.regression.MeanSquaredError
squared: True
loss_weight: 0.1
- _target_: morered.task.DiffModelOutput
name: ${globals.noise_output_key}
target_property: ${globals.noise_target_key}
loss_fn:
_target_: torch.nn.MSELoss
metrics:
mse:
_target_: torchmetrics.regression.MeanSquaredError
squared: True
loss_weight: 0.9
nll_metric: null

sampler:
denoiser: null

callbacks:
sampling:
t: null
max_steps: 2000

# denoising:
# _target_: morered.callbacks.SamplerCallback
# sampler: ${sampler}
# name: denoising
# t: 150
# max_steps: 1000
# sample_prior: True
# store_path: denoised
# every_n_batchs: 1
# every_n_epochs: 200
# start_epoch: 1
# log_rmsd: True
# log_validity: True
# bonds_data_path: null
5 changes: 5 additions & 0 deletions src/morered/configs/experiment/vp_gauss_morered_jt_qm9.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# @package _global_

defaults:
- vp_gauss_morered_jt
- override /data: qm9_filtered
42 changes: 42 additions & 0 deletions src/morered/configs/experiment/vp_gauss_time_predictor.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# @package _global_

defaults:
- vp_gauss_ddpm
- override /callbacks:
- checkpoint
- earlystopping
- lrmonitor
- ema

run:
experiment: vp_gauss_time_predictor

globals:
time_output_key: t_pred

model:
output_modules:
- _target_: morered.model.heads.DiffusionTime
n_in: ${globals.n_atom_basis}
n_hidden: null
n_layers: 3
output_key: ${globals.time_output_key}
aggregation_mode: null
detach_representation: False
postprocessors:
- _target_: schnetpack.transform.CastTo64

task:
outputs:
- _target_: schnetpack.task.ModelOutput
name: ${globals.time_output_key}
target_property: ${globals.time_target_key}
loss_fn:
_target_: torch.nn.MSELoss
metrics:
mse:
_target_: torchmetrics.regression.MeanSquaredError
squared: True
loss_weight: 1.0

sampler: null
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# @package _global_

defaults:
- vp_gauss_time_predictor
- override /data: qm9_filtered
11 changes: 11 additions & 0 deletions src/morered/configs/sampler/ddpm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
_target_: morered.sampling.DDPM
diffusion_process: ${globals.diffusion_process}
denoiser: ???
time_key: ${globals.time_target_key}
noise_pred_key: ${globals.noise_output_key}
cutoff: ${globals.cutoff}
recompute_neighbors: False
save_progress: False
progress_stride: 1
results_on_cpu: True
device: null
13 changes: 13 additions & 0 deletions src/morered/configs/sampler/morered.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
_target_: morered.sampling.MoreRed
diffusion_process: ${globals.diffusion_process}
denoiser: ???
time_key: ${globals.time_target_key}
noise_pred_key: ${globals.noise_output_key}
time_pred_key: ${globals.time_output_key}
convergence_step: 0
cutoff: ${globals.cutoff}
recompute_neighbors: False
save_progress: False
progress_stride: 1
results_on_cpu: True
device: null
5 changes: 5 additions & 0 deletions src/morered/configs/sampler/morered_as.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
defaults:
- morered

_target_: morered.sampling.MoreRedAS
time_predictor: ???
4 changes: 4 additions & 0 deletions src/morered/configs/sampler/morered_itp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
defaults:
- morered_as

_target_: morered.sampling.MoreRedITP
Loading

0 comments on commit 89ba818

Please sign in to comment.