-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #10 from wengroup/devel
Add docs for training and preparing data
- Loading branch information
Showing
5 changed files
with
242 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"# Data Preparation\n", | ||
"\n", | ||
"This notebook gives an example of how to prepare your own data to train the model. " | ||
], | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"id": "b5c6b7dd5f67094c" | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"outputs": [], | ||
"source": [ | ||
"import numpy as np\n", | ||
"import pandas as pd\n", | ||
"from pathlib import Path\n", | ||
"from pymatgen.core import Structure" | ||
], | ||
"metadata": { | ||
"collapsed": false, | ||
"ExecuteTime": { | ||
"end_time": "2024-03-08T22:47:02.268364Z", | ||
"start_time": "2024-03-08T22:47:01.432158Z" | ||
} | ||
}, | ||
"id": "45862d0d3d88e62e", | ||
"execution_count": 1 | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"outputs": [], | ||
"source": [ | ||
"def get_structures():\n", | ||
" \"\"\"Create a pymatgen structure for Si.\"\"\"\n", | ||
" return Structure(\n", | ||
" lattice=np.array([[0, 2.73, 2.73], [2.73, 0, 2.73], [2.73, 2.73, 0]]),\n", | ||
" species=[\"Si\", \"Si\"],\n", | ||
" coords=[[0, 0, 0], [0.25, 0.25, 0.25]],\n", | ||
" )\n", | ||
"\n", | ||
"\n", | ||
"def get_tensor(seed: int = 35):\n", | ||
" \"\"\"Generate random 3x3x3x3 elastic tensor.\n", | ||
"\n", | ||
" Note, this is by no means a physical tensor that satisfies the symmetry of any\n", | ||
" crystal. It is just a random array to show the data preparation process.\n", | ||
" \"\"\"\n", | ||
" np.random.seed(seed)\n", | ||
" t = np.random.rand(3, 3, 3, 3)\n", | ||
"\n", | ||
" return t" | ||
], | ||
"metadata": { | ||
"collapsed": false, | ||
"ExecuteTime": { | ||
"end_time": "2024-03-08T22:47:02.272659Z", | ||
"start_time": "2024-03-08T22:47:02.270735Z" | ||
} | ||
}, | ||
"id": "39bef8ecadee2005", | ||
"execution_count": 2 | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"## Get data \n", | ||
"\n", | ||
"Here we simply make 10 copies of the Si structure and 10 copies of the elastic tensor. \n", | ||
"You should replace this with your own data." | ||
], | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"id": "a8bfc64bcfa05bd2" | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"outputs": [], | ||
"source": [ | ||
"Si = get_structures()\n", | ||
"t = get_tensor()\n", | ||
"\n", | ||
"structures = [Si for _ in range(10)]\n", | ||
"tensors = [t for _ in range(10)]" | ||
], | ||
"metadata": { | ||
"collapsed": false, | ||
"ExecuteTime": { | ||
"end_time": "2024-03-08T22:47:02.276072Z", | ||
"start_time": "2024-03-08T22:47:02.272804Z" | ||
} | ||
}, | ||
"id": "2edf1f8dd8c323ea", | ||
"execution_count": 3 | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"## Write data to file" | ||
], | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"id": "2cfecfec86d3444a" | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"outputs": [], | ||
"source": [ | ||
"def write_data(\n", | ||
" structures: list[Structure],\n", | ||
" tensors: list[np.ndarray],\n", | ||
" path: Path = \"elasticity_tensors.json\",\n", | ||
"):\n", | ||
" \"\"\"Write structures and tensors to file.\n", | ||
"\n", | ||
" Args:\n", | ||
" structures: list of pymatgen structures.\n", | ||
" tensors: list of 3x3x3x3 elastic tensors.\n", | ||
" path: path to write the data.\n", | ||
" \"\"\"\n", | ||
" data = {\n", | ||
" \"structure\": [s.as_dict() for s in structures],\n", | ||
" \"elastic_tensor_full\": [t.tolist() for t in tensors],\n", | ||
" }\n", | ||
" df = pd.DataFrame(data)\n", | ||
"\n", | ||
" df.to_json(path)" | ||
], | ||
"metadata": { | ||
"collapsed": false, | ||
"ExecuteTime": { | ||
"end_time": "2024-03-08T22:47:02.279873Z", | ||
"start_time": "2024-03-08T22:47:02.276910Z" | ||
} | ||
}, | ||
"id": "55b15d7d8af1a98d", | ||
"execution_count": 4 | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"outputs": [], | ||
"source": [ | ||
"write_data(structures, tensors)" | ||
], | ||
"metadata": { | ||
"collapsed": false, | ||
"ExecuteTime": { | ||
"end_time": "2024-03-08T22:47:02.286930Z", | ||
"start_time": "2024-03-08T22:47:02.281900Z" | ||
} | ||
}, | ||
"id": "26057e33efd0eaf9", | ||
"execution_count": 5 | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"outputs": [], | ||
"source": [], | ||
"metadata": { | ||
"collapsed": false, | ||
"ExecuteTime": { | ||
"end_time": "2024-03-08T22:47:02.287405Z", | ||
"start_time": "2024-03-08T22:47:02.285609Z" | ||
} | ||
}, | ||
"id": "ba7de54ec1bf9ba3", | ||
"execution_count": 5 | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 2 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython2", | ||
"version": "2.7.6" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,42 @@ | ||
# scripts | ||
|
||
This directory contains scripts to train/evaluate/test the models. | ||
|
||
## First try | ||
|
||
To train an example model, run the below command in this directory: | ||
|
||
```bash | ||
python train_materials_tensor.py | ||
``` | ||
|
||
This will train a toy model on a very small example dataset for only 10 epochs and then stop. | ||
|
||
After training, you can see the metrics and the location of the best model. | ||
|
||
### Under the hood | ||
|
||
All training configurations (data, optimizer, model etc.) are defined in the [./configs/materials_tensor.yaml](./configs/materials_tensor.yaml) file. Particularly you can see from the `data` section (copied below) that the data file we are using is `example_crystal_elasticity_tensor_n100.json` and it is stored at the [../datasets/](../datasets). Here we use the same file for training, validation and testing. This is just for demonstration purposes. In a real-world scenario, you would have separate files for each of these. | ||
|
||
```yaml | ||
data: | ||
root: ../datasets/ | ||
r_cut: 5.0 | ||
trainset_filename: example_crystal_elasticity_tensor_n100.json | ||
valset_filename: example_crystal_elasticity_tensor_n100.json | ||
testset_filename: example_crystal_elasticity_tensor_n100.json | ||
reuse: false | ||
loader_kwargs: | ||
batch_size: 32 | ||
shuffle: true | ||
``` | ||
Feel free to change other settings and see how the model behaves. You might find [this config](../pretrained/20230627/config_final.yaml) file useful for reference, which is the final configuration used to train the model in the paper. | ||
You might also want to uncomment the `logger` section to use [Weights and Biases](https://wandb.ai/site) for logging, which makes tracking the training process much easier. | ||
|
||
## Training with your own data | ||
|
||
All you need to do is to prepare your data as json file(s), similar to `example_crystal_elasticity_tensor_n100.json`. | ||
It needs a list of pymatgen `structure`s and the corresponding `elastic_tensor_full` for the structures. | ||
See this [notebook](../notebooks/prepare_data.ipynb) for more details on how to prepare the data. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters