Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable predict for tensors other than elastic #14

Merged
merged 2 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions notebooks/predict_general.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""
An example script make predictions of any tensor.
"""

from pymatgen.core import Structure

from matten.predict import predict


def get_structure():
a = 5.46
lattice = [[0, a / 2, a / 2], [a / 2, 0, a / 2], [a / 2, a / 2, 0]]
basis = [[0.0, 0.0, 0.0], [0.25, 0.25, 0.25]]
Si = Structure(lattice, ["Si", "Si"], basis)

return Si


if __name__ == "__main__":
structure = get_structure()

tensor = predict(
structure,
model_identifier="/Users/mjwen.admin/Packages/matten_wengroup/scripts",
checkpoint="epoch=9-step=10.ckpt",
is_elasticity_tensor=False,
)

print("value:", tensor)
print("type:", type(tensor))
print("shape:", tensor.shape)
116 changes: 116 additions & 0 deletions scripts/configs/pizeoelectric.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
seed_everything: 35
log_level: info

data:
root: .
tensor_target_name: piezoelectric_tensor_total
tensor_target_formula: ijk=ikj
trainset_filename: /Users/mjwen.admin/Documents/Dataset/di_pizeoelectric_tensor/piezoelectric_tensors_n20.json
valset_filename: /Users/mjwen.admin/Documents/Dataset/di_pizeoelectric_tensor/piezoelectric_tensors_n20.json
testset_filename: /Users/mjwen.admin/Documents/Dataset/di_pizeoelectric_tensor/piezoelectric_tensors_n20.json
r_cut: 5.0
reuse: false
loader_kwargs:
batch_size: 32
shuffle: true

model:
##########
# embedding
##########

# atom species embedding
species_embedding_dim: 16

# spherical harmonics embedding of edge direction
irreps_edge_sh: 0e + 1o + 2e + 3o

# radial edge distance embedding
radial_basis_type: bessel
num_radial_basis: 8
radial_basis_start: 0.
radial_basis_end: 5.

##########
# message passing conv layers
##########
num_layers: 3

# radial network
invariant_layers: 2 # number of radial layers
invariant_neurons: 32 # number of hidden neurons in radial function

# Average number of neighbors used for normalization. Options:
# 1. `auto` to determine it automatically, by setting it to average number
# of neighbors of the training set
# 2. float or int provided here.
# 3. `null` to not use it
average_num_neighbors: auto

# point convolution
conv_layer_irreps: 32x0o+32x0e + 16x1o+16x1e + 4x2o+4x2e + 2x3o+2x3e
nonlinearity_type: gate
normalization: batch
resnet: true

##########
# output
##########

conv_to_output_hidden_irreps_out: 16x1o + 4x2o + 2x3o

# output_format and output_formula should be used together.
# - output_format (can be `irreps` or `cartesian`) determines what the loss
# function will be on (either on the irreps space or the cartesian space).
# - output_formula gives what the cartesian formula of the tensor is.
# For example, ijkl=jikl=klij specifies a forth-rank elasticity tensor.
output_format: irreps
output_formula: ijk=ikj

# pooling node feats to graph feats
reduce: mean

trainer:
max_epochs: 10 # number of maximum training epochs
num_nodes: 1
accelerator: cpu
devices: 1

callbacks:
- class_path: pytorch_lightning.callbacks.ModelCheckpoint
init_args:
monitor: val/score
mode: min
save_top_k: 3
save_last: true
verbose: false
- class_path: pytorch_lightning.callbacks.EarlyStopping
init_args:
monitor: val/score
mode: min
patience: 150
min_delta: 0
verbose: true
- class_path: pytorch_lightning.callbacks.ModelSummary
init_args:
max_depth: -1

#logger:
# class_path: pytorch_lightning.loggers.wandb.WandbLogger
# init_args:
# save_dir: matten_logs
# project: matten_proj

optimizer:
class_path: torch.optim.Adam
init_args:
lr: 0.01
weight_decay: 0.00001

lr_scheduler:
class_path: torch.optim.lr_scheduler.ReduceLROnPlateau
init_args:
mode: min
factor: 0.5
patience: 50
verbose: true
12 changes: 11 additions & 1 deletion src/matten/dataset/structure_scalar_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,10 +388,20 @@ def __init__(
filename: str, # this can be None
r_cut: float,
structures: List[Structure],
tensor_target_name: str = "elastic_tensor_full",
tensor_target_format: str = "irreps",
tensor_target_formula: str = "ijkl=jikl=klij",
**kwargs,
):
self.structures = structures
super().__init__(filename, r_cut, **kwargs)
super().__init__(
filename,
r_cut=r_cut,
tensor_target_name=tensor_target_name,
tensor_target_format=tensor_target_format,
tensor_target_formula=tensor_target_formula,
**kwargs,
)

def get_data(self):
# process info like ijkl=jikl=klij to get unique indices, ijkl
Expand Down
35 changes: 27 additions & 8 deletions src/matten/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@ def get_data_loader(
"testset_filename",
"compute_dataset_statistics",
]:
config.pop(k)
try:
config.pop(k)
except KeyError:
pass

r_cut = config.pop("r_cut")
config["dataset_statistics_fn"] = None
Expand Down Expand Up @@ -110,29 +113,31 @@ def check_species(model, structures: List[Structure]):
def evaluate(
model,
loader,
target_name: str = "elastic_tensor_full",
tensor_target_name: str = "elastic_tensor_full",
tensor_target_formula="ijkl=jikl=klij",
) -> List[torch.Tensor]:
"""
Evaluate the model to generate predictions.

Args:
model: the model to evaluate.
loader: the data loader.
target_name: the name of the target property.
tensor_target_name: the name of the target property.
tensor_target_formula: the formula of the target property.

Returns:
a list of predicted elastic tensors.
"""

converter = CartesianTensorWrapper("ijkl=jikl=klij")
converter = CartesianTensorWrapper(tensor_target_formula)

predictions = []

model.eval()
with torch.no_grad():
for batch in tqdm.tqdm(loader):
preds, _ = model(batch, task_name=target_name)
p = preds[target_name]
preds, _ = model(batch, task_name=tensor_target_name)
p = preds[tensor_target_name]
p = converter.to_cartesian(p)
predictions.extend(p)

Expand All @@ -145,6 +150,7 @@ def predict(
checkpoint: str = "model_final.ckpt",
batch_size: int = 200,
logger_level: str = "ERROR",
is_elasticity_tensor: bool = True,
) -> Union[ElasticTensor, List[ElasticTensor]]:
f"""
Predict the property of a structure or a list of structures.
Expand All @@ -165,6 +171,9 @@ def predict(
but it may be limited by the CPU memory.
logger_level: the level of the logger. Options are `DEBUG`, `INFO`, `WARNING`,
`ERROR`, and `CRITICAL`.
is_elasticity_tensor: whether the target property is an elasticity tensor. If
`True`, the returned value will be a pymargen `ElasticTensor` object.
Otherwise, it will be numpy array.

Returns:
Predicted tensor(s). `None` if the model cannot make prediction for a structure.
Expand All @@ -181,8 +190,18 @@ def predict(
check_species(model, structure)
loader = get_data_loader(structure, model_identifier, batch_size=batch_size)

predictions = evaluate(model, loader)
predictions = [ElasticTensor(t) for t in predictions]
config = get_pretrained_config(model_identifier)

predictions = evaluate(
model,
loader,
tensor_target_name=config["data"]["tensor_target_name"],
tensor_target_formula=config["data"]["tensor_target_formula"],
)
if is_elasticity_tensor:
predictions = [ElasticTensor(t) for t in predictions]
else:
predictions = [t.numpy() for t in predictions]

# deal with failed entries
failed = set(loader.dataset.failed_entries)
Expand Down
Loading