Skip to content

Commit

Permalink
Merge pull request #10 from wengroup/devel
Browse files Browse the repository at this point in the history
Add docs for training and preparing data
  • Loading branch information
mjwen authored Mar 8, 2024
2 parents 2a04699 + 4cf6136 commit 78fa1ea
Show file tree
Hide file tree
Showing 5 changed files with 242 additions and 2 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ crystals, pass a list of structures to `predict`.
- An example of 100 crystals is available in the [datasets](./datasets) directory.
- The full dataset is available at: https://doi.org/10.5281/zenodo.8190849

## Train the model (using your own data)

See instructions [here](./scripts/README.md).

## Reference

```
Expand Down
197 changes: 197 additions & 0 deletions notebooks/prepare_data.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
{
"cells": [
{
"cell_type": "markdown",
"source": [
"# Data Preparation\n",
"\n",
"This notebook gives an example of how to prepare your own data to train the model. "
],
"metadata": {
"collapsed": false
},
"id": "b5c6b7dd5f67094c"
},
{
"cell_type": "code",
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"from pathlib import Path\n",
"from pymatgen.core import Structure"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-03-08T22:47:02.268364Z",
"start_time": "2024-03-08T22:47:01.432158Z"
}
},
"id": "45862d0d3d88e62e",
"execution_count": 1
},
{
"cell_type": "code",
"outputs": [],
"source": [
"def get_structures():\n",
" \"\"\"Create a pymatgen structure for Si.\"\"\"\n",
" return Structure(\n",
" lattice=np.array([[0, 2.73, 2.73], [2.73, 0, 2.73], [2.73, 2.73, 0]]),\n",
" species=[\"Si\", \"Si\"],\n",
" coords=[[0, 0, 0], [0.25, 0.25, 0.25]],\n",
" )\n",
"\n",
"\n",
"def get_tensor(seed: int = 35):\n",
" \"\"\"Generate random 3x3x3x3 elastic tensor.\n",
"\n",
" Note, this is by no means a physical tensor that satisfies the symmetry of any\n",
" crystal. It is just a random array to show the data preparation process.\n",
" \"\"\"\n",
" np.random.seed(seed)\n",
" t = np.random.rand(3, 3, 3, 3)\n",
"\n",
" return t"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-03-08T22:47:02.272659Z",
"start_time": "2024-03-08T22:47:02.270735Z"
}
},
"id": "39bef8ecadee2005",
"execution_count": 2
},
{
"cell_type": "markdown",
"source": [
"## Get data \n",
"\n",
"Here we simply make 10 copies of the Si structure and 10 copies of the elastic tensor. \n",
"You should replace this with your own data."
],
"metadata": {
"collapsed": false
},
"id": "a8bfc64bcfa05bd2"
},
{
"cell_type": "code",
"outputs": [],
"source": [
"Si = get_structures()\n",
"t = get_tensor()\n",
"\n",
"structures = [Si for _ in range(10)]\n",
"tensors = [t for _ in range(10)]"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-03-08T22:47:02.276072Z",
"start_time": "2024-03-08T22:47:02.272804Z"
}
},
"id": "2edf1f8dd8c323ea",
"execution_count": 3
},
{
"cell_type": "markdown",
"source": [
"## Write data to file"
],
"metadata": {
"collapsed": false
},
"id": "2cfecfec86d3444a"
},
{
"cell_type": "code",
"outputs": [],
"source": [
"def write_data(\n",
" structures: list[Structure],\n",
" tensors: list[np.ndarray],\n",
" path: Path = \"elasticity_tensors.json\",\n",
"):\n",
" \"\"\"Write structures and tensors to file.\n",
"\n",
" Args:\n",
" structures: list of pymatgen structures.\n",
" tensors: list of 3x3x3x3 elastic tensors.\n",
" path: path to write the data.\n",
" \"\"\"\n",
" data = {\n",
" \"structure\": [s.as_dict() for s in structures],\n",
" \"elastic_tensor_full\": [t.tolist() for t in tensors],\n",
" }\n",
" df = pd.DataFrame(data)\n",
"\n",
" df.to_json(path)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-03-08T22:47:02.279873Z",
"start_time": "2024-03-08T22:47:02.276910Z"
}
},
"id": "55b15d7d8af1a98d",
"execution_count": 4
},
{
"cell_type": "code",
"outputs": [],
"source": [
"write_data(structures, tensors)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-03-08T22:47:02.286930Z",
"start_time": "2024-03-08T22:47:02.281900Z"
}
},
"id": "26057e33efd0eaf9",
"execution_count": 5
},
{
"cell_type": "code",
"outputs": [],
"source": [],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-03-08T22:47:02.287405Z",
"start_time": "2024-03-08T22:47:02.285609Z"
}
},
"id": "ba7de54ec1bf9ba3",
"execution_count": 5
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
39 changes: 39 additions & 0 deletions scripts/README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,42 @@
# scripts

This directory contains scripts to train/evaluate/test the models.

## First try

To train an example model, run the below command in this directory:

```bash
python train_materials_tensor.py
```

This will train a toy model on a very small example dataset for only 10 epochs and then stop.

After training, you can see the metrics and the location of the best model.

### Under the hood

All training configurations (data, optimizer, model etc.) are defined in the [./configs/materials_tensor.yaml](./configs/materials_tensor.yaml) file. Particularly you can see from the `data` section (copied below) that the data file we are using is `example_crystal_elasticity_tensor_n100.json` and it is stored at the [../datasets/](../datasets). Here we use the same file for training, validation and testing. This is just for demonstration purposes. In a real-world scenario, you would have separate files for each of these.

```yaml
data:
root: ../datasets/
r_cut: 5.0
trainset_filename: example_crystal_elasticity_tensor_n100.json
valset_filename: example_crystal_elasticity_tensor_n100.json
testset_filename: example_crystal_elasticity_tensor_n100.json
reuse: false
loader_kwargs:
batch_size: 32
shuffle: true
```
Feel free to change other settings and see how the model behaves. You might find [this config](../pretrained/20230627/config_final.yaml) file useful for reference, which is the final configuration used to train the model in the paper.
You might also want to uncomment the `logger` section to use [Weights and Biases](https://wandb.ai/site) for logging, which makes tracking the training process much easier.

## Training with your own data

All you need to do is to prepare your data as json file(s), similar to `example_crystal_elasticity_tensor_n100.json`.
It needs a list of pymatgen `structure`s and the corresponding `elastic_tensor_full` for the structures.
See this [notebook](../notebooks/prepare_data.ipynb) for more details on how to prepare the data.
2 changes: 1 addition & 1 deletion scripts/configs/materials_tensor.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ log_level: info

data:
root: ../datasets/
r_cut: 5.0
trainset_filename: example_crystal_elasticity_tensor_n100.json
valset_filename: example_crystal_elasticity_tensor_n100.json
testset_filename: example_crystal_elasticity_tensor_n100.json
r_cut: 5.0
reuse: false
loader_kwargs:
batch_size: 32
Expand Down
2 changes: 1 addition & 1 deletion src/matten/utils_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

def get_git_repo_commit(repo_path: Path) -> str:
"""
Get the latest git commit info of a github repository.
Get the latest git commit info of a git repository.
Args:
repo_path: path to the repo
Expand Down

0 comments on commit 78fa1ea

Please sign in to comment.