Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prediction for atomic tensor #15

Merged
merged 2 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions notebooks/predict_atomic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""
An example script make predictions of any tensor.
"""

from pymatgen.core import Structure

from matten.predict import predict


def get_structure():
a = 5.46
lattice = [[0, a / 2, a / 2], [a / 2, 0, a / 2], [a / 2, a / 2, 0]]
basis = [[0.0, 0.0, 0.0], [0.25, 0.25, 0.25]]
Si = Structure(lattice, ["Si", "Si"], basis)

return Si


if __name__ == "__main__":
structure = get_structure()

# predict for one structure
tensors = predict(
structure,
model_identifier="/Users/mjwen.admin/Downloads/trained",
checkpoint="epoch=9-step=100.ckpt",
is_atomic_tensor=True,
)
print("value:", tensors)
116 changes: 0 additions & 116 deletions scripts/configs/pizeoelectric.yaml

This file was deleted.

28 changes: 22 additions & 6 deletions src/matten/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from matten.dataset.structure_scalar_tensor import TensorDatasetPrediction
from matten.log import set_logger
from matten.model_factory.tfn_atomic_tensor import AtomicTensorModel
from matten.model_factory.tfn_scalar_tensor import ScalarTensorModel
from matten.utils import CartesianTensorWrapper, yaml_load

Expand All @@ -31,9 +32,11 @@ def get_pretrained_model_dir(identifier: str) -> Path:
return Path(__file__).parent.parent.parent / "pretrained" / identifier


def get_pretrained_model(identifier: str, checkpoint: str = "model_final.ckpt"):
def get_pretrained_model(
identifier: str, checkpoint: str = "model_final.ckpt", model_class=ScalarTensorModel
):
directory = get_pretrained_model_dir(identifier)
model = ScalarTensorModel.load_from_checkpoint(
model = model_class.load_from_checkpoint(
checkpoint_path=directory.joinpath(checkpoint).as_posix(),
map_location="cpu",
)
Expand Down Expand Up @@ -62,6 +65,7 @@ def get_data_loader(
"valset_filename",
"testset_filename",
"compute_dataset_statistics",
"atom_selector",
]:
try:
config.pop(k)
Expand Down Expand Up @@ -151,6 +155,7 @@ def predict(
batch_size: int = 200,
logger_level: str = "ERROR",
is_elasticity_tensor: bool = True,
is_atomic_tensor: bool = False,
) -> Union[ElasticTensor, List[ElasticTensor]]:
f"""
Predict the property of a structure or a list of structures.
Expand All @@ -174,6 +179,8 @@ def predict(
is_elasticity_tensor: whether the target property is an elasticity tensor. If
`True`, the returned value will be a pymargen `ElasticTensor` object.
Otherwise, it will be numpy array.
is_atomic_tensor: whether the target property is an atomic tensor. If `True`,
we predict a tensor value for each atom in the structure.

Returns:
Predicted tensor(s). `None` if the model cannot make prediction for a structure.
Expand All @@ -186,7 +193,16 @@ def predict(
else:
single_struct = False

model = get_pretrained_model(identifier=model_identifier, checkpoint=checkpoint)
if is_atomic_tensor:
model_class = AtomicTensorModel
is_elasticity_tensor = False
else:
model_class = ScalarTensorModel
model = get_pretrained_model(
identifier=model_identifier,
checkpoint=checkpoint,
model_class=model_class,
)
check_species(model, structure)
loader = get_data_loader(structure, model_identifier, batch_size=batch_size)

Expand Down Expand Up @@ -223,10 +239,10 @@ def predict(
else:
pred_tensors = predictions

if single_struct:
if single_struct and not is_atomic_tensor:
return pred_tensors[0]
else:
return pred_tensors

return pred_tensors


if __name__ == "__main__":
Expand Down
Loading