Skip to content

Commit

Permalink
Merge pull request #13 from wengroup/devel
Browse files Browse the repository at this point in the history
Add training script for atomic tensor
  • Loading branch information
mjwen authored Jul 30, 2024
2 parents af698e7 + ad11b36 commit ccc770d
Show file tree
Hide file tree
Showing 7 changed files with 449 additions and 2 deletions.
117 changes: 117 additions & 0 deletions scripts/configs/atomic_tensor.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
## Config files for atomic tensor (i.e. a tensor value for each atom)

seed_everything: 35
log_level: info

data:
tensor_target_name: nmr_tensor
atom_selector: atom_selector
tensor_target_formula: ij=ji
root: .
trainset_filename: /Users/mjwen.admin/Documents/Dataset/NMR_tensor/si_nmr_data_small.json
valset_filename: /Users/mjwen.admin/Documents/Dataset/NMR_tensor/si_nmr_data_small.json
testset_filename: /Users/mjwen.admin/Documents/Dataset/NMR_tensor/si_nmr_data_small.json
r_cut: 5.0
reuse: false
loader_kwargs:
batch_size: 2
shuffle: true

model:
##########
# embedding
##########

# atom species embedding
species_embedding_dim: 16

# spherical harmonics embedding of edge direction
irreps_edge_sh: 0e + 1o + 2e

# radial edge distance embedding
radial_basis_type: bessel
num_radial_basis: 8
radial_basis_start: 0.
radial_basis_end: 5.

##########
# message passing conv layers
##########
num_layers: 3

# radial network
invariant_layers: 2 # number of radial layers
invariant_neurons: 32 # number of hidden neurons in radial function

# Average number of neighbors used for normalization. Options:
# 1. `auto` to determine it automatically, by setting it to average number
# of neighbors of the training set
# 2. float or int provided here.
# 3. `null` to not use it
average_num_neighbors: auto

# point convolution
conv_layer_irreps: 32x0o+32x0e + 16x1o+16x1e + 4x2o+4x2e
nonlinearity_type: gate
normalization: batch
resnet: true

##########
# output
##########

# output_format and output_formula should be used together.
# - output_format (can be `irreps` or `cartesian`) determines what the loss
# function will be on (either on the irreps space or the cartesian space).
# - output_formula gives what the cartesian formula of the tensor is.
# For example, ijkl=jikl=klij specifies a forth-rank elasticity tensor.
output_format: irreps
output_formula: ij=ji

# pooling node feats to graph feats
reduce: mean

trainer:
max_epochs: 10 # number of maximum training epochs
num_nodes: 1
accelerator: cpu
devices: 1

callbacks:
- class_path: pytorch_lightning.callbacks.ModelCheckpoint
init_args:
monitor: val/score
mode: min
save_top_k: 3
save_last: true
verbose: false
- class_path: pytorch_lightning.callbacks.EarlyStopping
init_args:
monitor: val/score
mode: min
patience: 150
min_delta: 0
verbose: true
- class_path: pytorch_lightning.callbacks.ModelSummary
init_args:
max_depth: -1

#logger:
# class_path: pytorch_lightning.loggers.wandb.WandbLogger
# init_args:
# save_dir: matten_logs
# project: matten_proj

optimizer:
class_path: torch.optim.Adam
init_args:
lr: 0.01
weight_decay: 0.00001

lr_scheduler:
class_path: torch.optim.lr_scheduler.ReduceLROnPlateau
init_args:
mode: min
factor: 0.5
patience: 50
verbose: true
1 change: 1 addition & 0 deletions scripts/configs/materials_tensor.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ log_level: info

data:
root: ../datasets/
tensor_target_name: elastic_tensor_full
trainset_filename: example_crystal_elasticity_tensor_n100.json
valset_filename: example_crystal_elasticity_tensor_n100.json
testset_filename: example_crystal_elasticity_tensor_n100.json
Expand Down
81 changes: 81 additions & 0 deletions scripts/train_atomic_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""Script to train the materials tensor model."""

from pathlib import Path
from typing import Dict, List, Union

import yaml
from loguru import logger
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.cli import instantiate_class as lit_instantiate_class

from matten.dataset.structure_scalar_tensor import TensorDataModule
from matten.log import set_logger
from matten.model_factory.task import TensorRegressionTask
from matten.model_factory.tfn_atomic_tensor import AtomicTensorModel


def instantiate_class(d: Union[Dict, List]):
args = tuple() # no positional args
if isinstance(d, dict):
return lit_instantiate_class(args, d)
elif isinstance(d, list):
return [lit_instantiate_class(args, x) for x in d]
else:
raise ValueError(f"Cannot instantiate class from {d}")


def get_args(path: Path):
"""Get the arguments from the config file."""
with open(path, "r") as f:
config = yaml.safe_load(f)
return config


def main(config: Dict):
dm = TensorDataModule(**config["data"])
dm.prepare_data()
dm.setup()

model = AtomicTensorModel(
tasks=TensorRegressionTask(name=config["data"]["tensor_target_name"]),
backbone_hparams=config["model"],
dataset_hparams=dm.get_to_model_info(),
optimizer_hparams=config["optimizer"],
lr_scheduler_hparams=config["lr_scheduler"],
)

try:
callbacks = instantiate_class(config["trainer"].pop("callbacks"))
lit_logger = instantiate_class(config["trainer"].pop("logger"))
except KeyError:
callbacks = None
lit_logger = None

trainer = Trainer(
callbacks=callbacks,
logger=lit_logger,
**config["trainer"],
)

logger.info("Start training!")
trainer.fit(model, datamodule=dm)

# test
logger.info("Start testing!")
trainer.test(ckpt_path="best", datamodule=dm)

# print path of best checkpoint
logger.info(f"Best checkpoint path: {trainer.checkpoint_callback.best_model_path}")


if __name__ == "__main__":
config_file = Path(__file__).parent / "configs" / "atomic_tensor.yaml"
config = get_args(config_file)

seed = config.get("seed_everything", 1)
seed_everything(seed)

log_level = config.get("log_level", "INFO")
set_logger(log_level)

main(config)
2 changes: 1 addition & 1 deletion scripts/train_materials_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def main(config: Dict):
dm.setup()

model = ScalarTensorModel(
tasks=TensorRegressionTask(name="elastic_tensor_full"),
tasks=TensorRegressionTask(name=config["data"]["tensor_target_name"]),
backbone_hparams=config["model"],
dataset_hparams=dm.get_to_model_info(),
optimizer_hparams=config["optimizer"],
Expand Down
15 changes: 14 additions & 1 deletion src/matten/dataset/structure_scalar_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ class TensorDataset(InMemoryDataset):
10.0, 0:1.0}}, then for data points with minLC less than 0, it will have
a weight of 10, and for those with minLC larger than 0, it will have a
weight of 1. By default `None` means all data points has a weight of 1.
atom_selector: a list of bools to indicate which atoms to use in the structure
to compute the target. If `None`, all atoms are used.
global_featurizer: featurizer to compute global features.
normalize_global_features: whether to normalize the global feature.
atom_featuruzer: featurizer to compute atom features.
Expand Down Expand Up @@ -79,6 +81,7 @@ def __init__(
log_scalar_targets: List[bool] = None,
normalize_scalar_targets: List[bool] = None,
tensor_target_weight: Dict[str, Dict[str, float]] = None,
atom_selector: List[bool] = None,
global_featurizer: None = None,
normalize_global_features: bool = False,
atom_featurizer: None = None,
Expand All @@ -97,6 +100,7 @@ def __init__(
self.normalize_tensor_target = normalize_tensor_target

self.tensor_target_weight = tensor_target_weight
self.atom_selector = atom_selector

self.scalar_target_names = (
[] if scalar_target_names is None else scalar_target_names
Expand Down Expand Up @@ -259,7 +263,7 @@ def _get_crystals(self, df):
# TODO, convert to irreps tensor, assuming all input tensor is Cartesian
converter = CartesianTensorWrapper(formula=self.tensor_target_formula)
df[self.tensor_target_name] = df[self.tensor_target_name].apply(
lambda x: converter.from_cartesian(x).reshape(1, -1)
lambda x: torch.atleast_2d(converter.from_cartesian(x))
)
elif self.tensor_target_format == "cartesian":
df[self.tensor_target_name] = df[self.tensor_target_name].apply(
Expand Down Expand Up @@ -303,6 +307,10 @@ def _get_crystals(self, df):
y[self.tensor_target_name] * self.tensor_target_scale
)

# atom selector
if self.atom_selector is not None:
y["atom_selector"] = torch.as_tensor(row[self.atom_selector])

x = None
if self.global_featurizer:
# feats
Expand Down Expand Up @@ -414,6 +422,7 @@ def __init__(
tensor_target_scale: float = 1.0,
normalize_tensor_target: bool = False,
tensor_target_weight: Dict[str, Dict[str, float]] = None,
atom_selector: List[bool] = None,
scalar_target_names: List[str] = None,
log_scalar_targets: List[bool] = None,
normalize_scalar_targets: List[bool] = None,
Expand Down Expand Up @@ -457,6 +466,7 @@ def __init__(
self.tensor_target_scale = tensor_target_scale
self.normalize_tensor_target = normalize_tensor_target
self.tensor_target_weight = tensor_target_weight
self.atom_selector = atom_selector

self.scalar_target_names = scalar_target_names
self.log_scalar_targets = log_scalar_targets
Expand Down Expand Up @@ -563,6 +573,7 @@ def setup(self, stage: Optional[str] = None):
log_scalar_targets=self.log_scalar_targets,
normalize_scalar_targets=self.normalize_scalar_targets,
tensor_target_weight=self.tensor_target_weight,
atom_selector=self.atom_selector,
global_featurizer=gf,
normalize_global_features=self.normalize_global_features,
atom_featurizer=af,
Expand All @@ -584,6 +595,7 @@ def setup(self, stage: Optional[str] = None):
log_scalar_targets=self.log_scalar_targets,
normalize_scalar_targets=self.normalize_scalar_targets,
tensor_target_weight=self.tensor_target_weight,
atom_selector=self.atom_selector,
global_featurizer=gf,
normalize_global_features=self.normalize_global_features,
atom_featurizer=af,
Expand All @@ -605,6 +617,7 @@ def setup(self, stage: Optional[str] = None):
log_scalar_targets=self.log_scalar_targets,
normalize_scalar_targets=self.normalize_scalar_targets,
tensor_target_weight=self.tensor_target_weight,
atom_selector=self.atom_selector,
global_featurizer=gf,
normalize_global_features=self.normalize_global_features,
atom_featurizer=af,
Expand Down
7 changes: 7 additions & 0 deletions src/matten/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,11 @@ def shared_step(self, batch, mode: str):
# ========== compute predictions ==========
preds = self.decode(graphs)

# select atoms
if "atom_selector" in labels:
selector = labels["atom_selector"]
preds = {k: v[selector] for k, v in preds.items()}

# ========== compute losses ==========
target_weight = graphs.get("target_weight", None)
individual_loss, total_loss = self.compute_loss(
Expand Down Expand Up @@ -504,6 +509,8 @@ def preprocess_batch(self, batch: DataPoint) -> Tuple[DataPoint, Dict[str, Tenso

# task labels
labels = {name: graphs.y[name] for name in self.tasks}
if "atom_selector" in graphs.y:
labels["atom_selector"] = graphs.y["atom_selector"]

# convert graphs to a dict to use NequIP stuff
graphs = graphs.tensor_property_to_dict()
Expand Down
Loading

0 comments on commit ccc770d

Please sign in to comment.