Skip to content

Commit

Permalink
Add more docs
Browse files Browse the repository at this point in the history
  • Loading branch information
mjwen committed Jun 22, 2024
1 parent ebd4564 commit 35e2b02
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions src/matten/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,9 @@ def predict(
If you want to use your own model, you can provide the path to a directory,
and the directory should contain two files:
- `model_final.ckpt`: your trained model, which can be found in the
job directory after training.
job directory after training. Note, in your job directory, it may
have a different name, e.g., `model_epoch=100.ckpt`. You can rename
it or change the `checkpoint` argument to match the name.
- `config_final.yaml`: the configuration file used to train the model.
checkpoint: the checkpoint file to use. The default is `model_final.ckpt`.
batch_size: the batch size for prediction. In general, the larger the faster,
Expand Down Expand Up @@ -187,25 +189,25 @@ def predict(

if failed:
idx = 0
elastic_tensors = []
pred_tensors = []
for i in range(len(structure)):
if i in failed:
elastic_tensors.append(None)
pred_tensors.append(None)
else:
elastic_tensors.append(predictions[idx])
pred_tensors.append(predictions[idx])
idx += 1

warnings.warn(
"Cannot make predictions for the following structures. Their returned "
f"elasticity tensor set to `None`: {sorted(failed)}."
)
else:
elastic_tensors = predictions
pred_tensors = predictions

if single_struct:
return elastic_tensors[0]
return pred_tensors[0]
else:
return elastic_tensors
return pred_tensors


if __name__ == "__main__":
Expand Down

0 comments on commit 35e2b02

Please sign in to comment.