From 0555144eea0a171e178b03afb094786b8b3e3077 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Wed, 13 Nov 2024 20:50:38 +0000 Subject: [PATCH 01/17] initial repo --- .gitattributes | 2 + .github/ISSUE_TEMPLATE/bug_report.md | 38 + .github/ISSUE_TEMPLATE/feature_request.md | 20 + .github/workflows/lint.yml | 46 + .github/workflows/pre-commit.yaml | 28 + .github/workflows/release.yaml | 36 + .github/workflows/unittest.yaml | 29 + .gitignore | 35 + .mypy.ini | 39 + .pre-commit-config.yaml | 59 + LICENSE.md | 13 + MANIFEST.in | 1 + README.md | 349 ++++++ mace/__init__.py | 71 ++ mace/__version__.py | 3 + mace/calculators/__init__.py | 11 + mace/calculators/foundations_models.py | 262 ++++ .../foundations_models/mp_vasp_e0.json | 91 ++ mace/calculators/lammps_mace.py | 105 ++ mace/calculators/mace.py | 414 ++++++ mace/cli/__init__.py | 0 mace/cli/active_learning_md.py | 193 +++ mace/cli/convert_device.py | 31 + mace/cli/create_lammps_model.py | 92 ++ mace/cli/eval_configs.py | 165 +++ mace/cli/fine_tuning_select.py | 348 ++++++ mace/cli/plot_train.py | 193 +++ mace/cli/preprocess_data.py | 288 +++++ mace/cli/run_train.py | 811 ++++++++++++ mace/cli/select_head.py | 33 + mace/data/__init__.py | 34 + mace/data/atomic_data.py | 241 ++++ mace/data/hdf5_dataset.py | 93 ++ mace/data/neighborhood.py | 66 + mace/data/utils.py | 408 ++++++ mace/modules/__init__.py | 113 ++ mace/modules/blocks.py | 964 ++++++++++++++ mace/modules/irreps_tools.py | 94 ++ mace/modules/loss.py | 383 ++++++ mace/modules/models.py | 1109 +++++++++++++++++ mace/modules/radial.py | 323 +++++ mace/modules/symmetric_contraction.py | 233 ++++ mace/modules/utils.py | 442 +++++++ mace/py.typed | 1 + mace/tools/__init__.py | 71 ++ mace/tools/arg_parser.py | 878 +++++++++++++ mace/tools/arg_parser_tools.py | 113 ++ mace/tools/cg.py | 131 ++ mace/tools/checkpoint.py | 227 ++++ mace/tools/compile.py | 95 ++ mace/tools/finetuning_utils.py | 204 +++ mace/tools/model_script_utils.py | 228 ++++ mace/tools/multihead_tools.py | 185 +++ mace/tools/scatter.py | 112 ++ mace/tools/scripts_utils.py | 785 ++++++++++++ mace/tools/slurm_distributed.py | 34 + mace/tools/tables_utils.py | 241 ++++ mace/tools/torch_geometric/README.md | 12 + mace/tools/torch_geometric/__init__.py | 7 + mace/tools/torch_geometric/batch.py | 257 ++++ mace/tools/torch_geometric/data.py | 441 +++++++ mace/tools/torch_geometric/dataloader.py | 87 ++ mace/tools/torch_geometric/dataset.py | 280 +++++ mace/tools/torch_geometric/seed.py | 17 + mace/tools/torch_geometric/utils.py | 54 + mace/tools/torch_tools.py | 141 +++ mace/tools/train.py | 538 ++++++++ mace/tools/utils.py | 147 +++ pyproject.toml | 41 + scripts/__init__.py | 0 scripts/distributed_example.sbatch | 34 + scripts/eval_configs.py | 6 + scripts/preprocess_data.py | 6 + scripts/run_checks.sh | 9 + scripts/run_train.py | 6 + setup.cfg | 59 + tests/__init__.py | 0 tests/test_calculator.py | 508 ++++++++ tests/test_cg.py | 12 + tests/test_compile.py | 154 +++ tests/test_data.py | 207 +++ tests/test_foundations.py | 447 +++++++ tests/test_hessian.py | 54 + tests/test_models.py | 251 ++++ tests/test_modules.py | 249 ++++ tests/test_preprocess.py | 166 +++ tests/test_run_train.py | 849 +++++++++++++ tests/test_schedulefree.py | 127 ++ tests/test_tools.py | 48 + 89 files changed, 16828 insertions(+) create mode 100644 .gitattributes create mode 100644 .github/ISSUE_TEMPLATE/bug_report.md create mode 100644 .github/ISSUE_TEMPLATE/feature_request.md create mode 100644 .github/workflows/lint.yml create mode 100644 .github/workflows/pre-commit.yaml create mode 100644 .github/workflows/release.yaml create mode 100644 .github/workflows/unittest.yaml create mode 100644 .gitignore create mode 100644 .mypy.ini create mode 100644 .pre-commit-config.yaml create mode 100644 LICENSE.md create mode 100644 MANIFEST.in create mode 100644 README.md create mode 100644 mace/__init__.py create mode 100644 mace/__version__.py create mode 100644 mace/calculators/__init__.py create mode 100644 mace/calculators/foundations_models.py create mode 100644 mace/calculators/foundations_models/mp_vasp_e0.json create mode 100644 mace/calculators/lammps_mace.py create mode 100644 mace/calculators/mace.py create mode 100644 mace/cli/__init__.py create mode 100644 mace/cli/active_learning_md.py create mode 100644 mace/cli/convert_device.py create mode 100644 mace/cli/create_lammps_model.py create mode 100644 mace/cli/eval_configs.py create mode 100644 mace/cli/fine_tuning_select.py create mode 100644 mace/cli/plot_train.py create mode 100644 mace/cli/preprocess_data.py create mode 100644 mace/cli/run_train.py create mode 100644 mace/cli/select_head.py create mode 100644 mace/data/__init__.py create mode 100644 mace/data/atomic_data.py create mode 100644 mace/data/hdf5_dataset.py create mode 100644 mace/data/neighborhood.py create mode 100644 mace/data/utils.py create mode 100644 mace/modules/__init__.py create mode 100644 mace/modules/blocks.py create mode 100644 mace/modules/irreps_tools.py create mode 100644 mace/modules/loss.py create mode 100644 mace/modules/models.py create mode 100644 mace/modules/radial.py create mode 100644 mace/modules/symmetric_contraction.py create mode 100644 mace/modules/utils.py create mode 100644 mace/py.typed create mode 100644 mace/tools/__init__.py create mode 100644 mace/tools/arg_parser.py create mode 100644 mace/tools/arg_parser_tools.py create mode 100644 mace/tools/cg.py create mode 100644 mace/tools/checkpoint.py create mode 100644 mace/tools/compile.py create mode 100644 mace/tools/finetuning_utils.py create mode 100644 mace/tools/model_script_utils.py create mode 100644 mace/tools/multihead_tools.py create mode 100644 mace/tools/scatter.py create mode 100644 mace/tools/scripts_utils.py create mode 100644 mace/tools/slurm_distributed.py create mode 100644 mace/tools/tables_utils.py create mode 100644 mace/tools/torch_geometric/README.md create mode 100644 mace/tools/torch_geometric/__init__.py create mode 100644 mace/tools/torch_geometric/batch.py create mode 100644 mace/tools/torch_geometric/data.py create mode 100644 mace/tools/torch_geometric/dataloader.py create mode 100644 mace/tools/torch_geometric/dataset.py create mode 100644 mace/tools/torch_geometric/seed.py create mode 100644 mace/tools/torch_geometric/utils.py create mode 100644 mace/tools/torch_tools.py create mode 100644 mace/tools/train.py create mode 100644 mace/tools/utils.py create mode 100644 pyproject.toml create mode 100644 scripts/__init__.py create mode 100644 scripts/distributed_example.sbatch create mode 100644 scripts/eval_configs.py create mode 100644 scripts/preprocess_data.py create mode 100644 scripts/run_checks.sh create mode 100644 scripts/run_train.py create mode 100644 setup.cfg create mode 100644 tests/__init__.py create mode 100644 tests/test_calculator.py create mode 100644 tests/test_cg.py create mode 100644 tests/test_compile.py create mode 100644 tests/test_data.py create mode 100644 tests/test_foundations.py create mode 100644 tests/test_hessian.py create mode 100644 tests/test_models.py create mode 100644 tests/test_modules.py create mode 100644 tests/test_preprocess.py create mode 100644 tests/test_run_train.py create mode 100644 tests/test_schedulefree.py create mode 100644 tests/test_tools.py diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 00000000..dfe07704 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +# Auto detect text files and perform LF normalization +* text=auto diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 00000000..dd84ea78 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,38 @@ +--- +name: Bug report +about: Create a report to help us improve +title: '' +labels: '' +assignees: '' + +--- + +**Describe the bug** +A clear and concise description of what the bug is. + +**To Reproduce** +Steps to reproduce the behavior: +1. Go to '...' +2. Click on '....' +3. Scroll down to '....' +4. See error + +**Expected behavior** +A clear and concise description of what you expected to happen. + +**Screenshots** +If applicable, add screenshots to help explain your problem. + +**Desktop (please complete the following information):** + - OS: [e.g. iOS] + - Browser [e.g. chrome, safari] + - Version [e.g. 22] + +**Smartphone (please complete the following information):** + - Device: [e.g. iPhone6] + - OS: [e.g. iOS8.1] + - Browser [e.g. stock browser, safari] + - Version [e.g. 22] + +**Additional context** +Add any other context about the problem here. diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 00000000..bbcbbe7d --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,20 @@ +--- +name: Feature request +about: Suggest an idea for this project +title: '' +labels: '' +assignees: '' + +--- + +**Is your feature request related to a problem? Please describe.** +A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] + +**Describe the solution you'd like** +A clear and concise description of what you want to happen. + +**Describe alternatives you've considered** +A clear and concise description of any alternative solutions or features you've considered. + +**Additional context** +Add any other context or screenshots about the feature request here. diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 00000000..998ae250 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,46 @@ +name: Linting and code formatting + +on: [] + # Trigger the workflow on push or pull request, + # but only for the main branch + # push: + # branches: [] + # pull_request: + # branches: [] + + +jobs: + build-linux: + runs-on: ubuntu-latest + + steps: + # Setup + - name: Checkout + uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: 3.8.10 + - name: Get cache + uses: actions/cache@v2 + with: + path: /opt/hostedtoolcache/Python/3.8.10/x64/lib/python3.8/site-packages + # Look to see if there is a cache hit for the corresponding requirements file + key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }} + + # Install packages + - name: Install packages required for installation + run: python -m pip install --upgrade pip setuptools wheel + - name: Install dependencies + run: pip install -r requirements.txt + + # Check code + - name: Check formatting with yapf + run: python -m yapf --style=.style.yapf --diff --recursive . +# - name: Lint with flake8 +# run: flake8 --config=.flake8 . +# - name: Check type annotations with mypy +# run: mypy --config-file=.mypy.ini . + + - name: Test with pytest + run: python -m pytest tests diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml new file mode 100644 index 00000000..57e5d142 --- /dev/null +++ b/.github/workflows/pre-commit.yaml @@ -0,0 +1,28 @@ +name: Pre-Commit Checks + +on: + pull_request: + push: + branches: [main] + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.10" + cache: "pip" + - name: Install requirements + run: | + pip install -U pip + pip install pylint + pip install -U black + pip install .[dev] + pip install wandb + pip install tqdm + - name: Run black + run: | + python -m black . + - uses: pre-commit/action@v3.0.0 diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml new file mode 100644 index 00000000..7c6c82f6 --- /dev/null +++ b/.github/workflows/release.yaml @@ -0,0 +1,36 @@ +name: Release on tag + +on: + push: + tags: + # After vMajor.Minor.Patch _anything_ is allowed (without "/") ! + - v[0-9]+.[0-9]+.[0-9]+* + +jobs: + publish: + runs-on: ubuntu-latest + if: github.repository == 'ACEsuit/mace' && startsWith(github.ref, 'refs/tags/v') + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Python 3.10 + uses: actions/setup-python@v5 + with: + python-version: "3.10" + cache: "pip" + + - name: Build project for distribution + run: | + python -m pip install --upgrade pip + python -m pip install build + python -m build + + - name: Create Release + uses: ncipollo/release-action@v1 + with: + artifacts: "dist/*" + token: ${{ secrets.GITHUB_TOKEN }} + draft: false + skipIfReleaseExists: true diff --git a/.github/workflows/unittest.yaml b/.github/workflows/unittest.yaml new file mode 100644 index 00000000..857bf894 --- /dev/null +++ b/.github/workflows/unittest.yaml @@ -0,0 +1,29 @@ +name: unit tests +on: + pull_request: + push: + branches: [main] + +jobs: + pytest-container: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.10" + cache: "pip" + + - name: Install requirements + run: | + pip install -U pip + pip install .[dev] + + - name: Log installed environment + run: | + python3 -m pip freeze + + - name: Run unit tests + run: | + pytest tests diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..3817d9f3 --- /dev/null +++ b/.gitignore @@ -0,0 +1,35 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# Distribution / packaging +build/ +*.egg-info/ + +# pytest +.pytest_cache/ + +# mypy +.mypy_cache/ + +# IDE +.idea/ +.vscode/ +*.txt +*.log + +# Distribution +dist/ + +# Jupyter Notebook +.ipynb_checkpoints + +# DS_Store +.DS_Store +*.models +*.pt +/wandb +*.xyz +/checkpoints +*.model diff --git a/.mypy.ini b/.mypy.ini new file mode 100644 index 00000000..7d3c0433 --- /dev/null +++ b/.mypy.ini @@ -0,0 +1,39 @@ +[mypy] +# Platform configuration +python_version = 3.8 + +# Untyped definitions and calls +check_untyped_defs = True + +[mypy-mace.tools.torch_geometric.*] +ignore_errors = True + +[mypy-mace.tools.scatter] +ignore_errors = True + +[mypy-setuptools.*] +ignore_missing_imports = True + +[mypy-e3nn.*] +ignore_missing_imports = True + +[mypy-scipy.*] +ignore_missing_imports = True + +[mypy-ase.*] +ignore_missing_imports = True + +[mypy-prettytable.*] +ignore_missing_imports = True + +[mypy-torch_ema.*] +ignore_missing_imports = True + +[mypy-matplotlib.*] +ignore_missing_imports = True + +[mypy-pandas.*] +ignore_missing_imports = True + +[mypy-opt_einsum.*] +ignore_missing_imports = True diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..d78624bb --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,59 @@ +exclude: &exclude_files > + (?x)^( + docs/.*| + tests/.*| + .github/.*| + LICENSE.md| + README.md| + )$ + +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v2.5.0 + hooks: + - id: mixed-line-ending + - id: trailing-whitespace + exclude: *exclude_files + + - repo: https://github.com/psf/black-pre-commit-mirror + rev: 24.4.0 + hooks: + - id: black + name: Black Formating + exclude: *exclude_files + + - repo: https://github.com/pycqa/isort + rev: 5.13.2 + hooks: + - id: isort + name: Sort imports + exclude: *exclude_files + + - repo: https://github.com/PyCQA/pylint + rev: pylint-2.5.2 + hooks: + - id: pylint + language: system + args: [ + '--disable=line-too-long', + '--disable=no-member', + '--disable=missing-module-docstring', + '--disable=missing-class-docstring', + '--disable=missing-function-docstring', + '--disable=too-many-arguments', + '--disable=too-many-positional-arguments', + '--disable=too-many-locals', + '--disable=not-callable', + '--disable=logging-fstring-interpolation', + '--disable=logging-not-lazy', + '--disable=invalid-name', + '--disable=too-few-public-methods', + '--disable=too-many-instance-attributes', + '--disable=too-many-statements', + '--disable=too-many-branches', + '--disable=import-outside-toplevel', + '--disable=cell-var-from-loop', + '--disable=duplicate-code', + '--disable=use-dict-literal', + ] + exclude: *exclude_files \ No newline at end of file diff --git a/LICENSE.md b/LICENSE.md new file mode 100644 index 00000000..b4ec6253 --- /dev/null +++ b/LICENSE.md @@ -0,0 +1,13 @@ +MIT License + +Copyright (c) 2022 ACEsuit/mace + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 00000000..cf8dae6c --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +include mace/py.typed diff --git a/README.md b/README.md new file mode 100644 index 00000000..8481760d --- /dev/null +++ b/README.md @@ -0,0 +1,349 @@ +# MACE + +[![GitHub release](https://img.shields.io/github/release/ACEsuit/mace.svg)](https://GitHub.com/ACEsuit/mace/releases/) +[![Paper](https://img.shields.io/badge/Paper-NeurIPs2022-blue)](https://openreview.net/forum?id=YPpSngE-ZU) +[![License](https://img.shields.io/badge/License-MIT%202.0-blue.svg)](https://opensource.org/licenses/mit) +[![GitHub issues](https://img.shields.io/github/issues/ACEsuit/mace.svg)](https://GitHub.com/ACEsuit/mace/issues/) +[![Documentation Status](https://readthedocs.org/projects/mace/badge/)](https://mace-docs.readthedocs.io/en/latest/) +[![DOI](https://zenodo.org/badge/505964914.svg)](https://doi.org/10.5281/zenodo.14103332) + +## Table of contents + +- [MACE](#mace) + - [Table of contents](#table-of-contents) + - [About MACE](#about-mace) + - [Documentation](#documentation) + - [Installation](#installation) + - [pip installation](#installation-from-pypi) + - [pip installation from source](#installation-from-source) + - [Usage](#usage) + - [Training](#training) + - [Evaluation](#evaluation) + - [Tutorials](#tutorials) + - [Weights and Biases for experiment tracking](#weights-and-biases-for-experiment-tracking) + - [Pretrained Foundation Models](#pretrained-foundation-models) + - [MACE-MP: Materials Project Force Fields](#mace-mp-materials-project-force-fields) + - [Example usage in ASE](#example-usage-in-ase) + - [MACE-OFF: Transferable Organic Force Fields](#mace-off-transferable-organic-force-fields) + - [Example usage in ASE](#example-usage-in-ase-1) + - [Finetuning foundation models](#finetuning-foundation-models) + - [Development](#development) + - [References](#references) + - [Contact](#contact) + - [License](#license) + +## About MACE + +MACE provides fast and accurate machine learning interatomic potentials with higher order equivariant message passing. + +This repository contains the MACE reference implementation developed by +Ilyes Batatia, Gregor Simm, David Kovacs, and the group of Gabor Csanyi, and friends (see Contributors). + +Also available: + +- [MACE in JAX](https://github.com/ACEsuit/mace-jax), currently about 2x times faster at evaluation, but training is recommended in Pytorch for optimal performances. +- [MACE layers](https://github.com/ACEsuit/mace-layer) for constructing higher order equivariant graph neural networks for arbitrary 3D point clouds. + +## Documentation + +A partial documentation is available at: https://mace-docs.readthedocs.io + +## Installation + +### 1. Requirements: + +- Python >= 3.7 (for openMM, use Python = 3.9) +- [PyTorch](https://pytorch.org/) >= 1.12 **(training with float64 is not supported with PyTorch 2.1 but is supported with 2.2 and later, Pytorch 2.4.1 is not supported)** + +**Make sure to install PyTorch.** Please refer to the [official PyTorch installation](https://pytorch.org/get-started/locally/) for the installation instructions. Select the appropriate options for your system. + +### Installation from PyPI +This is the recommended way to install MACE. + +```sh +pip install --upgrade pip +pip install mace-torch +``` +**Note:** The homonymous package on [PyPI](https://pypi.org/project/MACE/) has nothing to do with this one. + + +### Installation from source + + +```sh +git clone https://github.com/ACEsuit/mace.git +pip install ./mace +``` + + + + +## Usage + +### Training + +To train a MACE model, you can use the `mace_run_train` script, which should be in the usual place that pip places binaries (or you can explicitly run `python3 /mace/cli/run_train.py`) + +```sh +mace_run_train \ + --name="MACE_model" \ + --train_file="train.xyz" \ + --valid_fraction=0.05 \ + --test_file="test.xyz" \ + --config_type_weights='{"Default":1.0}' \ + --E0s='{1:-13.663181292231226, 6:-1029.2809654211628, 7:-1484.1187695035828, 8:-2042.0330099956639}' \ + --model="MACE" \ + --hidden_irreps='128x0e + 128x1o' \ + --r_max=5.0 \ + --batch_size=10 \ + --max_num_epochs=1500 \ + --swa \ + --start_swa=1200 \ + --ema \ + --ema_decay=0.99 \ + --amsgrad \ + --restart_latest \ + --device=cuda \ +``` + +To give a specific validation set, use the argument `--valid_file`. To set a larger batch size for evaluating the validation set, specify `--valid_batch_size`. + +To control the model's size, you need to change `--hidden_irreps`. For most applications, the recommended default model size is `--hidden_irreps='256x0e'` (meaning 256 invariant messages) or `--hidden_irreps='128x0e + 128x1o'`. If the model is not accurate enough, you can include higher order features, e.g., `128x0e + 128x1o + 128x2e`, or increase the number of channels to `256`. It is also possible to specify the model using the `--num_channels=128` and `--max_L=1`keys. + +It is usually preferred to add the isolated atoms to the training set, rather than reading in their energies through the command line like in the example above. To label them in the training set, set `config_type=IsolatedAtom` in their info fields. If you prefer not to use or do not know the energies of the isolated atoms, you can use the option `--E0s="average"` which estimates the atomic energies using least squares regression. + +If the keyword `--swa` is enabled, the energy weight of the loss is increased for the last ~20% of the training epochs (from `--start_swa` epochs). This setting usually helps lower the energy errors. + +The precision can be changed using the keyword `--default_dtype`, the default is `float64` but `float32` gives a significant speed-up (usually a factor of x2 in training). + +The keywords `--batch_size` and `--max_num_epochs` should be adapted based on the size of the training set. The batch size should be increased when the number of training data increases, and the number of epochs should be decreased. An heuristic for initial settings, is to consider the number of gradient update constant to 200 000, which can be computed as $\text{max-num-epochs}*\frac{\text{num-configs-training}}{\text{batch-size}}$. + +The code can handle training set with heterogeneous labels, for example containing both bulk structures with stress and isolated molecules. In this example, to make the code ignore stress on molecules, append to your molecules configuration a `config_stress_weight = 0.0`. + +#### Apple Silicon GPU acceleration + +To use Apple Silicon GPU acceleration make sure to install the latest PyTorch version and specify `--device=mps`. + +#### Multi-GPU training + +For multi-GPU training, use the `--distributed` flag. This will use PyTorch's DistributedDataParallel module to train the model on multiple GPUs. Combine with on-line data loading for large datasets (see below). An example slurm script can be found in `mace/scripts/distributed_example.sbatch`. + +#### YAML configuration + +Option to parse all or some arguments using a YAML is available. For example, to train a model using the arguments above, you can create a YAML file `your_configs.yaml` with the following content: + +```yaml +name: nacl +seed: 2024 +train_file: train.xyz +swa: yes +start_swa: 1200 +max_num_epochs: 1500 +device: cpu +test_file: test.xyz +E0s: + 41: -1029.2809654211628 + 38: -1484.1187695035828 + 8: -2042.0330099956639 +config_type_weights: + Default: 1.0 + +``` +And append to the command line `--config="your_configs.yaml"`. Any argument specified in the command line will overwrite the one in the YAML file. + +### Evaluation + +To evaluate your MACE model on an XYZ file, run the `mace_eval_configs`: + +```sh +mace_eval_configs \ + --configs="your_configs.xyz" \ + --model="your_model.model" \ + --output="./your_output.xyz" +``` + +## Tutorials + +You can run our [Colab tutorial](https://colab.research.google.com/drive/1D6EtMUjQPey_GkuxUAbPgld6_9ibIa-V?authuser=1#scrollTo=Z10787RE1N8T) to quickly get started with MACE. + +We also have a more detailed Colab tutorials on: + - [Introduction to MACE training and evaluation](https://colab.research.google.com/drive/1ZrTuTvavXiCxTFyjBV4GqlARxgFwYAtX) + - [Introduction to MACE active learning and fine-tuning](https://colab.research.google.com/drive/1oCSVfMhWrqHTeHbKgUSQN9hTKxLzoNyb) + - [MACE theory and code (advanced)](https://colab.research.google.com/drive/1AlfjQETV_jZ0JQnV5M3FGwAM2SGCl2aU) + + +## On-line data loading for large datasets + +If you have a large dataset that might not fit into the GPU memory it is recommended to preprocess the data on a CPU and use on-line dataloading for training the model. To preprocess your dataset specified as an xyz file run the `preprocess_data.py` script. An example is given here: + +```sh +mkdir processed_data +python ./mace/scripts/preprocess_data.py \ + --train_file="/path/to/train_large.xyz" \ + --valid_fraction=0.05 \ + --test_file="/path/to/test_large.xyz" \ + --atomic_numbers="[1, 6, 7, 8, 9, 15, 16, 17, 35, 53]" \ + --r_max=4.5 \ + --h5_prefix="processed_data/" \ + --compute_statistics \ + --E0s="average" \ + --seed=123 \ +``` + +To see all options and a little description of them run `python ./mace/scripts/preprocess_data.py --help` . The script will create a number of HDF5 files in the `processed_data` folder which can be used for training. There will be one folder for training, one for validation and a separate one for each `config_type` in the test set. To train the model use the `run_train.py` script as follows: + +```sh +python ./mace/scripts/run_train.py \ + --name="MACE_on_big_data" \ + --num_workers=16 \ + --train_file="./processed_data/train.h5" \ + --valid_file="./processed_data/valid.h5" \ + --test_dir="./processed_data" \ + --statistics_file="./processed_data/statistics.json" \ + --model="ScaleShiftMACE" \ + --num_interactions=2 \ + --num_channels=128 \ + --max_L=1 \ + --correlation=3 \ + --batch_size=32 \ + --valid_batch_size=32 \ + --max_num_epochs=100 \ + --swa \ + --start_swa=60 \ + --ema \ + --ema_decay=0.99 \ + --amsgrad \ + --error_table='PerAtomMAE' \ + --device=cuda \ + --seed=123 \ +``` + +## Weights and Biases for experiment tracking + +If you would like to use MACE with Weights and Biases to log your experiments simply install with + +```sh +pip install ./mace[wandb] +``` + +And specify the necessary keyword arguments (`--wandb`, `--wandb_project`, `--wandb_entity`, `--wandb_name`, `--wandb_log_hypers`) + + +## Pretrained Foundation Models + +### MACE-MP: Materials Project Force Fields + +We have collaborated with the Materials Project (MP) to train a universal MACE potential covering 89 elements on 1.6 M bulk crystals in the [MPTrj dataset](https://figshare.com/articles/dataset/23713842) selected from MP relaxation trajectories. +The models are releaed on GitHub at https://github.com/ACEsuit/mace-mp. +If you use them please cite [our paper](https://arxiv.org/abs/2401.00096) which also contains an large range of example applications and benchmarks. + +> [!CAUTION] +> The MACE-MP models are trained on MPTrj raw DFT energies from VASP outputs, and are not directly comparable to the MP's DFT energies or CHGNet's energies, which have been applied MP2020Compatibility corrections for some transition metal oxides, fluorides (GGA/GGA+U mixing corrections), and 14 anions species (anion corrections). For more details, please refer to the [MP Documentation](https://docs.materialsproject.org/methodology/materials-methodology/thermodynamic-stability/thermodynamic-stability/anion-and-gga-gga+u-mixing) and [MP2020Compatibility.yaml](https://github.com/materialsproject/pymatgen/blob/master/pymatgen/entries/MP2020Compatibility.yaml). + +#### Example usage in ASE +```py +from mace.calculators import mace_mp +from ase import build + +atoms = build.molecule('H2O') +calc = mace_mp(model="medium", dispersion=False, default_dtype="float32", device='cuda') +atoms.calc = calc +print(atoms.get_potential_energy()) +``` + +### MACE-OFF: Transferable Organic Force Fields + +There is a series (small, medium, large) transferable organic force fields. These can be used for the simulation of organic molecules, crystals and molecular liquids, or as a starting point for fine-tuning on a new dataset. The models are released under the [ASL license](https://github.com/gabor1/ASL). +The models are releaed on GitHub at https://github.com/ACEsuit/mace-off. +If you use them please cite [our paper](https://arxiv.org/abs/2312.15211) which also contains detailed benchmarks and example applications. + +#### Example usage in ASE +```py +from mace.calculators import mace_off +from ase import build + +atoms = build.molecule('H2O') +calc = mace_off(model="medium", device='cuda') +atoms.calc = calc +print(atoms.get_potential_energy()) +``` + +### Finetuning foundation models + +To finetune one of the mace-mp-0 foundation model, you can use the `mace_run_train` script with the extra argument `--foundation_model=model_type`. For example to finetune the small model on a new dataset, you can use: + +```sh +mace_run_train \ + --name="MACE" \ + --foundation_model="small" \ + --train_file="train.xyz" \ + --valid_fraction=0.05 \ + --test_file="test.xyz" \ + --energy_weight=1.0 \ + --forces_weight=1.0 \ + --E0s="average" \ + --lr=0.01 \ + --scaling="rms_forces_scaling" \ + --batch_size=2 \ + --max_num_epochs=6 \ + --ema \ + --ema_decay=0.99 \ + --amsgrad \ + --default_dtype="float32" \ + --device=cuda \ + --seed=3 +``` +Other options are "medium" and "large", or the path to a foundation model. +If you want to finetune another model, the model will be loaded from the path provided `--foundation_model=$path_model`, but you will need to provide the full set of hyperparameters (hidden irreps, r_max, etc.) matching the model. + +## Development + +This project uses [pre-commit](https://pre-commit.com/) to execute code formatting and linting on commit. +We also use `black`, `isort`, `pylint`, and `mypy`. +We recommend setting up your development environment by installing the `dev` packages +into your python environment: +```bash +pip install -e ".[dev]" +pre-commit install +``` +The second line will initialise `pre-commit` to automaticaly run code checks on commit. +We have CI set up to check this, but we _highly_ recommend that you run those commands +before you commit (and push) to avoid accidentally committing bad code. + +We are happy to accept pull requests under an [MIT license](https://choosealicense.com/licenses/mit/). Please copy/paste the license text as a comment into your pull request. + +## References + +If you use this code, please cite our papers: + +```bibtex +@inproceedings{Batatia2022mace, + title={{MACE}: Higher Order Equivariant Message Passing Neural Networks for Fast and Accurate Force Fields}, + author={Ilyes Batatia and David Peter Kovacs and Gregor N. C. Simm and Christoph Ortner and Gabor Csanyi}, + booktitle={Advances in Neural Information Processing Systems}, + editor={Alice H. Oh and Alekh Agarwal and Danielle Belgrave and Kyunghyun Cho}, + year={2022}, + url={https://openreview.net/forum?id=YPpSngE-ZU} +} + +@misc{Batatia2022Design, + title = {The Design Space of E(3)-Equivariant Atom-Centered Interatomic Potentials}, + author = {Batatia, Ilyes and Batzner, Simon and Kov{\'a}cs, D{\'a}vid P{\'e}ter and Musaelian, Albert and Simm, Gregor N. C. and Drautz, Ralf and Ortner, Christoph and Kozinsky, Boris and Cs{\'a}nyi, G{\'a}bor}, + year = {2022}, + number = {arXiv:2205.06643}, + eprint = {2205.06643}, + eprinttype = {arxiv}, + doi = {10.48550/arXiv.2205.06643}, + archiveprefix = {arXiv} + } +``` + +## Contact + +If you have any questions, please contact us at ilyes.batatia@ens-paris-saclay.fr. + +For bugs or feature requests, please use [GitHub Issues](https://github.com/ACEsuit/mace/issues). + +## License + +MACE is published and distributed under the [MIT License](MIT.md). diff --git a/mace/__init__.py b/mace/__init__.py new file mode 100644 index 00000000..8ad80243 --- /dev/null +++ b/mace/__init__.py @@ -0,0 +1,71 @@ +from .arg_parser import build_default_arg_parser, build_preprocess_arg_parser +from .arg_parser_tools import check_args +from .cg import U_matrix_real +from .checkpoint import CheckpointHandler, CheckpointIO, CheckpointState +from .finetuning_utils import load_foundations, load_foundations_elements +from .torch_tools import ( + TensorDict, + cartesian_to_spherical, + count_parameters, + init_device, + init_wandb, + set_default_dtype, + set_seeds, + spherical_to_cartesian, + to_numpy, + to_one_hot, + voigt_to_matrix, +) +from .train import SWAContainer, evaluate, train +from .utils import ( + AtomicNumberTable, + MetricsLogger, + atomic_numbers_to_indices, + compute_c, + compute_mae, + compute_q95, + compute_rel_mae, + compute_rel_rmse, + compute_rmse, + get_atomic_number_table_from_zs, + get_tag, + setup_logger, +) + +__all__ = [ + "TensorDict", + "AtomicNumberTable", + "atomic_numbers_to_indices", + "to_numpy", + "to_one_hot", + "build_default_arg_parser", + "check_args", + "set_seeds", + "init_device", + "setup_logger", + "get_tag", + "count_parameters", + "MetricsLogger", + "get_atomic_number_table_from_zs", + "train", + "evaluate", + "SWAContainer", + "CheckpointHandler", + "CheckpointIO", + "CheckpointState", + "set_default_dtype", + "compute_mae", + "compute_rel_mae", + "compute_rmse", + "compute_rel_rmse", + "compute_q95", + "compute_c", + "U_matrix_real", + "spherical_to_cartesian", + "cartesian_to_spherical", + "voigt_to_matrix", + "init_wandb", + "load_foundations", + "load_foundations_elements", + "build_preprocess_arg_parser", +] diff --git a/mace/__version__.py b/mace/__version__.py new file mode 100644 index 00000000..2eb279ae --- /dev/null +++ b/mace/__version__.py @@ -0,0 +1,3 @@ +__version__ = "0.3.8" + +__all__ = ["__version__"] diff --git a/mace/calculators/__init__.py b/mace/calculators/__init__.py new file mode 100644 index 00000000..8511eb9e --- /dev/null +++ b/mace/calculators/__init__.py @@ -0,0 +1,11 @@ +from .foundations_models import mace_anicc, mace_mp, mace_off +from .lammps_mace import LAMMPS_MACE +from .mace import MACECalculator + +__all__ = [ + "MACECalculator", + "LAMMPS_MACE", + "mace_mp", + "mace_off", + "mace_anicc", +] diff --git a/mace/calculators/foundations_models.py b/mace/calculators/foundations_models.py new file mode 100644 index 00000000..ed814f1a --- /dev/null +++ b/mace/calculators/foundations_models.py @@ -0,0 +1,262 @@ +import os +import urllib.request +from pathlib import Path +from typing import Union + +import torch +from ase import units +from ase.calculators.mixing import SumCalculator + +from .mace import MACECalculator + +module_dir = os.path.dirname(__file__) +local_model_path = os.path.join( + module_dir, "foundations_models/2023-12-03-mace-mp.model" +) + + +def download_mace_mp_checkpoint(model: Union[str, Path] = None) -> str: + """ + Downloads or locates the MACE-MP checkpoint file. + + Args: + model (str, optional): Path to the model or size specification. + Defaults to None which uses the medium model. + + Returns: + str: Path to the downloaded (or cached, if previously loaded) checkpoint file. + """ + if model in (None, "medium") and os.path.isfile(local_model_path): + return local_model_path + + urls = { + "small": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-10-mace-128-L0_energy_epoch-249.model", + "medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-03-mace-128-L1_epoch-199.model", + "large": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/MACE_MPtrj_2022.9.model", + } + + checkpoint_url = ( + urls.get(model, urls["medium"]) + if model in (None, "small", "medium", "large") + else model + ) + + cache_dir = os.path.expanduser("~/.cache/mace") + checkpoint_url_name = "".join( + c for c in os.path.basename(checkpoint_url) if c.isalnum() or c in "_" + ) + cached_model_path = f"{cache_dir}/{checkpoint_url_name}" + + if not os.path.isfile(cached_model_path): + os.makedirs(cache_dir, exist_ok=True) + print(f"Downloading MACE model from {checkpoint_url!r}") + _, http_msg = urllib.request.urlretrieve(checkpoint_url, cached_model_path) + if "Content-Type: text/html" in http_msg: + raise RuntimeError( + f"Model download failed, please check the URL {checkpoint_url}" + ) + print(f"Cached MACE model to {cached_model_path}") + + return cached_model_path + + +def mace_mp( + model: Union[str, Path] = None, + device: str = "", + default_dtype: str = "float32", + dispersion: bool = False, + damping: str = "bj", # choices: ["zero", "bj", "zerom", "bjm"] + dispersion_xc: str = "pbe", + dispersion_cutoff: float = 40.0 * units.Bohr, + return_raw_model: bool = False, + **kwargs, +) -> MACECalculator: + """ + Constructs a MACECalculator with a pretrained model based on the Materials Project (89 elements). + The model is released under the MIT license. See https://github.com/ACEsuit/mace-mp for all models. + Note: + If you are using this function, please cite the relevant paper for the Materials Project, + any paper associated with the MACE model, and also the following: + - MACE-MP by Ilyes Batatia, Philipp Benner, Yuan Chiang, Alin M. Elena, + Dávid P. Kovács, Janosh Riebesell, et al., 2023, arXiv:2401.00096 + - MACE-Universal by Yuan Chiang, 2023, Hugging Face, Revision e5ebd9b, + DOI: 10.57967/hf/1202, URL: https://huggingface.co/cyrusyc/mace-universal + - Matbench Discovery by Janosh Riebesell, Rhys EA Goodall, Philipp Benner, Yuan Chiang, + Alpha A Lee, Anubhav Jain, Kristin A Persson, 2023, arXiv:2308.14920 + + Args: + model (str, optional): Path to the model. Defaults to None which first checks for + a local model and then downloads the default model from figshare. Specify "small", + "medium" or "large" to download a smaller or larger model from figshare. + device (str, optional): Device to use for the model. Defaults to "cuda" if available. + default_dtype (str, optional): Default dtype for the model. Defaults to "float32". + dispersion (bool, optional): Whether to use D3 dispersion corrections. Defaults to False. + damping (str): The damping function associated with the D3 correction. Defaults to "bj" for D3(BJ). + dispersion_xc (str, optional): Exchange-correlation functional for D3 dispersion corrections. + dispersion_cutoff (float, optional): Cutoff radius in Bohr for D3 dispersion corrections. + return_raw_model (bool, optional): Whether to return the raw model or an ASE calculator. Defaults to False. + **kwargs: Passed to MACECalculator and TorchDFTD3Calculator. + + Returns: + MACECalculator: trained on the MPtrj dataset (unless model otherwise specified). + """ + try: + if model in (None, "small", "medium", "large") or str(model).startswith( + "https:" + ): + model_path = download_mace_mp_checkpoint(model) + print(f"Using Materials Project MACE for MACECalculator with {model_path}") + else: + if not Path(model).exists(): + raise FileNotFoundError(f"{model} not found locally") + model_path = model + except Exception as exc: + raise RuntimeError("Model download failed and no local model found") from exc + + device = device or ("cuda" if torch.cuda.is_available() else "cpu") + if default_dtype == "float64": + print( + "Using float64 for MACECalculator, which is slower but more accurate. Recommended for geometry optimization." + ) + if default_dtype == "float32": + print( + "Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization." + ) + + if return_raw_model: + return torch.load(model_path, map_location=device) + + mace_calc = MACECalculator( + model_paths=model_path, device=device, default_dtype=default_dtype, **kwargs + ) + + if not dispersion: + return mace_calc + + try: + from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator + except ImportError as exc: + raise RuntimeError( + "Please install torch-dftd to use dispersion corrections (see https://github.com/pfnet-research/torch-dftd)" + ) from exc + + print("Using TorchDFTD3Calculator for D3 dispersion corrections") + dtype = torch.float32 if default_dtype == "float32" else torch.float64 + d3_calc = TorchDFTD3Calculator( + device=device, + damping=damping, + dtype=dtype, + xc=dispersion_xc, + cutoff=dispersion_cutoff, + **kwargs, + ) + + return SumCalculator([mace_calc, d3_calc]) + + +def mace_off( + model: Union[str, Path] = None, + device: str = "", + default_dtype: str = "float64", + return_raw_model: bool = False, + **kwargs, +) -> MACECalculator: + """ + Constructs a MACECalculator with a pretrained model based on the MACE-OFF23 models. + The model is released under the ASL license. + Note: + If you are using this function, please cite the relevant paper by Kovacs et.al., arXiv:2312.15211 + + Args: + model (str, optional): Path to the model. Defaults to None which first checks for + a local model and then downloads the default medium model from https://github.com/ACEsuit/mace-off. + Specify "small", "medium" or "large" to download a smaller or larger model. + device (str, optional): Device to use for the model. Defaults to "cuda". + default_dtype (str, optional): Default dtype for the model. Defaults to "float64". + return_raw_model (bool, optional): Whether to return the raw model or an ASE calculator. Defaults to False. + **kwargs: Passed to MACECalculator. + + Returns: + MACECalculator: trained on the MACE-OFF23 dataset + """ + try: + if model in (None, "small", "medium", "large") or str(model).startswith( + "https:" + ): + urls = dict( + small="https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_small.model?raw=true", + medium="https://github.com/ACEsuit/mace-off/raw/main/mace_off23/MACE-OFF23_medium.model?raw=true", + large="https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_large.model?raw=true", + ) + checkpoint_url = ( + urls.get(model, urls["medium"]) + if model in (None, "small", "medium", "large") + else model + ) + cache_dir = os.path.expanduser("~/.cache/mace") + checkpoint_url_name = os.path.basename(checkpoint_url).split("?")[0] + cached_model_path = f"{cache_dir}/{checkpoint_url_name}" + if not os.path.isfile(cached_model_path): + os.makedirs(cache_dir, exist_ok=True) + # download and save to disk + print(f"Downloading MACE model from {checkpoint_url!r}") + print( + "The model is distributed under the Academic Software License (ASL) license, see https://github.com/gabor1/ASL \n To use the model you accept the terms of the license." + ) + print( + "ASL is based on the Gnu Public License, but does not permit commercial use" + ) + urllib.request.urlretrieve(checkpoint_url, cached_model_path) + print(f"Cached MACE model to {cached_model_path}") + model = cached_model_path + msg = f"Using MACE-OFF23 MODEL for MACECalculator with {model}" + print(msg) + else: + if not Path(model).exists(): + raise FileNotFoundError(f"{model} not found locally") + except Exception as exc: + raise RuntimeError("Model download failed and no local model found") from exc + + device = device or ("cuda" if torch.cuda.is_available() else "cpu") + + if return_raw_model: + return torch.load(model, map_location=device) + + if default_dtype == "float64": + print( + "Using float64 for MACECalculator, which is slower but more accurate. Recommended for geometry optimization." + ) + if default_dtype == "float32": + print( + "Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization." + ) + mace_calc = MACECalculator( + model_paths=model, device=device, default_dtype=default_dtype, **kwargs + ) + return mace_calc + + +def mace_anicc( + device: str = "cuda", + model_path: str = None, + return_raw_model: bool = False, +) -> MACECalculator: + """ + Constructs a MACECalculator with a pretrained model based on the ANI (H, C, N, O). + The model is released under the MIT license. + Note: + If you are using this function, please cite the relevant paper associated with the MACE model, ANI dataset, and also the following: + - "Evaluation of the MACE Force Field Architecture by Dávid Péter Kovács, Ilyes Batatia, Eszter Sára Arany, and Gábor Csányi, The Journal of Chemical Physics, 2023, URL: https://doi.org/10.1063/5.0155322 + """ + if model_path is None: + model_path = os.path.join( + module_dir, "foundations_models/ani500k_large_CC.model" + ) + print( + "Using ANI couple cluster model for MACECalculator, see https://doi.org/10.1063/5.0155322" + ) + if return_raw_model: + return torch.load(model_path, map_location=device) + return MACECalculator( + model_paths=model_path, device=device, default_dtype="float64" + ) diff --git a/mace/calculators/foundations_models/mp_vasp_e0.json b/mace/calculators/foundations_models/mp_vasp_e0.json new file mode 100644 index 00000000..01771879 --- /dev/null +++ b/mace/calculators/foundations_models/mp_vasp_e0.json @@ -0,0 +1,91 @@ +{ + "pbe": { + "1": -1.11734008, + "2": 0.00096759, + "3": -0.29754725, + "4": -0.01781697, + "5": -0.26885011, + "6": -1.26173507, + "7": -3.12438806, + "8": -1.54838784, + "9": -0.51882044, + "10": -0.01241601, + "11": -0.22883163, + "12": -0.00951015, + "13": -0.21630193, + "14": -0.8263903, + "15": -1.88816619, + "16": -0.89160769, + "17": -0.25828273, + "18": -0.04925973, + "19": -0.22697913, + "20": -0.0927795, + "21": -2.11396364, + "22": -2.50054871, + "23": -3.70477179, + "24": -5.60261985, + "25": -5.32541181, + "26": -3.52004933, + "27": -1.93555024, + "28": -0.9351969, + "29": -0.60025846, + "30": -0.1651332, + "31": -0.32990651, + "32": -0.77971828, + "33": -1.68367812, + "34": -0.76941032, + "35": -0.22213843, + "36": -0.0335879, + "37": -0.1881724, + "38": -0.06826294, + "39": -2.17084228, + "40": -2.28579303, + "41": -3.13429753, + "42": -4.60211419, + "43": -3.45201492, + "44": -2.38073513, + "45": -1.46855515, + "46": -1.4773126, + "47": -0.33954585, + "48": -0.16843877, + "49": -0.35470981, + "50": -0.83642657, + "51": -1.41101987, + "52": -0.65740879, + "53": -0.18964571, + "54": -0.00857582, + "55": -0.13771876, + "56": -0.03457659, + "57": -0.45580806, + "58": -1.3309175, + "59": -0.29671824, + "60": -0.30391193, + "61": -0.30898427, + "62": -0.25470891, + "63": -8.38001538, + "64": -10.38896525, + "65": -0.3059505, + "66": -0.30676216, + "67": -0.30874667, + "69": -0.25190039, + "70": -0.06431414, + "71": -0.31997586, + "72": -3.52770927, + "73": -3.54492209, + "75": -4.70108713, + "76": -2.88257209, + "77": -1.46779304, + "78": -0.50269936, + "79": -0.28801193, + "80": -0.12454674, + "81": -0.31737194, + "82": -0.77644932, + "83": -1.32627283, + "89": -0.26827152, + "90": -0.90817426, + "91": -2.47653193, + "92": -4.90438537, + "93": -7.63378961, + "94": -10.77237713 + } +} \ No newline at end of file diff --git a/mace/calculators/lammps_mace.py b/mace/calculators/lammps_mace.py new file mode 100644 index 00000000..4211c37f --- /dev/null +++ b/mace/calculators/lammps_mace.py @@ -0,0 +1,105 @@ +from typing import Dict, List, Optional + +import torch +from e3nn.util.jit import compile_mode + +from mace.tools.scatter import scatter_sum + + +@compile_mode("script") +class LAMMPS_MACE(torch.nn.Module): + def __init__(self, model, **kwargs): + super().__init__() + self.model = model + self.register_buffer("atomic_numbers", model.atomic_numbers) + self.register_buffer("r_max", model.r_max) + self.register_buffer("num_interactions", model.num_interactions) + if not hasattr(model, "heads"): + model.heads = [None] + self.register_buffer( + "head", + torch.tensor( + self.model.heads.index(kwargs.get("head", self.model.heads[-1])), + dtype=torch.long, + ).unsqueeze(0), + ) + + for param in self.model.parameters(): + param.requires_grad = False + + def forward( + self, + data: Dict[str, torch.Tensor], + local_or_ghost: torch.Tensor, + compute_virials: bool = False, + ) -> Dict[str, Optional[torch.Tensor]]: + num_graphs = data["ptr"].numel() - 1 + compute_displacement = False + if compute_virials: + compute_displacement = True + data["head"] = self.head + out = self.model( + data, + training=False, + compute_force=False, + compute_virials=False, + compute_stress=False, + compute_displacement=compute_displacement, + ) + node_energy = out["node_energy"] + if node_energy is None: + return { + "total_energy_local": None, + "node_energy": None, + "forces": None, + "virials": None, + } + positions = data["positions"] + displacement = out["displacement"] + forces: Optional[torch.Tensor] = torch.zeros_like(positions) + virials: Optional[torch.Tensor] = torch.zeros_like(data["cell"]) + # accumulate energies of local atoms + node_energy_local = node_energy * local_or_ghost + total_energy_local = scatter_sum( + src=node_energy_local, index=data["batch"], dim=-1, dim_size=num_graphs + ) + # compute partial forces and (possibly) partial virials + grad_outputs: List[Optional[torch.Tensor]] = [ + torch.ones_like(total_energy_local) + ] + if compute_virials and displacement is not None: + forces, virials = torch.autograd.grad( + outputs=[total_energy_local], + inputs=[positions, displacement], + grad_outputs=grad_outputs, + retain_graph=False, + create_graph=False, + allow_unused=True, + ) + if forces is not None: + forces = -1 * forces + else: + forces = torch.zeros_like(positions) + if virials is not None: + virials = -1 * virials + else: + virials = torch.zeros_like(displacement) + else: + forces = torch.autograd.grad( + outputs=[total_energy_local], + inputs=[positions], + grad_outputs=grad_outputs, + retain_graph=False, + create_graph=False, + allow_unused=True, + )[0] + if forces is not None: + forces = -1 * forces + else: + forces = torch.zeros_like(positions) + return { + "total_energy_local": total_energy_local, + "node_energy": node_energy, + "forces": forces, + "virials": virials, + } diff --git a/mace/calculators/mace.py b/mace/calculators/mace.py new file mode 100644 index 00000000..dcd2b8e5 --- /dev/null +++ b/mace/calculators/mace.py @@ -0,0 +1,414 @@ +########################################################################################### +# The ASE Calculator for MACE +# Authors: Ilyes Batatia, David Kovacs +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + + +import logging +from glob import glob +from pathlib import Path +from typing import Union + +import numpy as np +import torch +from ase.calculators.calculator import Calculator, all_changes +from ase.stress import full_3x3_to_voigt_6_stress + +from mace import data +from mace.modules.utils import extract_invariant +from mace.tools import torch_geometric, torch_tools, utils +from mace.tools.compile import prepare +from mace.tools.scripts_utils import extract_model + + +def get_model_dtype(model: torch.nn.Module) -> torch.dtype: + """Get the dtype of the model""" + mode_dtype = next(model.parameters()).dtype + if mode_dtype == torch.float64: + return "float64" + if mode_dtype == torch.float32: + return "float32" + raise ValueError(f"Unknown dtype {mode_dtype}") + + +class MACECalculator(Calculator): + """MACE ASE Calculator + args: + model_paths: str, path to model or models if a committee is produced + to make a committee use a wild card notation like mace_*.model + device: str, device to run on (cuda or cpu) + energy_units_to_eV: float, conversion factor from model energy units to eV + length_units_to_A: float, conversion factor from model length units to Angstroms + default_dtype: str, default dtype of model + charges_key: str, Array field of atoms object where atomic charges are stored + model_type: str, type of model to load + Options: [MACE, DipoleMACE, EnergyDipoleMACE] + + Dipoles are returned in units of Debye + """ + + def __init__( + self, + model_paths: Union[list, str, None] = None, + models: Union[list[torch.nn.Module], torch.nn.Module, None] = None, + device: str = "cpu", + energy_units_to_eV: float = 1.0, + length_units_to_A: float = 1.0, + default_dtype="", + charges_key="Qs", + model_type="MACE", + compile_mode=None, + fullgraph=True, + **kwargs, + ): + Calculator.__init__(self, **kwargs) + + if "model_path" in kwargs: + deprecation_message = ( + "'model_path' argument is deprecated, please use 'model_paths'" + ) + if model_paths is None: + logging.warning(f"{deprecation_message} in the future.") + model_paths = kwargs["model_path"] + else: + raise ValueError( + f"both 'model_path' and 'model_paths' given, {deprecation_message} only." + ) + + if (model_paths is None) == (models is None): + raise ValueError( + "Exactly one of 'model_paths' or 'models' must be provided" + ) + + self.results = {} + + self.model_type = model_type + + if model_type == "MACE": + self.implemented_properties = [ + "energy", + "free_energy", + "node_energy", + "forces", + "stress", + ] + elif model_type == "DipoleMACE": + self.implemented_properties = ["dipole"] + elif model_type == "EnergyDipoleMACE": + self.implemented_properties = [ + "energy", + "free_energy", + "node_energy", + "forces", + "stress", + "dipole", + ] + else: + raise ValueError( + f"Give a valid model_type: [MACE, DipoleMACE, EnergyDipoleMACE], {model_type} not supported" + ) + + if model_paths is not None: + if isinstance(model_paths, str): + # Find all models that satisfy the wildcard (e.g. mace_model_*.pt) + model_paths_glob = glob(model_paths) + + if len(model_paths_glob) == 0: + raise ValueError(f"Couldn't find MACE model files: {model_paths}") + + model_paths = model_paths_glob + elif isinstance(model_paths, Path): + model_paths = [model_paths] + + if len(model_paths) == 0: + raise ValueError("No mace file names supplied") + self.num_models = len(model_paths) + + # Load models from files + self.models = [ + torch.load(f=model_path, map_location=device) + for model_path in model_paths + ] + + elif models is not None: + if not isinstance(models, list): + models = [models] + + if len(models) == 0: + raise ValueError("No models supplied") + + self.models = models + self.num_models = len(models) + + if self.num_models > 1: + print(f"Running committee mace with {self.num_models} models") + + if model_type in ["MACE", "EnergyDipoleMACE"]: + self.implemented_properties.extend( + ["energies", "energy_var", "forces_comm", "stress_var"] + ) + elif model_type == "DipoleMACE": + self.implemented_properties.extend(["dipole_var"]) + + if compile_mode is not None: + print(f"Torch compile is enabled with mode: {compile_mode}") + self.models = [ + torch.compile( + prepare(extract_model)(model=model, map_location=device), + mode=compile_mode, + fullgraph=fullgraph, + ) + for model in self.models + ] + self.use_compile = True + else: + self.use_compile = False + + # Ensure all models are on the same device + for model in self.models: + model.to(device) + + r_maxs = [model.r_max.cpu() for model in self.models] + r_maxs = np.array(r_maxs) + if not np.all(r_maxs == r_maxs[0]): + raise ValueError(f"committee r_max are not all the same {' '.join(r_maxs)}") + self.r_max = float(r_maxs[0]) + + self.device = torch_tools.init_device(device) + self.energy_units_to_eV = energy_units_to_eV + self.length_units_to_A = length_units_to_A + self.z_table = utils.AtomicNumberTable( + [int(z) for z in self.models[0].atomic_numbers] + ) + self.charges_key = charges_key + try: + self.heads = self.models[0].heads + except AttributeError: + self.heads = ["Default"] + model_dtype = get_model_dtype(self.models[0]) + if default_dtype == "": + print( + f"No dtype selected, switching to {model_dtype} to match model dtype." + ) + default_dtype = model_dtype + if model_dtype != default_dtype: + print( + f"Default dtype {default_dtype} does not match model dtype {model_dtype}, converting models to {default_dtype}." + ) + if default_dtype == "float64": + self.models = [model.double() for model in self.models] + elif default_dtype == "float32": + self.models = [model.float() for model in self.models] + torch_tools.set_default_dtype(default_dtype) + for model in self.models: + for param in model.parameters(): + param.requires_grad = False + + def _create_result_tensors( + self, model_type: str, num_models: int, num_atoms: int + ) -> dict: + """ + Create tensors to store the results of the committee + :param model_type: str, type of model to load + Options: [MACE, DipoleMACE, EnergyDipoleMACE] + :param num_models: int, number of models in the committee + :return: tuple of torch tensors + """ + dict_of_tensors = {} + if model_type in ["MACE", "EnergyDipoleMACE"]: + energies = torch.zeros(num_models, device=self.device) + node_energy = torch.zeros(num_models, num_atoms, device=self.device) + forces = torch.zeros(num_models, num_atoms, 3, device=self.device) + stress = torch.zeros(num_models, 3, 3, device=self.device) + dict_of_tensors.update( + { + "energies": energies, + "node_energy": node_energy, + "forces": forces, + "stress": stress, + } + ) + if model_type in ["EnergyDipoleMACE", "DipoleMACE"]: + dipole = torch.zeros(num_models, 3, device=self.device) + dict_of_tensors.update({"dipole": dipole}) + return dict_of_tensors + + def _atoms_to_batch(self, atoms): + config = data.config_from_atoms(atoms, charges_key=self.charges_key) + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[ + data.AtomicData.from_config( + config, z_table=self.z_table, cutoff=self.r_max, heads=self.heads + ) + ], + batch_size=1, + shuffle=False, + drop_last=False, + ) + batch = next(iter(data_loader)).to(self.device) + return batch + + def _clone_batch(self, batch): + batch_clone = batch.clone() + if self.use_compile: + batch_clone["node_attrs"].requires_grad_(True) + batch_clone["positions"].requires_grad_(True) + return batch_clone + + # pylint: disable=dangerous-default-value + def calculate(self, atoms=None, properties=None, system_changes=all_changes): + """ + Calculate properties. + :param atoms: ase.Atoms object + :param properties: [str], properties to be computed, used by ASE internally + :param system_changes: [str], system changes since last calculation, used by ASE internally + :return: + """ + # call to base-class to set atoms attribute + Calculator.calculate(self, atoms) + + batch_base = self._atoms_to_batch(atoms) + + if self.model_type in ["MACE", "EnergyDipoleMACE"]: + batch = self._clone_batch(batch_base) + node_heads = batch["head"][batch["batch"]] + num_atoms_arange = torch.arange(batch["positions"].shape[0]) + node_e0 = self.models[0].atomic_energies_fn(batch["node_attrs"])[ + num_atoms_arange, node_heads + ] + compute_stress = not self.use_compile + else: + compute_stress = False + + ret_tensors = self._create_result_tensors( + self.model_type, self.num_models, len(atoms) + ) + for i, model in enumerate(self.models): + batch = self._clone_batch(batch_base) + out = model( + batch.to_dict(), + compute_stress=compute_stress, + training=self.use_compile, + ) + if self.model_type in ["MACE", "EnergyDipoleMACE"]: + ret_tensors["energies"][i] = out["energy"].detach() + ret_tensors["node_energy"][i] = (out["node_energy"] - node_e0).detach() + ret_tensors["forces"][i] = out["forces"].detach() + if out["stress"] is not None: + ret_tensors["stress"][i] = out["stress"].detach() + if self.model_type in ["DipoleMACE", "EnergyDipoleMACE"]: + ret_tensors["dipole"][i] = out["dipole"].detach() + + self.results = {} + if self.model_type in ["MACE", "EnergyDipoleMACE"]: + self.results["energy"] = ( + torch.mean(ret_tensors["energies"], dim=0).cpu().item() + * self.energy_units_to_eV + ) + self.results["free_energy"] = self.results["energy"] + self.results["node_energy"] = ( + torch.mean(ret_tensors["node_energy"], dim=0).cpu().numpy() + ) + self.results["forces"] = ( + torch.mean(ret_tensors["forces"], dim=0).cpu().numpy() + * self.energy_units_to_eV + / self.length_units_to_A + ) + if self.num_models > 1: + self.results["energies"] = ( + ret_tensors["energies"].cpu().numpy() * self.energy_units_to_eV + ) + self.results["energy_var"] = ( + torch.var(ret_tensors["energies"], dim=0, unbiased=False) + .cpu() + .item() + * self.energy_units_to_eV + ) + self.results["forces_comm"] = ( + ret_tensors["forces"].cpu().numpy() + * self.energy_units_to_eV + / self.length_units_to_A + ) + if out["stress"] is not None: + self.results["stress"] = full_3x3_to_voigt_6_stress( + torch.mean(ret_tensors["stress"], dim=0).cpu().numpy() + * self.energy_units_to_eV + / self.length_units_to_A**3 + ) + if self.num_models > 1: + self.results["stress_var"] = full_3x3_to_voigt_6_stress( + torch.var(ret_tensors["stress"], dim=0, unbiased=False) + .cpu() + .numpy() + * self.energy_units_to_eV + / self.length_units_to_A**3 + ) + if self.model_type in ["DipoleMACE", "EnergyDipoleMACE"]: + self.results["dipole"] = ( + torch.mean(ret_tensors["dipole"], dim=0).cpu().numpy() + ) + if self.num_models > 1: + self.results["dipole_var"] = ( + torch.var(ret_tensors["dipole"], dim=0, unbiased=False) + .cpu() + .numpy() + ) + + def get_hessian(self, atoms=None): + if atoms is None and self.atoms is None: + raise ValueError("atoms not set") + if atoms is None: + atoms = self.atoms + if self.model_type != "MACE": + raise NotImplementedError("Only implemented for MACE models") + batch = self._atoms_to_batch(atoms) + hessians = [ + model( + self._clone_batch(batch).to_dict(), + compute_hessian=True, + compute_stress=False, + training=self.use_compile, + )["hessian"] + for model in self.models + ] + hessians = [hessian.detach().cpu().numpy() for hessian in hessians] + if self.num_models == 1: + return hessians[0] + return hessians + + def get_descriptors(self, atoms=None, invariants_only=True, num_layers=-1): + """Extracts the descriptors from MACE model. + :param atoms: ase.Atoms object + :param invariants_only: bool, if True only the invariant descriptors are returned + :param num_layers: int, number of layers to extract descriptors from, if -1 all layers are used + :return: np.ndarray (num_atoms, num_interactions, invariant_features) of invariant descriptors if num_models is 1 or list[np.ndarray] otherwise + """ + if atoms is None and self.atoms is None: + raise ValueError("atoms not set") + if atoms is None: + atoms = self.atoms + if self.model_type != "MACE": + raise NotImplementedError("Only implemented for MACE models") + if num_layers == -1: + num_layers = int(self.models[0].num_interactions) + batch = self._atoms_to_batch(atoms) + descriptors = [model(batch.to_dict())["node_feats"] for model in self.models] + if invariants_only: + irreps_out = self.models[0].products[0].linear.__dict__["irreps_out"] + l_max = irreps_out.lmax + num_features = irreps_out.dim // (l_max + 1) ** 2 + descriptors = [ + extract_invariant( + descriptor, + num_layers=num_layers, + num_features=num_features, + l_max=l_max, + ) + for descriptor in descriptors + ] + descriptors = [descriptor.detach().cpu().numpy() for descriptor in descriptors] + + if self.num_models == 1: + return descriptors[0] + return descriptors diff --git a/mace/cli/__init__.py b/mace/cli/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mace/cli/active_learning_md.py b/mace/cli/active_learning_md.py new file mode 100644 index 00000000..9cf4f4a8 --- /dev/null +++ b/mace/cli/active_learning_md.py @@ -0,0 +1,193 @@ +"""Demonstrates active learning molecular dynamics with constant temperature.""" + +import argparse +import os +import time + +import ase.io +import numpy as np +from ase import units +from ase.md.langevin import Langevin +from ase.md.velocitydistribution import MaxwellBoltzmannDistribution + +from mace.calculators.mace import MACECalculator + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("--config", help="path to XYZ configurations", required=True) + parser.add_argument( + "--config_index", help="index of configuration", type=int, default=-1 + ) + parser.add_argument( + "--error_threshold", help="error threshold", type=float, default=0.1 + ) + parser.add_argument("--temperature_K", help="temperature", type=float, default=300) + parser.add_argument("--friction", help="friction", type=float, default=0.01) + parser.add_argument("--timestep", help="timestep", type=float, default=1) + parser.add_argument("--nsteps", help="number of steps", type=int, default=1000) + parser.add_argument( + "--nprint", help="number of steps between prints", type=int, default=10 + ) + parser.add_argument( + "--nsave", help="number of steps between saves", type=int, default=10 + ) + parser.add_argument( + "--ncheckerror", help="number of steps between saves", type=int, default=10 + ) + + parser.add_argument( + "--model", + help="path to model. Use wildcards to add multiple models as committee eg " + "(`mace_*.model` to load mace_1.model, mace_2.model) ", + required=True, + ) + parser.add_argument("--output", help="output path", required=True) + parser.add_argument( + "--device", + help="select device", + type=str, + choices=["cpu", "cuda"], + default="cuda", + ) + parser.add_argument( + "--default_dtype", + help="set default dtype", + type=str, + choices=["float32", "float64"], + default="float64", + ) + parser.add_argument( + "--compute_stress", + help="compute stress", + action="store_true", + default=False, + ) + parser.add_argument( + "--info_prefix", + help="prefix for energy, forces and stress keys", + type=str, + default="MACE_", + ) + return parser.parse_args() + + +def printenergy(dyn, start_time=None): # store a reference to atoms in the definition. + """Function to print the potential, kinetic and total energy.""" + a = dyn.atoms + epot = a.get_potential_energy() / len(a) + ekin = a.get_kinetic_energy() / len(a) + if start_time is None: + elapsed_time = 0 + else: + elapsed_time = time.time() - start_time + forces_var = np.var(a.calc.results["forces_comm"], axis=0) + print( + "%.1fs: Energy per atom: Epot = %.3feV Ekin = %.3feV (T=%3.0fK) " # pylint: disable=C0209 + "Etot = %.3feV t=%.1ffs Eerr = %.3feV Ferr = %.3feV/A" + % ( + elapsed_time, + epot, + ekin, + ekin / (1.5 * units.kB), + epot + ekin, + dyn.get_time() / units.fs, + a.calc.results["energy_var"], + np.max(np.linalg.norm(forces_var, axis=1)), + ), + flush=True, + ) + + +def save_config(dyn, fname): + atomsi = dyn.atoms + ens = atomsi.get_potential_energy() + frcs = atomsi.get_forces() + + atomsi.info.update( + { + "mlff_energy": ens, + "time": np.round(dyn.get_time() / units.fs, 5), + "mlff_energy_var": atomsi.calc.results["energy_var"], + } + ) + atomsi.arrays.update( + { + "mlff_forces": frcs, + "mlff_forces_var": np.var(atomsi.calc.results["forces_comm"], axis=0), + } + ) + + ase.io.write(fname, atomsi, append=True) + + +def stop_error(dyn, threshold, reg=0.2): + atomsi = dyn.atoms + force_var = np.var(atomsi.calc.results["forces_comm"], axis=0) + force = atomsi.get_forces() + ferr = np.sqrt(np.sum(force_var, axis=1)) + ferr_rel = ferr / (np.linalg.norm(force, axis=1) + reg) + + if np.max(ferr_rel) > threshold: + print( + "Error too large {:.3}. Stopping t={:.2} fs.".format( # pylint: disable=C0209 + np.max(ferr_rel), dyn.get_time() / units.fs + ), + flush=True, + ) + dyn.max_steps = 0 + + +def main() -> None: + args = parse_args() + run(args) + + +def run(args: argparse.Namespace) -> None: + mace_fname = args.model + atoms_fname = args.config + atoms_index = args.config_index + + mace_calc = MACECalculator( + model_paths=mace_fname, + device=args.device, + default_dtype=args.default_dtype, + ) + + NSTEPS = args.nsteps + + if os.path.exists(args.output): + print("Trajectory exists. Continuing from last step.") + atoms = ase.io.read(args.output, index=-1) + len_save = len(ase.io.read(args.output, ":")) + print("Last step: ", atoms.info["time"], "Number of configs: ", len_save) + NSTEPS -= len_save * args.nsave + else: + atoms = ase.io.read(atoms_fname, index=atoms_index) + MaxwellBoltzmannDistribution(atoms, temperature_K=args.temperature_K) + + atoms.calc = mace_calc + + # We want to run MD with constant energy using the Langevin algorithm + # with a time step of 5 fs, the temperature T and the friction + # coefficient to 0.02 atomic units. + dyn = Langevin( + atoms=atoms, + timestep=args.timestep * units.fs, + temperature_K=args.temperature_K, + friction=args.friction, + ) + + dyn.attach(printenergy, interval=args.nsave, dyn=dyn, start_time=time.time()) + dyn.attach(save_config, interval=args.nsave, dyn=dyn, fname=args.output) + dyn.attach( + stop_error, interval=args.ncheckerror, dyn=dyn, threshold=args.error_threshold + ) + # Now run the dynamics + dyn.run(NSTEPS) + + +if __name__ == "__main__": + main() diff --git a/mace/cli/convert_device.py b/mace/cli/convert_device.py new file mode 100644 index 00000000..9dd8c61d --- /dev/null +++ b/mace/cli/convert_device.py @@ -0,0 +1,31 @@ +from argparse import ArgumentParser + +import torch + + +def main(): + parser = ArgumentParser() + parser.add_argument( + "--target_device", + "-t", + help="device to convert to, usually 'cpu' or 'cuda'", + default="cpu", + ) + parser.add_argument( + "--output_file", + "-o", + help="name for output model, defaults to model_file.target_device", + ) + parser.add_argument("model_file", help="input model file path") + args = parser.parse_args() + + if args.output_file is None: + args.output_file = args.model_file + "." + args.target_device + + model = torch.load(args.model_file) + model.to(args.target_device) + torch.save(model, args.output_file) + + +if __name__ == "__main__": + main() diff --git a/mace/cli/create_lammps_model.py b/mace/cli/create_lammps_model.py new file mode 100644 index 00000000..507a2cd0 --- /dev/null +++ b/mace/cli/create_lammps_model.py @@ -0,0 +1,92 @@ +import argparse + +import torch +from e3nn.util import jit + +from mace.calculators import LAMMPS_MACE + + +def parse_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "model_path", + type=str, + help="Path to the model to be converted to LAMMPS", + ) + parser.add_argument( + "--head", + type=str, + nargs="?", + help="Head of the model to be converted to LAMMPS", + default=None, + ) + parser.add_argument( + "--dtype", + type=str, + nargs="?", + help="Data type of the model to be converted to LAMMPS", + default="float64", + ) + return parser.parse_args() + + +def select_head(model): + if hasattr(model, "heads"): + heads = model.heads + else: + heads = [None] + + if len(heads) == 1: + print(f"Only one head found in the model: {heads[0]}. Skipping selection.") + return heads[0] + + print("Available heads in the model:") + for i, head in enumerate(heads): + print(f"{i + 1}: {head}") + + # Ask the user to select a head + selected = input( + f"Select a head by number (Defaulting to head: {len(heads)}, press Enter to accept): " + ) + + if selected.isdigit() and 1 <= int(selected) <= len(heads): + return heads[int(selected) - 1] + if selected == "": + print("No head selected. Proceeding without specifying a head.") + return None + print(f"No valid selection made. Defaulting to the last head: {heads[-1]}") + return heads[-1] + + +def main(): + args = parse_args() + model_path = args.model_path # takes model name as command-line input + model = torch.load( + model_path, + map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"), + ) + if args.dtype == "float64": + model = model.double().to("cpu") + elif args.dtype == "float32": + print("Converting model to float32, this may cause loss of precision.") + model = model.float().to("cpu") + + if args.head is None: + head = select_head(model) + else: + head = args.head + print( + f"Selected head: {head} from command line in the list available heads: {model.heads}" + ) + + lammps_model = ( + LAMMPS_MACE(model, head=head) if head is not None else LAMMPS_MACE(model) + ) + lammps_model_compiled = jit.compile(lammps_model) + lammps_model_compiled.save(model_path + "-lammps.pt") + + +if __name__ == "__main__": + main() diff --git a/mace/cli/eval_configs.py b/mace/cli/eval_configs.py new file mode 100644 index 00000000..d00c54c6 --- /dev/null +++ b/mace/cli/eval_configs.py @@ -0,0 +1,165 @@ +########################################################################################### +# Script for evaluating configurations contained in an xyz file with a trained model +# Authors: Ilyes Batatia, Gregor Simm +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import argparse + +import ase.data +import ase.io +import numpy as np +import torch + +from mace import data +from mace.tools import torch_geometric, torch_tools, utils + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("--configs", help="path to XYZ configurations", required=True) + parser.add_argument("--model", help="path to model", required=True) + parser.add_argument("--output", help="output path", required=True) + parser.add_argument( + "--device", + help="select device", + type=str, + choices=["cpu", "cuda"], + default="cpu", + ) + parser.add_argument( + "--default_dtype", + help="set default dtype", + type=str, + choices=["float32", "float64"], + default="float64", + ) + parser.add_argument("--batch_size", help="batch size", type=int, default=64) + parser.add_argument( + "--compute_stress", + help="compute stress", + action="store_true", + default=False, + ) + parser.add_argument( + "--return_contributions", + help="model outputs energy contributions for each body order, only supported for MACE, not ScaleShiftMACE", + action="store_true", + default=False, + ) + parser.add_argument( + "--info_prefix", + help="prefix for energy, forces and stress keys", + type=str, + default="MACE_", + ) + parser.add_argument( + "--head", + help="Model head used for evaluation", + type=str, + required=False, + default=None, + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + run(args) + + +def run(args: argparse.Namespace) -> None: + torch_tools.set_default_dtype(args.default_dtype) + device = torch_tools.init_device(args.device) + + # Load model + model = torch.load(f=args.model, map_location=args.device) + model = model.to( + args.device + ) # shouldn't be necessary but seems to help with CUDA problems + + for param in model.parameters(): + param.requires_grad = False + + # Load data and prepare input + atoms_list = ase.io.read(args.configs, index=":") + if args.head is not None: + for atoms in atoms_list: + atoms.info["head"] = args.head + configs = [data.config_from_atoms(atoms) for atoms in atoms_list] + + z_table = utils.AtomicNumberTable([int(z) for z in model.atomic_numbers]) + + try: + heads = model.heads + except AttributeError: + heads = None + + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[ + data.AtomicData.from_config( + config, z_table=z_table, cutoff=float(model.r_max), heads=heads + ) + for config in configs + ], + batch_size=args.batch_size, + shuffle=False, + drop_last=False, + ) + + # Collect data + energies_list = [] + contributions_list = [] + stresses_list = [] + forces_collection = [] + + for batch in data_loader: + batch = batch.to(device) + output = model(batch.to_dict(), compute_stress=args.compute_stress) + energies_list.append(torch_tools.to_numpy(output["energy"])) + if args.compute_stress: + stresses_list.append(torch_tools.to_numpy(output["stress"])) + + if args.return_contributions: + contributions_list.append(torch_tools.to_numpy(output["contributions"])) + + forces = np.split( + torch_tools.to_numpy(output["forces"]), + indices_or_sections=batch.ptr[1:], + axis=0, + ) + forces_collection.append(forces[:-1]) # drop last as its empty + + energies = np.concatenate(energies_list, axis=0) + forces_list = [ + forces for forces_list in forces_collection for forces in forces_list + ] + assert len(atoms_list) == len(energies) == len(forces_list) + if args.compute_stress: + stresses = np.concatenate(stresses_list, axis=0) + assert len(atoms_list) == stresses.shape[0] + + if args.return_contributions: + contributions = np.concatenate(contributions_list, axis=0) + assert len(atoms_list) == contributions.shape[0] + + # Store data in atoms objects + for i, (atoms, energy, forces) in enumerate(zip(atoms_list, energies, forces_list)): + atoms.calc = None # crucial + atoms.info[args.info_prefix + "energy"] = energy + atoms.arrays[args.info_prefix + "forces"] = forces + + if args.compute_stress: + atoms.info[args.info_prefix + "stress"] = stresses[i] + + if args.return_contributions: + atoms.info[args.info_prefix + "BO_contributions"] = contributions[i] + + # Write atoms to output path + ase.io.write(args.output, images=atoms_list, format="extxyz") + + +if __name__ == "__main__": + main() diff --git a/mace/cli/fine_tuning_select.py b/mace/cli/fine_tuning_select.py new file mode 100644 index 00000000..94baf0dd --- /dev/null +++ b/mace/cli/fine_tuning_select.py @@ -0,0 +1,348 @@ +########################################################################################### +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import argparse +import logging +from typing import List + +import ase.data +import ase.io +import numpy as np +import torch + +from mace.calculators import MACECalculator, mace_mp + +try: + import fpsample +except ImportError: + pass + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--configs_pt", + help="path to XYZ configurations for the pretraining", + required=True, + ) + parser.add_argument( + "--configs_ft", + help="path or list of paths to XYZ configurations for the finetuning", + required=True, + ) + parser.add_argument( + "--num_samples", + help="number of samples to select for the pretraining", + type=int, + required=False, + default=None, + ) + parser.add_argument( + "--subselect", + help="method to subselect the configurations of the pretraining set", + type=str, + choices=["fps", "random"], + default="fps", + ) + parser.add_argument( + "--model", help="path to model", default="small", required=False + ) + parser.add_argument("--output", help="output path", required=True) + parser.add_argument( + "--descriptors", help="path to descriptors", required=False, default=None + ) + parser.add_argument( + "--device", + help="select device", + type=str, + choices=["cpu", "cuda"], + default="cpu", + ) + parser.add_argument( + "--default_dtype", + help="set default dtype", + type=str, + choices=["float32", "float64"], + default="float64", + ) + parser.add_argument( + "--head_pt", + help="level of head for the pretraining set", + type=str, + default=None, + ) + parser.add_argument( + "--head_ft", + help="level of head for the finetuning set", + type=str, + default=None, + ) + parser.add_argument( + "--filtering_type", + help="filtering type", + type=str, + choices=[None, "combinations", "exclusive", "inclusive"], + default=None, + ) + parser.add_argument( + "--weight_ft", + help="weight for the finetuning set", + type=float, + default=1.0, + ) + parser.add_argument( + "--weight_pt", + help="weight for the pretraining set", + type=float, + default=1.0, + ) + parser.add_argument("--seed", help="random seed", type=int, default=42) + return parser.parse_args() + + +def calculate_descriptors(atoms: List[ase.Atoms], calc: MACECalculator) -> None: + logging.info("Calculating descriptors") + for mol in atoms: + descriptors = calc.get_descriptors(mol.copy(), invariants_only=True) + # average descriptors over atoms for each element + descriptors_dict = { + element: np.mean(descriptors[mol.symbols == element], axis=0) + for element in np.unique(mol.symbols) + } + mol.info["mace_descriptors"] = descriptors_dict + + +def filter_atoms( + atoms: ase.Atoms, element_subset: List[str], filtering_type: str +) -> bool: + """ + Filters atoms based on the provided filtering type and element subset. + + Parameters: + atoms (ase.Atoms): The atoms object to filter. + element_subset (list): The list of elements to consider during filtering. + filtering_type (str): The type of filtering to apply. Can be 'none', 'exclusive', or 'inclusive'. + 'none' - No filtering is applied. + 'combinations' - Return true if `atoms` is composed of combinations of elements in the subset, false otherwise. I.e. does not require all of the specified elements to be present. + 'exclusive' - Return true if `atoms` contains *only* elements in the subset, false otherwise. + 'inclusive' - Return true if `atoms` contains all elements in the subset, false otherwise. I.e. allows additional elements. + + Returns: + bool: True if the atoms pass the filter, False otherwise. + """ + if filtering_type == "none": + return True + if filtering_type == "combinations": + atom_symbols = np.unique(atoms.symbols) + return all( + x in element_subset for x in atom_symbols + ) # atoms must *only* contain elements in the subset + if filtering_type == "exclusive": + atom_symbols = set(list(atoms.symbols)) + return atom_symbols == set(element_subset) + if filtering_type == "inclusive": + atom_symbols = np.unique(atoms.symbols) + return all( + x in atom_symbols for x in element_subset + ) # atoms must *at least* contain elements in the subset + raise ValueError( + f"Filtering type {filtering_type} not recognised. Must be one of 'none', 'exclusive', or 'inclusive'." + ) + + +class FPS: + def __init__(self, atoms_list: List[ase.Atoms], n_samples: int): + self.n_samples = n_samples + self.atoms_list = atoms_list + self.species = np.unique([x.symbol for atoms in atoms_list for x in atoms]) + self.species_dict = {x: i for i, x in enumerate(self.species)} + # start from a random configuration + self.list_index = [np.random.randint(0, len(atoms_list))] + self.assemble_descriptors() + + def run( + self, + ) -> List[int]: + """ + Run the farthest point sampling algorithm. + """ + descriptor_dataset_reshaped = ( + self.descriptors_dataset.reshape( # pylint: disable=E1121 + (len(self.atoms_list), -1) + ) + ) + logging.info(f"{descriptor_dataset_reshaped.shape}") + logging.info(f"n_samples: {self.n_samples}") + self.list_index = fpsample.fps_npdu_kdtree_sampling( + descriptor_dataset_reshaped, + self.n_samples, + ) + return self.list_index + + def assemble_descriptors(self) -> np.ndarray: + """ + Assemble the descriptors for all the configurations. + """ + self.descriptors_dataset: np.ndarray = 10e10 * np.ones( + ( + len(self.atoms_list), + len(self.species), + len(list(self.atoms_list[0].info["mace_descriptors"].values())[0]), + ), + dtype=np.float32, + ).astype(np.float32) + + for i, atoms in enumerate(self.atoms_list): + descriptors = atoms.info["mace_descriptors"] + for z in descriptors: + self.descriptors_dataset[i, self.species_dict[z]] = np.array( + descriptors[z] + ).astype(np.float32) + + +def select_samples( + args: argparse.Namespace, +) -> None: + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if args.model in ["small", "medium", "large"]: + calc = mace_mp(args.model, device=args.device, default_dtype=args.default_dtype) + else: + calc = MACECalculator( + model_paths=args.model, device=args.device, default_dtype=args.default_dtype + ) + if isinstance(args.configs_ft, str): + atoms_list_ft = ase.io.read(args.configs_ft, index=":") + else: + atoms_list_ft = [] + for path in args.configs_ft: + atoms_list_ft += ase.io.read(path, index=":") + + if args.filtering_type is not None: + all_species_ft = np.unique([x.symbol for atoms in atoms_list_ft for x in atoms]) + logging.info( + "Filtering configurations based on the finetuning set, " + f"filtering type: combinations, elements: {all_species_ft}" + ) + if args.subselect != "random": + if args.descriptors is not None: + logging.info("Loading descriptors") + descriptors = np.load(args.descriptors, allow_pickle=True) + atoms_list_pt = ase.io.read(args.configs_pt, index=":") + for i, atoms in enumerate(atoms_list_pt): + atoms.info["mace_descriptors"] = descriptors[i] + atoms_list_pt_filtered = [ + x + for x in atoms_list_pt + if filter_atoms(x, all_species_ft, "combinations") + ] + else: + atoms_list_pt = ase.io.read(args.configs_pt, index=":") + atoms_list_pt_filtered = [ + x + for x in atoms_list_pt + if filter_atoms(x, all_species_ft, "combinations") + ] + else: + atoms_list_pt = ase.io.read(args.configs_pt, index=":") + atoms_list_pt_filtered = [ + x + for x in atoms_list_pt + if filter_atoms(x, all_species_ft, "combinations") + ] + if len(atoms_list_pt_filtered) <= args.num_samples: + logging.info( + f"Number of configurations after filtering {len(atoms_list_pt_filtered)} " + f"is less than the number of samples {args.num_samples}, " + "selecting random configurations for the rest." + ) + atoms_list_pt_minus_filtered = [ + x for x in atoms_list_pt if x not in atoms_list_pt_filtered + ] + atoms_list_pt_random_inds = np.random.choice( + list(range(len(atoms_list_pt_minus_filtered))), + args.num_samples - len(atoms_list_pt_filtered), + replace=False, + ) + atoms_list_pt = atoms_list_pt_filtered + [ + atoms_list_pt_minus_filtered[ind] for ind in atoms_list_pt_random_inds + ] + else: + atoms_list_pt = atoms_list_pt_filtered + + else: + atoms_list_pt = ase.io.read(args.configs_pt, index=":") + if args.descriptors is not None: + logging.info( + f"Loading descriptors for the pretraining set from {args.descriptors}" + ) + descriptors = np.load(args.descriptors, allow_pickle=True) + for i, atoms in enumerate(atoms_list_pt): + atoms.info["mace_descriptors"] = descriptors[i] + + if args.num_samples is not None and args.num_samples < len(atoms_list_pt): + if args.subselect == "fps": + if args.descriptors is None: + logging.info("Calculating descriptors for the pretraining set") + calculate_descriptors(atoms_list_pt, calc) + descriptors_list = [ + atoms.info["mace_descriptors"] for atoms in atoms_list_pt + ] + logging.info( + f"Saving descriptors at {args.output.replace('.xyz', '_descriptors.npy')}" + ) + np.save( + args.output.replace(".xyz", "_descriptors.npy"), descriptors_list + ) + logging.info("Selecting configurations using Farthest Point Sampling") + try: + fps_pt = FPS(atoms_list_pt, args.num_samples) + idx_pt = fps_pt.run() + logging.info(f"Selected {len(idx_pt)} configurations") + except Exception as e: # pylint: disable=W0703 + logging.error( + f"FPS failed, selecting random configurations instead: {e}" + ) + idx_pt = np.random.choice( + list(range(len(atoms_list_pt))), args.num_samples, replace=False + ) + atoms_list_pt = [atoms_list_pt[i] for i in idx_pt] + else: + logging.info("Selecting random configurations") + idx_pt = np.random.choice( + list(range(len(atoms_list_pt))), args.num_samples, replace=False + ) + atoms_list_pt = [atoms_list_pt[i] for i in idx_pt] + for atoms in atoms_list_pt: + # del atoms.info["mace_descriptors"] + atoms.info["pretrained"] = True + atoms.info["config_weight"] = args.weight_pt + atoms.info["mace_descriptors"] = None + if args.head_pt is not None: + atoms.info["head"] = args.head_pt + + logging.info("Saving the selected configurations") + ase.io.write(args.output, atoms_list_pt, format="extxyz") + logging.info("Saving a combined XYZ file") + for atoms in atoms_list_ft: + atoms.info["pretrained"] = False + atoms.info["config_weight"] = args.weight_ft + atoms.info["mace_descriptors"] = None + if args.head_ft is not None: + atoms.info["head"] = args.head_ft + atoms_fps_pt_ft = atoms_list_pt + atoms_list_ft + ase.io.write( + args.output.replace(".xyz", "_combined.xyz"), atoms_fps_pt_ft, format="extxyz" + ) + + +def main(): + args = parse_args() + select_samples(args) + + +if __name__ == "__main__": + main() diff --git a/mace/cli/plot_train.py b/mace/cli/plot_train.py new file mode 100644 index 00000000..a1c424df --- /dev/null +++ b/mace/cli/plot_train.py @@ -0,0 +1,193 @@ +import argparse +import dataclasses +import glob +import json +import os +import re +from typing import List + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +fig_width = 2.5 +fig_height = 2.1 + +plt.rcParams.update({"font.size": 6}) + +colors = [ + "#1f77b4", # muted blue + "#d62728", # brick red + "#ff7f0e", # safety orange + "#2ca02c", # cooked asparagus green + "#9467bd", # muted purple + "#8c564b", # chestnut brown + "#e377c2", # raspberry yogurt pink + "#7f7f7f", # middle gray + "#bcbd22", # curry yellow-green + "#17becf", # blue-teal +] + + +@dataclasses.dataclass +class RunInfo: + name: str + seed: int + + +name_re = re.compile(r"(?P.+)_run-(?P\d+)_train.txt") + + +def parse_path(path: str) -> RunInfo: + match = name_re.match(os.path.basename(path)) + if not match: + raise RuntimeError(f"Cannot parse {path}") + + return RunInfo(name=match.group("name"), seed=int(match.group("seed"))) + + +def parse_training_results(path: str) -> List[dict]: + run_info = parse_path(path) + results = [] + with open(path, mode="r", encoding="utf-8") as f: + for line in f: + d = json.loads(line) + d["name"] = run_info.name + d["seed"] = run_info.seed + results.append(d) + + return results + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Plot mace training statistics", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--path", help="path to results file or directory", required=True + ) + parser.add_argument( + "--min_epoch", help="minimum epoch", default=50, type=int, required=False + ) + return parser.parse_args() + + +def plot(data: pd.DataFrame, min_epoch: int, output_path: str) -> None: + data = data[data["epoch"] > min_epoch] + + data = data.groupby(["name", "mode", "epoch"]).agg([np.mean, np.std]).reset_index() + + valid_data = data[data["mode"] == "eval"] + train_data = data[data["mode"] == "opt"] + + fig, axes = plt.subplots( + nrows=1, ncols=2, figsize=(2 * fig_width, fig_height), constrained_layout=True + ) + + ax = axes[0] + ax.plot( + valid_data["epoch"], + valid_data["loss"]["mean"], + color=colors[0], + zorder=1, + label="Validation", + ) + ax.fill_between( + x=valid_data["epoch"], + y1=valid_data["loss"]["mean"] - valid_data["loss"]["std"], + y2=valid_data["loss"]["mean"] + valid_data["loss"]["std"], + alpha=0.5, + zorder=-1, + color=colors[0], + ) + ax.plot( + train_data["epoch"], + train_data["loss"]["mean"], + color=colors[3], + zorder=1, + label="Training", + ) + ax.fill_between( + x=train_data["epoch"], + y1=train_data["loss"]["mean"] - train_data["loss"]["std"], + y2=train_data["loss"]["mean"] + train_data["loss"]["std"], + alpha=0.5, + zorder=-1, + color=colors[3], + ) + + ax.set_ylim(bottom=0.0) + ax.set_xlabel("Epoch") + ax.set_ylabel("Loss") + ax.legend() + + ax = axes[1] + ax.plot( + valid_data["epoch"], + valid_data["mae_e"]["mean"], + color=colors[1], + zorder=1, + label="MAE Energy [eV]", + ) + ax.fill_between( + x=valid_data["epoch"], + y1=valid_data["mae_e"]["mean"] - valid_data["mae_e"]["std"], + y2=valid_data["mae_e"]["mean"] + valid_data["mae_e"]["std"], + alpha=0.5, + zorder=-1, + color=colors[1], + ) + ax.plot( + valid_data["epoch"], + valid_data["mae_f"]["mean"], + color=colors[2], + zorder=1, + label="MAE Forces [eV/Å]", + ) + ax.fill_between( + x=valid_data["epoch"], + y1=valid_data["mae_f"]["mean"] - valid_data["mae_f"]["std"], + y2=valid_data["mae_f"]["mean"] + valid_data["mae_f"]["std"], + alpha=0.5, + zorder=-1, + color=colors[2], + ) + + ax.set_ylim(bottom=0.0) + ax.set_xlabel("Epoch") + ax.legend() + + fig.savefig(output_path) + plt.close(fig) + + +def get_paths(path: str) -> List[str]: + if os.path.isfile(path): + return [path] + paths = glob.glob(os.path.join(path, "*_train.txt")) + + if len(paths) == 0: + raise RuntimeError(f"Cannot find results in '{path}'") + + return paths + + +def main() -> None: + args = parse_args() + run(args) + + +def run(args: argparse.Namespace) -> None: + data = pd.DataFrame( + results + for path in get_paths(args.path) + for results in parse_training_results(path) + ) + + for name, group in data.groupby("name"): + plot(group, min_epoch=args.min_epoch, output_path=f"{name}.pdf") + + +if __name__ == "__main__": + main() diff --git a/mace/cli/preprocess_data.py b/mace/cli/preprocess_data.py new file mode 100644 index 00000000..ef9f1343 --- /dev/null +++ b/mace/cli/preprocess_data.py @@ -0,0 +1,288 @@ +# This file loads an xyz dataset and prepares +# new hdf5 file that is ready for training with on-the-fly dataloading + +import argparse +import ast +import json +import logging +import multiprocessing as mp +import os +import random +from functools import partial +from glob import glob +from typing import List, Tuple + +import h5py +import numpy as np +import tqdm + +from mace import data, tools +from mace.data.utils import save_configurations_as_HDF5 +from mace.modules import compute_statistics +from mace.tools import torch_geometric +from mace.tools.scripts_utils import get_atomic_energies, get_dataset_from_xyz +from mace.tools.utils import AtomicNumberTable + + +def compute_stats_target( + file: str, + z_table: AtomicNumberTable, + r_max: float, + atomic_energies: Tuple, + batch_size: int, +): + train_dataset = data.HDF5Dataset(file, z_table=z_table, r_max=r_max) + train_loader = torch_geometric.dataloader.DataLoader( + dataset=train_dataset, + batch_size=batch_size, + shuffle=False, + drop_last=False, + ) + + avg_num_neighbors, mean, std = compute_statistics(train_loader, atomic_energies) + output = [avg_num_neighbors, mean, std] + return output + + +def pool_compute_stats(inputs: List): + path_to_files, z_table, r_max, atomic_energies, batch_size, num_process = inputs + + with mp.Pool(processes=num_process) as pool: + re = [ + pool.apply_async( + compute_stats_target, + args=( + file, + z_table, + r_max, + atomic_energies, + batch_size, + ), + ) + for file in glob(path_to_files + "/*") + ] + + pool.close() + pool.join() + + results = [r.get() for r in tqdm.tqdm(re)] + + if not results: + raise ValueError( + "No results were computed. Check if the input files exist and are readable." + ) + + # Separate avg_num_neighbors, mean, and std + avg_num_neighbors = np.mean([r[0] for r in results]) + means = np.array([r[1] for r in results]) + stds = np.array([r[2] for r in results]) + + # Compute averages + mean = np.mean(means, axis=0).item() + std = np.mean(stds, axis=0).item() + + return avg_num_neighbors, mean, std + + +def split_array(a: np.ndarray, max_size: int): + drop_last = False + if len(a) % 2 == 1: + a = np.append(a, a[-1]) + drop_last = True + factors = get_prime_factors(len(a)) + max_factor = 1 + for i in range(1, len(factors) + 1): + for j in range(0, len(factors) - i + 1): + if np.prod(factors[j : j + i]) <= max_size: + test = np.prod(factors[j : j + i]) + max_factor = max(test, max_factor) + return np.array_split(a, max_factor), drop_last + + +def get_prime_factors(n: int): + factors = [] + for i in range(2, n + 1): + while n % i == 0: + factors.append(i) + n = n / i + return factors + + +# Define Task for Multiprocessiing +def multi_train_hdf5(process, args, split_train, drop_last): + with h5py.File(args.h5_prefix + "train/train_" + str(process) + ".h5", "w") as f: + f.attrs["drop_last"] = drop_last + save_configurations_as_HDF5(split_train[process], process, f) + + +def multi_valid_hdf5(process, args, split_valid, drop_last): + with h5py.File(args.h5_prefix + "val/val_" + str(process) + ".h5", "w") as f: + f.attrs["drop_last"] = drop_last + save_configurations_as_HDF5(split_valid[process], process, f) + + +def multi_test_hdf5(process, name, args, split_test, drop_last): + with h5py.File( + args.h5_prefix + "test/" + name + "_" + str(process) + ".h5", "w" + ) as f: + f.attrs["drop_last"] = drop_last + save_configurations_as_HDF5(split_test[process], process, f) + + +def main() -> None: + """ + This script loads an xyz dataset and prepares + new hdf5 file that is ready for training with on-the-fly dataloading + """ + args = tools.build_preprocess_arg_parser().parse_args() + run(args) + + +def run(args: argparse.Namespace): + """ + This script loads an xyz dataset and prepares + new hdf5 file that is ready for training with on-the-fly dataloading + """ + + # Setup + tools.set_seeds(args.seed) + random.seed(args.seed) + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)-8s %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + handlers=[logging.StreamHandler()], + ) + + try: + config_type_weights = ast.literal_eval(args.config_type_weights) + assert isinstance(config_type_weights, dict) + except Exception as e: # pylint: disable=W0703 + logging.warning( + f"Config type weights not specified correctly ({e}), using Default" + ) + config_type_weights = {"Default": 1.0} + + folders = ["train", "val", "test"] + for sub_dir in folders: + if not os.path.exists(args.h5_prefix + sub_dir): + os.makedirs(args.h5_prefix + sub_dir) + + # Data preparation + collections, atomic_energies_dict = get_dataset_from_xyz( + work_dir=args.work_dir, + train_path=args.train_file, + valid_path=args.valid_file, + valid_fraction=args.valid_fraction, + config_type_weights=config_type_weights, + test_path=args.test_file, + seed=args.seed, + energy_key=args.energy_key, + forces_key=args.forces_key, + stress_key=args.stress_key, + virials_key=args.virials_key, + dipole_key=args.dipole_key, + charges_key=args.charges_key, + ) + + # Atomic number table + # yapf: disable + if args.atomic_numbers is None: + z_table = tools.get_atomic_number_table_from_zs( + z + for configs in (collections.train, collections.valid) + for config in configs + for z in config.atomic_numbers + ) + else: + logging.info("Using atomic numbers from command line argument") + zs_list = ast.literal_eval(args.atomic_numbers) + assert isinstance(zs_list, list) + z_table = tools.get_atomic_number_table_from_zs(zs_list) + + logging.info("Preparing training set") + if args.shuffle: + random.shuffle(collections.train) + + # split collections.train into batches and save them to hdf5 + split_train = np.array_split(collections.train,args.num_process) + drop_last = False + if len(collections.train) % 2 == 1: + drop_last = True + + multi_train_hdf5_ = partial(multi_train_hdf5, args=args, split_train=split_train, drop_last=drop_last) + processes = [] + for i in range(args.num_process): + p = mp.Process(target=multi_train_hdf5_, args=[i]) + p.start() + processes.append(p) + + for i in processes: + i.join() + + if args.compute_statistics: + logging.info("Computing statistics") + if len(atomic_energies_dict) == 0: + atomic_energies_dict = get_atomic_energies(args.E0s, collections.train, z_table) + atomic_energies: np.ndarray = np.array( + [atomic_energies_dict[z] for z in z_table.zs] + ) + logging.info(f"Atomic Energies: {atomic_energies.tolist()}") + _inputs = [args.h5_prefix+'train', z_table, args.r_max, atomic_energies, args.batch_size, args.num_process] + avg_num_neighbors, mean, std=pool_compute_stats(_inputs) + logging.info(f"Average number of neighbors: {avg_num_neighbors}") + logging.info(f"Mean: {mean}") + logging.info(f"Standard deviation: {std}") + + # save the statistics as a json + statistics = { + "atomic_energies": str(atomic_energies_dict), + "avg_num_neighbors": avg_num_neighbors, + "mean": mean, + "std": std, + "atomic_numbers": str(z_table.zs), + "r_max": args.r_max, + } + + with open(args.h5_prefix + "statistics.json", "w") as f: # pylint: disable=W1514 + json.dump(statistics, f) + + logging.info("Preparing validation set") + if args.shuffle: + random.shuffle(collections.valid) + split_valid = np.array_split(collections.valid, args.num_process) + drop_last = False + if len(collections.valid) % 2 == 1: + drop_last = True + + multi_valid_hdf5_ = partial(multi_valid_hdf5, args=args, split_valid=split_valid, drop_last=drop_last) + processes = [] + for i in range(args.num_process): + p = mp.Process(target=multi_valid_hdf5_, args=[i]) + p.start() + processes.append(p) + + for i in processes: + i.join() + + if args.test_file is not None: + logging.info("Preparing test sets") + for name, subset in collections.tests: + drop_last = False + if len(subset) % 2 == 1: + drop_last = True + split_test = np.array_split(subset, args.num_process) + multi_test_hdf5_ = partial(multi_test_hdf5, args=args, split_test=split_test, drop_last=drop_last) + + processes = [] + for i in range(args.num_process): + p = mp.Process(target=multi_test_hdf5_, args=[i, name]) + p.start() + processes.append(p) + + for i in processes: + i.join() + + +if __name__ == "__main__": + main() diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py new file mode 100644 index 00000000..3813b055 --- /dev/null +++ b/mace/cli/run_train.py @@ -0,0 +1,811 @@ +########################################################################################### +# Training script for MACE +# Authors: Ilyes Batatia, Gregor Simm, David Kovacs +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import argparse +import ast +import glob +import json +import logging +import os +from copy import deepcopy +from pathlib import Path +from typing import List, Optional + +import torch.distributed +import torch.nn.functional +from e3nn.util import jit +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import ConcatDataset +from torch_ema import ExponentialMovingAverage + +import mace +from mace import data, tools +from mace.calculators.foundations_models import mace_mp, mace_off +from mace.tools import torch_geometric +from mace.tools.model_script_utils import configure_model +from mace.tools.multihead_tools import ( + HeadConfig, + assemble_mp_data, + dict_head_to_dataclass, + prepare_default_head, +) +from mace.tools.scripts_utils import ( + LRScheduler, + check_path_ase_read, + convert_to_json_format, + dict_to_array, + extract_config_mace_model, + get_atomic_energies, + get_avg_num_neighbors, + get_config_type_weights, + get_dataset_from_xyz, + get_files_with_suffix, + get_loss_fn, + get_optimizer, + get_params_options, + get_swa, + print_git_commit, + remove_pt_head, + setup_wandb, +) +from mace.tools.slurm_distributed import DistributedEnvironment +from mace.tools.tables_utils import create_error_table +from mace.tools.utils import AtomicNumberTable + + +def main() -> None: + """ + This script runs the training/fine tuning for mace + """ + args = tools.build_default_arg_parser().parse_args() + run(args) + + +def run(args: argparse.Namespace) -> None: + """ + This script runs the training/fine tuning for mace + """ + tag = tools.get_tag(name=args.name, seed=args.seed) + args, input_log_messages = tools.check_args(args) + + if args.device == "xpu": + try: + import intel_extension_for_pytorch as ipex + except ImportError as e: + raise ImportError( + "Error: Intel extension for PyTorch not found, but XPU device was specified" + ) from e + if args.distributed: + try: + distr_env = DistributedEnvironment() + except Exception as e: # pylint: disable=W0703 + logging.error(f"Failed to initialize distributed environment: {e}") + return + world_size = distr_env.world_size + local_rank = distr_env.local_rank + rank = distr_env.rank + if rank == 0: + print(distr_env) + torch.distributed.init_process_group(backend="nccl") + else: + rank = int(0) + + # Setup + tools.set_seeds(args.seed) + tools.setup_logger(level=args.log_level, tag=tag, directory=args.log_dir, rank=rank) + logging.info("===========VERIFYING SETTINGS===========") + for message, loglevel in input_log_messages: + logging.log(level=loglevel, msg=message) + + if args.distributed: + torch.cuda.set_device(local_rank) + logging.info(f"Process group initialized: {torch.distributed.is_initialized()}") + logging.info(f"Processes: {world_size}") + + try: + logging.info(f"MACE version: {mace.__version__}") + except AttributeError: + logging.info("Cannot find MACE version, please install MACE via pip") + logging.debug(f"Configuration: {args}") + + tools.set_default_dtype(args.default_dtype) + device = tools.init_device(args.device) + commit = print_git_commit() + model_foundation: Optional[torch.nn.Module] = None + if args.foundation_model is not None: + if args.foundation_model in ["small", "medium", "large"]: + logging.info( + f"Using foundation model mace-mp-0 {args.foundation_model} as initial checkpoint." + ) + calc = mace_mp( + model=args.foundation_model, + device=args.device, + default_dtype=args.default_dtype, + ) + model_foundation = calc.models[0] + elif args.foundation_model in ["small_off", "medium_off", "large_off"]: + model_type = args.foundation_model.split("_")[0] + logging.info( + f"Using foundation model mace-off-2023 {model_type} as initial checkpoint. ASL license." + ) + calc = mace_off( + model=model_type, + device=args.device, + default_dtype=args.default_dtype, + ) + model_foundation = calc.models[0] + else: + model_foundation = torch.load( + args.foundation_model, map_location=args.device + ) + logging.info( + f"Using foundation model {args.foundation_model} as initial checkpoint." + ) + args.r_max = model_foundation.r_max.item() + if ( + args.foundation_model not in ["small", "medium", "large"] + and args.pt_train_file is None + ): + logging.warning( + "Using multiheads finetuning with a foundation model that is not a Materials Project model, need to provied a path to a pretraining file with --pt_train_file." + ) + args.multiheads_finetuning = False + if args.multiheads_finetuning: + assert ( + args.E0s != "average" + ), "average atomic energies cannot be used for multiheads finetuning" + # check that the foundation model has a single head, if not, use the first head + if hasattr(model_foundation, "heads"): + if len(model_foundation.heads) > 1: + logging.warning( + "Mutlihead finetuning with models with more than one head is not supported, using the first head as foundation head." + ) + model_foundation = remove_pt_head( + model_foundation, args.foundation_head + ) + else: + args.multiheads_finetuning = False + + if args.heads is not None: + args.heads = ast.literal_eval(args.heads) + else: + args.heads = prepare_default_head(args) + + logging.info("===========LOADING INPUT DATA===========") + heads = list(args.heads.keys()) + logging.info(f"Using heads: {heads}") + head_configs: List[HeadConfig] = [] + for head, head_args in args.heads.items(): + logging.info(f"============= Processing head {head} ===========") + head_config = dict_head_to_dataclass(head_args, head, args) + if head_config.statistics_file is not None: + with open(head_config.statistics_file, "r") as f: # pylint: disable=W1514 + statistics = json.load(f) + logging.info("Using statistics json file") + head_config.r_max = ( + statistics["r_max"] if args.foundation_model is None else args.r_max + ) + head_config.atomic_numbers = statistics["atomic_numbers"] + head_config.mean = statistics["mean"] + head_config.std = statistics["std"] + head_config.avg_num_neighbors = statistics["avg_num_neighbors"] + head_config.compute_avg_num_neighbors = False + if isinstance(statistics["atomic_energies"], str) and statistics[ + "atomic_energies" + ].endswith(".json"): + with open(statistics["atomic_energies"], "r", encoding="utf-8") as f: + atomic_energies = json.load(f) + head_config.E0s = atomic_energies + head_config.atomic_energies_dict = ast.literal_eval(atomic_energies) + else: + head_config.E0s = statistics["atomic_energies"] + head_config.atomic_energies_dict = ast.literal_eval( + statistics["atomic_energies"] + ) + + # Data preparation + if check_path_ase_read(head_config.train_file): + if head_config.valid_file is not None: + assert check_path_ase_read( + head_config.valid_file + ), "valid_file if given must be same format as train_file" + config_type_weights = get_config_type_weights( + head_config.config_type_weights + ) + collections, atomic_energies_dict = get_dataset_from_xyz( + work_dir=args.work_dir, + train_path=head_config.train_file, + valid_path=head_config.valid_file, + valid_fraction=head_config.valid_fraction, + config_type_weights=config_type_weights, + test_path=head_config.test_file, + seed=args.seed, + energy_key=head_config.energy_key, + forces_key=head_config.forces_key, + stress_key=head_config.stress_key, + virials_key=head_config.virials_key, + dipole_key=head_config.dipole_key, + charges_key=head_config.charges_key, + head_name=head_config.head_name, + keep_isolated_atoms=head_config.keep_isolated_atoms, + ) + head_config.collections = collections + head_config.atomic_energies_dict = atomic_energies_dict + logging.info( + f"Total number of configurations: train={len(collections.train)}, valid={len(collections.valid)}, " + f"tests=[{', '.join([name + ': ' + str(len(test_configs)) for name, test_configs in collections.tests])}]," + ) + head_configs.append(head_config) + + if all(check_path_ase_read(head_config.train_file) for head_config in head_configs): + size_collections_train = sum( + len(head_config.collections.train) for head_config in head_configs + ) + size_collections_valid = sum( + len(head_config.collections.valid) for head_config in head_configs + ) + if size_collections_train < args.batch_size: + logging.error( + f"Batch size ({args.batch_size}) is larger than the number of training data ({size_collections_train})" + ) + if size_collections_valid < args.valid_batch_size: + logging.warning( + f"Validation batch size ({args.valid_batch_size}) is larger than the number of validation data ({size_collections_valid})" + ) + + if args.multiheads_finetuning: + logging.info( + "==================Using multiheads finetuning mode==================" + ) + args.loss = "universal" + if ( + args.foundation_model in ["small", "medium", "large"] + or args.pt_train_file is None + ): + logging.info( + "Using foundation model for multiheads finetuning with Materials Project data" + ) + heads = list(dict.fromkeys(["pt_head"] + heads)) + head_config_pt = HeadConfig( + head_name="pt_head", + E0s="foundation", + statistics_file=args.statistics_file, + compute_avg_num_neighbors=False, + avg_num_neighbors=model_foundation.interactions[0].avg_num_neighbors, + ) + collections = assemble_mp_data(args, tag, head_configs) + head_config_pt.collections = collections + head_config_pt.train_file = f"mp_finetuning-{tag}.xyz" + head_configs.append(head_config_pt) + else: + logging.info( + f"Using foundation model for multiheads finetuning with {args.pt_train_file}" + ) + heads = list(dict.fromkeys(["pt_head"] + heads)) + collections, atomic_energies_dict = get_dataset_from_xyz( + work_dir=args.work_dir, + train_path=args.pt_train_file, + valid_path=args.pt_valid_file, + valid_fraction=args.valid_fraction, + config_type_weights=None, + test_path=None, + seed=args.seed, + energy_key=args.energy_key, + forces_key=args.forces_key, + stress_key=args.stress_key, + virials_key=args.virials_key, + dipole_key=args.dipole_key, + charges_key=args.charges_key, + head_name="pt_head", + keep_isolated_atoms=args.keep_isolated_atoms, + ) + head_config_pt = HeadConfig( + head_name="pt_head", + train_file=args.pt_train_file, + valid_file=args.pt_valid_file, + E0s="foundation", + statistics_file=args.statistics_file, + valid_fraction=args.valid_fraction, + config_type_weights=None, + energy_key=args.energy_key, + forces_key=args.forces_key, + stress_key=args.stress_key, + virials_key=args.virials_key, + dipole_key=args.dipole_key, + charges_key=args.charges_key, + keep_isolated_atoms=args.keep_isolated_atoms, + collections=collections, + avg_num_neighbors=model_foundation.interactions[0].avg_num_neighbors, + compute_avg_num_neighbors=False, + ) + head_config_pt.collections = collections + head_configs.append(head_config_pt) + logging.info( + f"Total number of configurations: train={len(collections.train)}, valid={len(collections.valid)}" + ) + + # Atomic number table + # yapf: disable + for head_config in head_configs: + if head_config.atomic_numbers is None: + assert check_path_ase_read(head_config.train_file), "Must specify atomic_numbers when using .h5 train_file input" + z_table_head = tools.get_atomic_number_table_from_zs( + z + for configs in (head_config.collections.train, head_config.collections.valid) + for config in configs + for z in config.atomic_numbers + ) + head_config.atomic_numbers = z_table_head.zs + head_config.z_table = z_table_head + else: + if head_config.statistics_file is None: + logging.info("Using atomic numbers from command line argument") + else: + logging.info("Using atomic numbers from statistics file") + zs_list = ast.literal_eval(head_config.atomic_numbers) + assert isinstance(zs_list, list) + z_table_head = tools.AtomicNumberTable(zs_list) + head_config.atomic_numbers = zs_list + head_config.z_table = z_table_head + # yapf: enable + all_atomic_numbers = set() + for head_config in head_configs: + all_atomic_numbers.update(head_config.atomic_numbers) + z_table = AtomicNumberTable(sorted(list(all_atomic_numbers))) + logging.info(f"Atomic Numbers used: {z_table.zs}") + + # Atomic energies + atomic_energies_dict = {} + for head_config in head_configs: + if head_config.atomic_energies_dict is None or len(head_config.atomic_energies_dict) == 0: + assert head_config.E0s is not None, "Atomic energies must be provided" + if check_path_ase_read(head_config.train_file) and head_config.E0s.lower() != "foundation": + atomic_energies_dict[head_config.head_name] = get_atomic_energies( + head_config.E0s, head_config.collections.train, head_config.z_table + ) + elif head_config.E0s.lower() == "foundation": + assert args.foundation_model is not None + z_table_foundation = AtomicNumberTable( + [int(z) for z in model_foundation.atomic_numbers] + ) + foundation_atomic_energies = model_foundation.atomic_energies_fn.atomic_energies + if foundation_atomic_energies.ndim > 1: + foundation_atomic_energies = foundation_atomic_energies.squeeze() + if foundation_atomic_energies.ndim == 2: + foundation_atomic_energies = foundation_atomic_energies[0] + logging.info("Foundation model has multiple heads, using the first head as foundation E0s.") + atomic_energies_dict[head_config.head_name] = { + z: foundation_atomic_energies[ + z_table_foundation.z_to_index(z) + ].item() + for z in z_table.zs + } + else: + atomic_energies_dict[head_config.head_name] = get_atomic_energies(head_config.E0s, None, head_config.z_table) + else: + atomic_energies_dict[head_config.head_name] = head_config.atomic_energies_dict + + # Atomic energies for multiheads finetuning + if args.multiheads_finetuning: + assert ( + model_foundation is not None + ), "Model foundation must be provided for multiheads finetuning" + z_table_foundation = AtomicNumberTable( + [int(z) for z in model_foundation.atomic_numbers] + ) + foundation_atomic_energies = model_foundation.atomic_energies_fn.atomic_energies + if foundation_atomic_energies.ndim > 1: + foundation_atomic_energies = foundation_atomic_energies.squeeze() + if foundation_atomic_energies.ndim == 2: + foundation_atomic_energies = foundation_atomic_energies[0] + logging.info("Foundation model has multiple heads, using the first head as foundation E0s.") + atomic_energies_dict["pt_head"] = { + z: foundation_atomic_energies[ + z_table_foundation.z_to_index(z) + ].item() + for z in z_table.zs + } + + if args.model == "AtomicDipolesMACE": + atomic_energies = None + dipole_only = True + args.compute_dipole = True + args.compute_energy = False + args.compute_forces = False + args.compute_virials = False + args.compute_stress = False + else: + dipole_only = False + if args.model == "EnergyDipolesMACE": + args.compute_dipole = True + args.compute_energy = True + args.compute_forces = True + args.compute_virials = False + args.compute_stress = False + else: + args.compute_energy = True + args.compute_dipole = False + # atomic_energies: np.ndarray = np.array( + # [atomic_energies_dict[z] for z in z_table.zs] + # ) + atomic_energies = dict_to_array(atomic_energies_dict, heads) + for head_config in head_configs: + try: + logging.info(f"Atomic Energies used (z: eV) for head {head_config.head_name}: " + "{" + ", ".join([f"{z}: {atomic_energies_dict[head_config.head_name][z]}" for z in head_config.z_table.zs]) + "}") + except KeyError as e: + raise KeyError(f"Atomic number {e} not found in atomic_energies_dict for head {head_config.head_name}, add E0s for this atomic number") from e + + + valid_sets = {head: [] for head in heads} + train_sets = {head: [] for head in heads} + for head_config in head_configs: + if check_path_ase_read(head_config.train_file): + train_sets[head_config.head_name] = [ + data.AtomicData.from_config( + config, z_table=z_table, cutoff=args.r_max, heads=heads + ) + for config in head_config.collections.train + ] + valid_sets[head_config.head_name] = [ + data.AtomicData.from_config( + config, z_table=z_table, cutoff=args.r_max, heads=heads + ) + for config in head_config.collections.valid + ] + + elif head_config.train_file.endswith(".h5"): + train_sets[head_config.head_name] = data.HDF5Dataset( + head_config.train_file, r_max=args.r_max, z_table=z_table, heads=heads, head=head_config.head_name + ) + valid_sets[head_config.head_name] = data.HDF5Dataset( + head_config.valid_file, r_max=args.r_max, z_table=z_table, heads=heads, head=head_config.head_name + ) + else: # This case would be for when the file path is to a directory of multiple .h5 files + train_sets[head_config.head_name] = data.dataset_from_sharded_hdf5( + head_config.train_file, r_max=args.r_max, z_table=z_table, heads=heads, head=head_config.head_name + ) + valid_sets[head_config.head_name] = data.dataset_from_sharded_hdf5( + head_config.valid_file, r_max=args.r_max, z_table=z_table, heads=heads, head=head_config.head_name + ) + train_loader_head = torch_geometric.dataloader.DataLoader( + dataset=train_sets[head_config.head_name], + batch_size=args.batch_size, + shuffle=True, + drop_last=True, + pin_memory=args.pin_memory, + num_workers=args.num_workers, + generator=torch.Generator().manual_seed(args.seed), + ) + head_config.train_loader = train_loader_head + # concatenate all the trainsets + train_set = ConcatDataset([train_sets[head] for head in heads]) + train_sampler, valid_sampler = None, None + if args.distributed: + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_set, + num_replicas=world_size, + rank=rank, + shuffle=True, + drop_last=True, + seed=args.seed, + ) + valid_samplers = {} + for head, valid_set in valid_sets.items(): + valid_sampler = torch.utils.data.distributed.DistributedSampler( + valid_set, + num_replicas=world_size, + rank=rank, + shuffle=True, + drop_last=True, + seed=args.seed, + ) + valid_samplers[head] = valid_sampler + train_loader = torch_geometric.dataloader.DataLoader( + dataset=train_set, + batch_size=args.batch_size, + sampler=train_sampler, + shuffle=(train_sampler is None), + drop_last=(train_sampler is None), + pin_memory=args.pin_memory, + num_workers=args.num_workers, + generator=torch.Generator().manual_seed(args.seed), + ) + valid_loaders = {heads[i]: None for i in range(len(heads))} + if not isinstance(valid_sets, dict): + valid_sets = {"Default": valid_sets} + for head, valid_set in valid_sets.items(): + valid_loaders[head] = torch_geometric.dataloader.DataLoader( + dataset=valid_set, + batch_size=args.valid_batch_size, + sampler=valid_samplers[head] if args.distributed else None, + shuffle=False, + drop_last=False, + pin_memory=args.pin_memory, + num_workers=args.num_workers, + generator=torch.Generator().manual_seed(args.seed), + ) + + loss_fn = get_loss_fn(args, dipole_only, args.compute_dipole) + args.avg_num_neighbors = get_avg_num_neighbors(head_configs, args, train_loader, device) + + # Model + model, output_args = configure_model(args, train_loader, atomic_energies, model_foundation, heads, z_table) + model.to(device) + + logging.debug(model) + logging.info(f"Total number of parameters: {tools.count_parameters(model)}") + logging.info("") + logging.info("===========OPTIMIZER INFORMATION===========") + logging.info(f"Using {args.optimizer.upper()} as parameter optimizer") + logging.info(f"Batch size: {args.batch_size}") + if args.ema: + logging.info(f"Using Exponential Moving Average with decay: {args.ema_decay}") + logging.info( + f"Number of gradient updates: {int(args.max_num_epochs*len(train_set)/args.batch_size)}" + ) + logging.info(f"Learning rate: {args.lr}, weight decay: {args.weight_decay}") + logging.info(loss_fn) + + # Optimizer + param_options = get_params_options(args, model) + optimizer: torch.optim.Optimizer + optimizer = get_optimizer(args, param_options) + if args.device == "xpu": + logging.info("Optimzing model and optimzier for XPU") + model, optimizer = ipex.optimize(model, optimizer=optimizer) + logger = tools.MetricsLogger( + directory=args.results_dir, tag=tag + "_train" + ) # pylint: disable=E1123 + + lr_scheduler = LRScheduler(optimizer, args) + + swa: Optional[tools.SWAContainer] = None + swas = [False] + if args.swa: + swa, swas = get_swa(args, model, optimizer, swas, dipole_only) + + checkpoint_handler = tools.CheckpointHandler( + directory=args.checkpoints_dir, + tag=tag, + keep=args.keep_checkpoints, + swa_start=args.start_swa, + ) + + start_epoch = 0 + if args.restart_latest: + try: + opt_start_epoch = checkpoint_handler.load_latest( + state=tools.CheckpointState(model, optimizer, lr_scheduler), + swa=True, + device=device, + ) + except Exception: # pylint: disable=W0703 + opt_start_epoch = checkpoint_handler.load_latest( + state=tools.CheckpointState(model, optimizer, lr_scheduler), + swa=False, + device=device, + ) + if opt_start_epoch is not None: + start_epoch = opt_start_epoch + + ema: Optional[ExponentialMovingAverage] = None + if args.ema: + ema = ExponentialMovingAverage(model.parameters(), decay=args.ema_decay) + else: + for group in optimizer.param_groups: + group["lr"] = args.lr + + if args.wandb: + setup_wandb(args) + + if args.distributed: + distributed_model = DDP(model, device_ids=[local_rank]) + else: + distributed_model = None + tools.train( + model=model, + loss_fn=loss_fn, + train_loader=train_loader, + valid_loaders=valid_loaders, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + checkpoint_handler=checkpoint_handler, + eval_interval=args.eval_interval, + start_epoch=start_epoch, + max_num_epochs=args.max_num_epochs, + logger=logger, + patience=args.patience, + save_all_checkpoints=args.save_all_checkpoints, + output_args=output_args, + device=device, + swa=swa, + ema=ema, + max_grad_norm=args.clip_grad, + log_errors=args.error_table, + log_wandb=args.wandb, + distributed=args.distributed, + distributed_model=distributed_model, + train_sampler=train_sampler, + rank=rank, + ) + + logging.info("") + logging.info("===========RESULTS===========") + logging.info("Computing metrics for training, validation, and test sets") + + train_valid_data_loader = {} + for head_config in head_configs: + data_loader_name = "train_" + head_config.head_name + train_valid_data_loader[data_loader_name] = head_config.train_loader + for head, valid_loader in valid_loaders.items(): + data_load_name = "valid_" + head + train_valid_data_loader[data_load_name] = valid_loader + + test_sets = {} + stop_first_test = False + test_data_loader = {} + if all( + head_config.test_file == head_configs[0].test_file + for head_config in head_configs + ) and head_configs[0].test_file is not None: + stop_first_test = True + if all( + head_config.test_dir == head_configs[0].test_dir + for head_config in head_configs + ) and head_configs[0].test_dir is not None: + stop_first_test = True + for head_config in head_configs: + if check_path_ase_read(head_config.train_file): + for name, subset in head_config.collections.tests: + test_sets[name] = [ + data.AtomicData.from_config( + config, z_table=z_table, cutoff=args.r_max, heads=heads + ) + for config in subset + ] + if head_config.test_dir is not None: + if not args.multi_processed_test: + test_files = get_files_with_suffix(head_config.test_dir, "_test.h5") + for test_file in test_files: + name = os.path.splitext(os.path.basename(test_file))[0] + test_sets[name] = data.HDF5Dataset( + test_file, r_max=args.r_max, z_table=z_table, heads=heads, head=head_config.head_name + ) + else: + test_folders = glob(head_config.test_dir + "/*") + for folder in test_folders: + name = os.path.splitext(os.path.basename(test_file))[0] + test_sets[name] = data.dataset_from_sharded_hdf5( + folder, r_max=args.r_max, z_table=z_table, heads=heads, head=head_config.head_name + ) + for test_name, test_set in test_sets.items(): + test_sampler = None + if args.distributed: + test_sampler = torch.utils.data.distributed.DistributedSampler( + test_set, + num_replicas=world_size, + rank=rank, + shuffle=True, + drop_last=True, + seed=args.seed, + ) + try: + drop_last = test_set.drop_last + except AttributeError as e: # pylint: disable=W0612 + drop_last = False + test_loader = torch_geometric.dataloader.DataLoader( + test_set, + batch_size=args.valid_batch_size, + shuffle=(test_sampler is None), + drop_last=drop_last, + num_workers=args.num_workers, + pin_memory=args.pin_memory, + ) + test_data_loader[test_name] = test_loader + if stop_first_test: + break + + for swa_eval in swas: + epoch = checkpoint_handler.load_latest( + state=tools.CheckpointState(model, optimizer, lr_scheduler), + swa=swa_eval, + device=device, + ) + model.to(device) + if args.distributed: + distributed_model = DDP(model, device_ids=[local_rank]) + model_to_evaluate = model if not args.distributed else distributed_model + if swa_eval: + logging.info(f"Loaded Stage two model from epoch {epoch} for evaluation") + else: + logging.info(f"Loaded Stage one model from epoch {epoch} for evaluation") + + for param in model.parameters(): + param.requires_grad = False + table_train_valid = create_error_table( + table_type=args.error_table, + all_data_loaders=train_valid_data_loader, + model=model_to_evaluate, + loss_fn=loss_fn, + output_args=output_args, + log_wandb=args.wandb, + device=device, + distributed=args.distributed, + ) + logging.info("Error-table on TRAIN and VALID:\n" + str(table_train_valid)) + + if test_data_loader: + table_test = create_error_table( + table_type=args.error_table, + all_data_loaders=test_data_loader, + model=model_to_evaluate, + loss_fn=loss_fn, + output_args=output_args, + log_wandb=args.wandb, + device=device, + distributed=args.distributed, + ) + logging.info("Error-table on TEST:\n" + str(table_test)) + + if rank == 0: + # Save entire model + if swa_eval: + model_path = Path(args.checkpoints_dir) / (tag + "_stagetwo.model") + else: + model_path = Path(args.checkpoints_dir) / (tag + ".model") + logging.info(f"Saving model to {model_path}") + if args.save_cpu: + model = model.to("cpu") + torch.save(model, model_path) + extra_files = { + "commit.txt": commit.encode("utf-8") if commit is not None else b"", + "config.yaml": json.dumps( + convert_to_json_format(extract_config_mace_model(model)) + ), + } + if swa_eval: + torch.save( + model, Path(args.model_dir) / (args.name + "_stagetwo.model") + ) + try: + path_complied = Path(args.model_dir) / ( + args.name + "_stagetwo_compiled.model" + ) + logging.info(f"Compiling model, saving metadata {path_complied}") + model_compiled = jit.compile(deepcopy(model)) + torch.jit.save( + model_compiled, + path_complied, + _extra_files=extra_files, + ) + except Exception as e: # pylint: disable=W0703 + pass + else: + torch.save(model, Path(args.model_dir) / (args.name + ".model")) + try: + path_complied = Path(args.model_dir) / ( + args.name + "_compiled.model" + ) + logging.info(f"Compiling model, saving metadata to {path_complied}") + model_compiled = jit.compile(deepcopy(model)) + torch.jit.save( + model_compiled, + path_complied, + _extra_files=extra_files, + ) + except Exception as e: # pylint: disable=W0703 + pass + + if args.distributed: + torch.distributed.barrier() + + logging.info("Done") + if args.distributed: + torch.distributed.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/mace/cli/select_head.py b/mace/cli/select_head.py new file mode 100644 index 00000000..a1e27229 --- /dev/null +++ b/mace/cli/select_head.py @@ -0,0 +1,33 @@ +from argparse import ArgumentParser + +import torch + +from mace.tools.scripts_utils import remove_pt_head + + +def main(): + parser = ArgumentParser() + parser.add_argument( + "--head_name", + "-n", + help="name of the head to extract", + default=None, + ) + parser.add_argument( + "--output_file", + "-o", + help="name for output model, defaults to model_file.target_device", + ) + parser.add_argument("model_file", help="input model file path") + args = parser.parse_args() + + if args.output_file is None: + args.output_file = args.model_file + "." + args.target_device + + model = torch.load(args.model_file) + model_single = remove_pt_head(model, args.head_name) + torch.save(model_single, args.output_file) + + +if __name__ == "__main__": + main() diff --git a/mace/data/__init__.py b/mace/data/__init__.py new file mode 100644 index 00000000..c10a3698 --- /dev/null +++ b/mace/data/__init__.py @@ -0,0 +1,34 @@ +from .atomic_data import AtomicData +from .hdf5_dataset import HDF5Dataset, dataset_from_sharded_hdf5 +from .neighborhood import get_neighborhood +from .utils import ( + Configuration, + Configurations, + compute_average_E0s, + config_from_atoms, + config_from_atoms_list, + load_from_xyz, + random_train_valid_split, + save_AtomicData_to_HDF5, + save_configurations_as_HDF5, + save_dataset_as_HDF5, + test_config_types, +) + +__all__ = [ + "get_neighborhood", + "Configuration", + "Configurations", + "random_train_valid_split", + "load_from_xyz", + "test_config_types", + "config_from_atoms", + "config_from_atoms_list", + "AtomicData", + "compute_average_E0s", + "save_dataset_as_HDF5", + "HDF5Dataset", + "dataset_from_sharded_hdf5", + "save_AtomicData_to_HDF5", + "save_configurations_as_HDF5", +] diff --git a/mace/data/atomic_data.py b/mace/data/atomic_data.py new file mode 100644 index 00000000..cb4edd94 --- /dev/null +++ b/mace/data/atomic_data.py @@ -0,0 +1,241 @@ +########################################################################################### +# Atomic Data Class for handling molecules as graphs +# Authors: Ilyes Batatia, Gregor Simm +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +from typing import Optional, Sequence + +import torch.utils.data + +from mace.tools import ( + AtomicNumberTable, + atomic_numbers_to_indices, + to_one_hot, + torch_geometric, + voigt_to_matrix, +) + +from .neighborhood import get_neighborhood +from .utils import Configuration + + +class AtomicData(torch_geometric.data.Data): + num_graphs: torch.Tensor + batch: torch.Tensor + edge_index: torch.Tensor + node_attrs: torch.Tensor + edge_vectors: torch.Tensor + edge_lengths: torch.Tensor + positions: torch.Tensor + shifts: torch.Tensor + unit_shifts: torch.Tensor + cell: torch.Tensor + forces: torch.Tensor + energy: torch.Tensor + stress: torch.Tensor + virials: torch.Tensor + dipole: torch.Tensor + charges: torch.Tensor + weight: torch.Tensor + energy_weight: torch.Tensor + forces_weight: torch.Tensor + stress_weight: torch.Tensor + virials_weight: torch.Tensor + + def __init__( + self, + edge_index: torch.Tensor, # [2, n_edges] + node_attrs: torch.Tensor, # [n_nodes, n_node_feats] + positions: torch.Tensor, # [n_nodes, 3] + shifts: torch.Tensor, # [n_edges, 3], + unit_shifts: torch.Tensor, # [n_edges, 3] + cell: Optional[torch.Tensor], # [3,3] + weight: Optional[torch.Tensor], # [,] + head: Optional[torch.Tensor], # [,] + energy_weight: Optional[torch.Tensor], # [,] + forces_weight: Optional[torch.Tensor], # [,] + stress_weight: Optional[torch.Tensor], # [,] + virials_weight: Optional[torch.Tensor], # [,] + forces: Optional[torch.Tensor], # [n_nodes, 3] + energy: Optional[torch.Tensor], # [, ] + stress: Optional[torch.Tensor], # [1,3,3] + virials: Optional[torch.Tensor], # [1,3,3] + dipole: Optional[torch.Tensor], # [, 3] + charges: Optional[torch.Tensor], # [n_nodes, ] + ): + # Check shapes + num_nodes = node_attrs.shape[0] + + assert edge_index.shape[0] == 2 and len(edge_index.shape) == 2 + assert positions.shape == (num_nodes, 3) + assert shifts.shape[1] == 3 + assert unit_shifts.shape[1] == 3 + assert len(node_attrs.shape) == 2 + assert weight is None or len(weight.shape) == 0 + assert head is None or len(head.shape) == 0 + assert energy_weight is None or len(energy_weight.shape) == 0 + assert forces_weight is None or len(forces_weight.shape) == 0 + assert stress_weight is None or len(stress_weight.shape) == 0 + assert virials_weight is None or len(virials_weight.shape) == 0 + assert cell is None or cell.shape == (3, 3) + assert forces is None or forces.shape == (num_nodes, 3) + assert energy is None or len(energy.shape) == 0 + assert stress is None or stress.shape == (1, 3, 3) + assert virials is None or virials.shape == (1, 3, 3) + assert dipole is None or dipole.shape[-1] == 3 + assert charges is None or charges.shape == (num_nodes,) + # Aggregate data + data = { + "num_nodes": num_nodes, + "edge_index": edge_index, + "positions": positions, + "shifts": shifts, + "unit_shifts": unit_shifts, + "cell": cell, + "node_attrs": node_attrs, + "weight": weight, + "head": head, + "energy_weight": energy_weight, + "forces_weight": forces_weight, + "stress_weight": stress_weight, + "virials_weight": virials_weight, + "forces": forces, + "energy": energy, + "stress": stress, + "virials": virials, + "dipole": dipole, + "charges": charges, + } + super().__init__(**data) + + @classmethod + def from_config( + cls, + config: Configuration, + z_table: AtomicNumberTable, + cutoff: float, + heads: Optional[list] = None, + ) -> "AtomicData": + if heads is None: + heads = ["default"] + edge_index, shifts, unit_shifts, cell = get_neighborhood( + positions=config.positions, cutoff=cutoff, pbc=config.pbc, cell=config.cell + ) + indices = atomic_numbers_to_indices(config.atomic_numbers, z_table=z_table) + one_hot = to_one_hot( + torch.tensor(indices, dtype=torch.long).unsqueeze(-1), + num_classes=len(z_table), + ) + try: + head = torch.tensor(heads.index(config.head), dtype=torch.long) + except ValueError: + head = torch.tensor(len(heads) - 1, dtype=torch.long) + + cell = ( + torch.tensor(cell, dtype=torch.get_default_dtype()) + if cell is not None + else torch.tensor( + 3 * [0.0, 0.0, 0.0], dtype=torch.get_default_dtype() + ).view(3, 3) + ) + + weight = ( + torch.tensor(config.weight, dtype=torch.get_default_dtype()) + if config.weight is not None + else 1 + ) + + energy_weight = ( + torch.tensor(config.energy_weight, dtype=torch.get_default_dtype()) + if config.energy_weight is not None + else 1 + ) + + forces_weight = ( + torch.tensor(config.forces_weight, dtype=torch.get_default_dtype()) + if config.forces_weight is not None + else 1 + ) + + stress_weight = ( + torch.tensor(config.stress_weight, dtype=torch.get_default_dtype()) + if config.stress_weight is not None + else 1 + ) + + virials_weight = ( + torch.tensor(config.virials_weight, dtype=torch.get_default_dtype()) + if config.virials_weight is not None + else 1 + ) + + forces = ( + torch.tensor(config.forces, dtype=torch.get_default_dtype()) + if config.forces is not None + else None + ) + energy = ( + torch.tensor(config.energy, dtype=torch.get_default_dtype()) + if config.energy is not None + else None + ) + stress = ( + voigt_to_matrix( + torch.tensor(config.stress, dtype=torch.get_default_dtype()) + ).unsqueeze(0) + if config.stress is not None + else None + ) + virials = ( + voigt_to_matrix( + torch.tensor(config.virials, dtype=torch.get_default_dtype()) + ).unsqueeze(0) + if config.virials is not None + else None + ) + dipole = ( + torch.tensor(config.dipole, dtype=torch.get_default_dtype()).unsqueeze(0) + if config.dipole is not None + else None + ) + charges = ( + torch.tensor(config.charges, dtype=torch.get_default_dtype()) + if config.charges is not None + else None + ) + + return cls( + edge_index=torch.tensor(edge_index, dtype=torch.long), + positions=torch.tensor(config.positions, dtype=torch.get_default_dtype()), + shifts=torch.tensor(shifts, dtype=torch.get_default_dtype()), + unit_shifts=torch.tensor(unit_shifts, dtype=torch.get_default_dtype()), + cell=cell, + node_attrs=one_hot, + weight=weight, + head=head, + energy_weight=energy_weight, + forces_weight=forces_weight, + stress_weight=stress_weight, + virials_weight=virials_weight, + forces=forces, + energy=energy, + stress=stress, + virials=virials, + dipole=dipole, + charges=charges, + ) + + +def get_data_loader( + dataset: Sequence[AtomicData], + batch_size: int, + shuffle=True, + drop_last=False, +) -> torch.utils.data.DataLoader: + return torch_geometric.dataloader.DataLoader( + dataset=dataset, + batch_size=batch_size, + shuffle=shuffle, + drop_last=drop_last, + ) diff --git a/mace/data/hdf5_dataset.py b/mace/data/hdf5_dataset.py new file mode 100644 index 00000000..477ccd3f --- /dev/null +++ b/mace/data/hdf5_dataset.py @@ -0,0 +1,93 @@ +from glob import glob +from typing import List + +import h5py +from torch.utils.data import ConcatDataset, Dataset + +from mace.data.atomic_data import AtomicData +from mace.data.utils import Configuration +from mace.tools.utils import AtomicNumberTable + + +class HDF5Dataset(Dataset): + def __init__(self, file_path, r_max, z_table, **kwargs): + super(HDF5Dataset, self).__init__() # pylint: disable=super-with-arguments + self.file_path = file_path + self._file = None + batch_key = list(self.file.keys())[0] + self.batch_size = len(self.file[batch_key].keys()) + self.length = len(self.file.keys()) * self.batch_size + self.r_max = r_max + self.z_table = z_table + try: + self.drop_last = bool(self.file.attrs["drop_last"]) + except KeyError: + self.drop_last = False + self.kwargs = kwargs + + @property + def file(self): + if self._file is None: + # If a file has not already been opened, open one here + self._file = h5py.File(self.file_path, "r") + return self._file + + def __getstate__(self): + _d = dict(self.__dict__) + + # An opened h5py.File cannot be pickled, so we must exclude it from the state + _d["_file"] = None + return _d + + def __len__(self): + return self.length + + def __getitem__(self, index): + # compute the index of the batch + batch_index = index // self.batch_size + config_index = index % self.batch_size + grp = self.file["config_batch_" + str(batch_index)] + subgrp = grp["config_" + str(config_index)] + config = Configuration( + atomic_numbers=subgrp["atomic_numbers"][()], + positions=subgrp["positions"][()], + energy=unpack_value(subgrp["energy"][()]), + forces=unpack_value(subgrp["forces"][()]), + stress=unpack_value(subgrp["stress"][()]), + virials=unpack_value(subgrp["virials"][()]), + dipole=unpack_value(subgrp["dipole"][()]), + charges=unpack_value(subgrp["charges"][()]), + weight=unpack_value(subgrp["weight"][()]), + energy_weight=unpack_value(subgrp["energy_weight"][()]), + forces_weight=unpack_value(subgrp["forces_weight"][()]), + stress_weight=unpack_value(subgrp["stress_weight"][()]), + virials_weight=unpack_value(subgrp["virials_weight"][()]), + config_type=unpack_value(subgrp["config_type"][()]), + pbc=unpack_value(subgrp["pbc"][()]), + cell=unpack_value(subgrp["cell"][()]), + ) + if config.head is None: + config.head = self.kwargs.get("head") + atomic_data = AtomicData.from_config( + config, + z_table=self.z_table, + cutoff=self.r_max, + heads=self.kwargs.get("heads", ["Default"]), + ) + return atomic_data + + +def dataset_from_sharded_hdf5( + files: List, z_table: AtomicNumberTable, r_max: float, **kwargs +): + files = glob(files + "/*") + datasets = [] + for file in files: + datasets.append(HDF5Dataset(file, z_table=z_table, r_max=r_max, **kwargs)) + full_dataset = ConcatDataset(datasets) + return full_dataset + + +def unpack_value(value): + value = value.decode("utf-8") if isinstance(value, bytes) else value + return None if str(value) == "None" else value diff --git a/mace/data/neighborhood.py b/mace/data/neighborhood.py new file mode 100644 index 00000000..21296fa6 --- /dev/null +++ b/mace/data/neighborhood.py @@ -0,0 +1,66 @@ +from typing import Optional, Tuple + +import numpy as np +from matscipy.neighbours import neighbour_list + + +def get_neighborhood( + positions: np.ndarray, # [num_positions, 3] + cutoff: float, + pbc: Optional[Tuple[bool, bool, bool]] = None, + cell: Optional[np.ndarray] = None, # [3, 3] + true_self_interaction=False, +) -> Tuple[np.ndarray, np.ndarray]: + if pbc is None: + pbc = (False, False, False) + + if cell is None or cell.any() == np.zeros((3, 3)).any(): + cell = np.identity(3, dtype=float) + + assert len(pbc) == 3 and all(isinstance(i, (bool, np.bool_)) for i in pbc) + assert cell.shape == (3, 3) + + pbc_x = pbc[0] + pbc_y = pbc[1] + pbc_z = pbc[2] + identity = np.identity(3, dtype=float) + max_positions = np.max(np.absolute(positions)) + 1 + # Extend cell in non-periodic directions + # For models with more than 5 layers, the multiplicative constant needs to be increased. + # temp_cell = np.copy(cell) + if not pbc_x: + cell[0, :] = max_positions * 5 * cutoff * identity[0, :] + if not pbc_y: + cell[1, :] = max_positions * 5 * cutoff * identity[1, :] + if not pbc_z: + cell[2, :] = max_positions * 5 * cutoff * identity[2, :] + + sender, receiver, unit_shifts = neighbour_list( + quantities="ijS", + pbc=pbc, + cell=cell, + positions=positions, + cutoff=cutoff, + # self_interaction=True, # we want edges from atom to itself in different periodic images + # use_scaled_positions=False, # positions are not scaled positions + ) + + if not true_self_interaction: + # Eliminate self-edges that don't cross periodic boundaries + true_self_edge = sender == receiver + true_self_edge &= np.all(unit_shifts == 0, axis=1) + keep_edge = ~true_self_edge + + # Note: after eliminating self-edges, it can be that no edges remain in this system + sender = sender[keep_edge] + receiver = receiver[keep_edge] + unit_shifts = unit_shifts[keep_edge] + + # Build output + edge_index = np.stack((sender, receiver)) # [2, n_edges] + + # From the docs: With the shift vector S, the distances D between atoms can be computed from + # D = positions[j]-positions[i]+S.dot(cell) + shifts = np.dot(unit_shifts, cell) # [n_edges, 3] + + return edge_index, shifts, unit_shifts, cell diff --git a/mace/data/utils.py b/mace/data/utils.py new file mode 100644 index 00000000..bb8e5448 --- /dev/null +++ b/mace/data/utils.py @@ -0,0 +1,408 @@ +########################################################################################### +# Data parsing utilities +# Authors: Ilyes Batatia, Gregor Simm and David Kovacs +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import logging +from dataclasses import dataclass +from typing import Dict, List, Optional, Sequence, Tuple + +import ase.data +import ase.io +import h5py +import numpy as np + +from mace.tools import AtomicNumberTable + +Vector = np.ndarray # [3,] +Positions = np.ndarray # [..., 3] +Forces = np.ndarray # [..., 3] +Stress = np.ndarray # [6, ], [3,3], [9, ] +Virials = np.ndarray # [6, ], [3,3], [9, ] +Charges = np.ndarray # [..., 1] +Cell = np.ndarray # [3,3] +Pbc = tuple # (3,) + +DEFAULT_CONFIG_TYPE = "Default" +DEFAULT_CONFIG_TYPE_WEIGHTS = {DEFAULT_CONFIG_TYPE: 1.0} + + +@dataclass +class Configuration: + atomic_numbers: np.ndarray + positions: Positions # Angstrom + energy: Optional[float] = None # eV + forces: Optional[Forces] = None # eV/Angstrom + stress: Optional[Stress] = None # eV/Angstrom^3 + virials: Optional[Virials] = None # eV + dipole: Optional[Vector] = None # Debye + charges: Optional[Charges] = None # atomic unit + cell: Optional[Cell] = None + pbc: Optional[Pbc] = None + + weight: float = 1.0 # weight of config in loss + energy_weight: float = 1.0 # weight of config energy in loss + forces_weight: float = 1.0 # weight of config forces in loss + stress_weight: float = 1.0 # weight of config stress in loss + virials_weight: float = 1.0 # weight of config virial in loss + config_type: Optional[str] = DEFAULT_CONFIG_TYPE # config_type of config + head: Optional[str] = "Default" # head used to compute the config + + +Configurations = List[Configuration] + + +def random_train_valid_split( + items: Sequence, valid_fraction: float, seed: int, work_dir: str +) -> Tuple[List, List]: + assert 0.0 < valid_fraction < 1.0 + + size = len(items) + train_size = size - int(valid_fraction * size) + + indices = list(range(size)) + rng = np.random.default_rng(seed) + rng.shuffle(indices) + if len(indices[train_size:]) < 10: + logging.info( + f"Using random {100 * valid_fraction:.0f}% of training set for validation with following indices: {indices[train_size:]}" + ) + else: + # Save indices to file + with open(work_dir + f"/valid_indices_{seed}.txt", "w", encoding="utf-8") as f: + for index in indices[train_size:]: + f.write(f"{index}\n") + + logging.info( + f"Using random {100 * valid_fraction:.0f}% of training set for validation with indices saved in: {work_dir}/valid_indices_{seed}.txt" + ) + + return ( + [items[i] for i in indices[:train_size]], + [items[i] for i in indices[train_size:]], + ) + + +def config_from_atoms_list( + atoms_list: List[ase.Atoms], + energy_key="REF_energy", + forces_key="REF_forces", + stress_key="REF_stress", + virials_key="REF_virials", + dipole_key="REF_dipole", + charges_key="REF_charges", + head_key="head", + config_type_weights: Optional[Dict[str, float]] = None, +) -> Configurations: + """Convert list of ase.Atoms into Configurations""" + if config_type_weights is None: + config_type_weights = DEFAULT_CONFIG_TYPE_WEIGHTS + + all_configs = [] + for atoms in atoms_list: + all_configs.append( + config_from_atoms( + atoms, + energy_key=energy_key, + forces_key=forces_key, + stress_key=stress_key, + virials_key=virials_key, + dipole_key=dipole_key, + charges_key=charges_key, + head_key=head_key, + config_type_weights=config_type_weights, + ) + ) + return all_configs + + +def config_from_atoms( + atoms: ase.Atoms, + energy_key="REF_energy", + forces_key="REF_forces", + stress_key="REF_stress", + virials_key="REF_virials", + dipole_key="REF_dipole", + charges_key="REF_charges", + head_key="head", + config_type_weights: Optional[Dict[str, float]] = None, +) -> Configuration: + """Convert ase.Atoms to Configuration""" + if config_type_weights is None: + config_type_weights = DEFAULT_CONFIG_TYPE_WEIGHTS + + energy = atoms.info.get(energy_key, None) # eV + forces = atoms.arrays.get(forces_key, None) # eV / Ang + stress = atoms.info.get(stress_key, None) # eV / Ang ^ 3 + virials = atoms.info.get(virials_key, None) + dipole = atoms.info.get(dipole_key, None) # Debye + # Charges default to 0 instead of None if not found + charges = atoms.arrays.get(charges_key, np.zeros(len(atoms))) # atomic unit + atomic_numbers = np.array( + [ase.data.atomic_numbers[symbol] for symbol in atoms.symbols] + ) + pbc = tuple(atoms.get_pbc()) + cell = np.array(atoms.get_cell()) + config_type = atoms.info.get("config_type", "Default") + weight = atoms.info.get("config_weight", 1.0) * config_type_weights.get( + config_type, 1.0 + ) + energy_weight = atoms.info.get("config_energy_weight", 1.0) + forces_weight = atoms.info.get("config_forces_weight", 1.0) + stress_weight = atoms.info.get("config_stress_weight", 1.0) + virials_weight = atoms.info.get("config_virials_weight", 1.0) + + head = atoms.info.get(head_key, "Default") + + # fill in missing quantities but set their weight to 0.0 + if energy is None: + energy = 0.0 + energy_weight = 0.0 + if forces is None: + forces = np.zeros(np.shape(atoms.positions)) + forces_weight = 0.0 + if stress is None: + stress = np.zeros(6) + stress_weight = 0.0 + if virials is None: + virials = np.zeros((3, 3)) + virials_weight = 0.0 + if dipole is None: + dipole = np.zeros(3) + # dipoles_weight = 0.0 + + return Configuration( + atomic_numbers=atomic_numbers, + positions=atoms.get_positions(), + energy=energy, + forces=forces, + stress=stress, + virials=virials, + dipole=dipole, + charges=charges, + weight=weight, + head=head, + energy_weight=energy_weight, + forces_weight=forces_weight, + stress_weight=stress_weight, + virials_weight=virials_weight, + config_type=config_type, + pbc=pbc, + cell=cell, + ) + + +def test_config_types( + test_configs: Configurations, +) -> List[Tuple[Optional[str], List[Configuration]]]: + """Split test set based on config_type-s""" + test_by_ct = [] + all_cts = [] + for conf in test_configs: + config_type_name = conf.config_type + "_" + conf.head + if config_type_name not in all_cts: + all_cts.append(config_type_name) + test_by_ct.append((config_type_name, [conf])) + else: + ind = all_cts.index(config_type_name) + test_by_ct[ind][1].append(conf) + return test_by_ct + + +def load_from_xyz( + file_path: str, + config_type_weights: Dict, + energy_key: str = "REF_energy", + forces_key: str = "REF_forces", + stress_key: str = "REF_stress", + virials_key: str = "REF_virials", + dipole_key: str = "REF_dipole", + charges_key: str = "REF_charges", + head_key: str = "head", + head_name: str = "Default", + extract_atomic_energies: bool = False, + keep_isolated_atoms: bool = False, +) -> Tuple[Dict[int, float], Configurations]: + atoms_list = ase.io.read(file_path, index=":") + if energy_key == "energy": + logging.warning( + "Since ASE version 3.23.0b1, using energy_key 'energy' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'energy' to 'REF_energy'. You need to use --energy_key='REF_energy' to specify the chosen key name." + ) + energy_key = "REF_energy" + for atoms in atoms_list: + try: + atoms.info["REF_energy"] = atoms.get_potential_energy() + except Exception as e: # pylint: disable=W0703 + logging.error(f"Failed to extract energy: {e}") + atoms.info["REF_energy"] = None + if forces_key == "forces": + logging.warning( + "Since ASE version 3.23.0b1, using forces_key 'forces' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'forces' to 'REF_forces'. You need to use --forces_key='REF_forces' to specify the chosen key name." + ) + forces_key = "REF_forces" + for atoms in atoms_list: + try: + atoms.arrays["REF_forces"] = atoms.get_forces() + except Exception as e: # pylint: disable=W0703 + logging.error(f"Failed to extract forces: {e}") + atoms.arrays["REF_forces"] = None + if stress_key == "stress": + logging.warning( + "Since ASE version 3.23.0b1, using stress_key 'stress' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'stress' to 'REF_stress'. You need to use --stress_key='REF_stress' to specify the chosen key name." + ) + stress_key = "REF_stress" + for atoms in atoms_list: + try: + atoms.info["REF_stress"] = atoms.get_stress() + except Exception as e: # pylint: disable=W0703 + atoms.info["REF_stress"] = None + if not isinstance(atoms_list, list): + atoms_list = [atoms_list] + + atomic_energies_dict = {} + if extract_atomic_energies: + atoms_without_iso_atoms = [] + + for idx, atoms in enumerate(atoms_list): + atoms.info[head_key] = head_name + isolated_atom_config = ( + len(atoms) == 1 and atoms.info.get("config_type") == "IsolatedAtom" + ) + if isolated_atom_config: + if energy_key in atoms.info.keys(): + atomic_energies_dict[atoms.get_atomic_numbers()[0]] = atoms.info[ + energy_key + ] + else: + logging.warning( + f"Configuration '{idx}' is marked as 'IsolatedAtom' " + "but does not contain an energy. Zero energy will be used." + ) + atomic_energies_dict[atoms.get_atomic_numbers()[0]] = np.zeros(1) + else: + atoms_without_iso_atoms.append(atoms) + + if len(atomic_energies_dict) > 0: + logging.info("Using isolated atom energies from training file") + if not keep_isolated_atoms: + atoms_list = atoms_without_iso_atoms + + configs = config_from_atoms_list( + atoms_list, + config_type_weights=config_type_weights, + energy_key=energy_key, + forces_key=forces_key, + stress_key=stress_key, + virials_key=virials_key, + dipole_key=dipole_key, + charges_key=charges_key, + head_key=head_key, + ) + return atomic_energies_dict, configs + + +def compute_average_E0s( + collections_train: Configurations, z_table: AtomicNumberTable +) -> Dict[int, float]: + """ + Function to compute the average interaction energy of each chemical element + returns dictionary of E0s + """ + len_train = len(collections_train) + len_zs = len(z_table) + A = np.zeros((len_train, len_zs)) + B = np.zeros(len_train) + for i in range(len_train): + B[i] = collections_train[i].energy + for j, z in enumerate(z_table.zs): + A[i, j] = np.count_nonzero(collections_train[i].atomic_numbers == z) + try: + E0s = np.linalg.lstsq(A, B, rcond=None)[0] + atomic_energies_dict = {} + for i, z in enumerate(z_table.zs): + atomic_energies_dict[z] = E0s[i] + except np.linalg.LinAlgError: + logging.error( + "Failed to compute E0s using least squares regression, using the same for all atoms" + ) + atomic_energies_dict = {} + for i, z in enumerate(z_table.zs): + atomic_energies_dict[z] = 0.0 + return atomic_energies_dict + + +def save_dataset_as_HDF5(dataset: List, out_name: str) -> None: + with h5py.File(out_name, "w") as f: + for i, data in enumerate(dataset): + grp = f.create_group(f"config_{i}") + grp["num_nodes"] = data.num_nodes + grp["edge_index"] = data.edge_index + grp["positions"] = data.positions + grp["shifts"] = data.shifts + grp["unit_shifts"] = data.unit_shifts + grp["cell"] = data.cell + grp["node_attrs"] = data.node_attrs + grp["weight"] = data.weight + grp["energy_weight"] = data.energy_weight + grp["forces_weight"] = data.forces_weight + grp["stress_weight"] = data.stress_weight + grp["virials_weight"] = data.virials_weight + grp["forces"] = data.forces + grp["energy"] = data.energy + grp["stress"] = data.stress + grp["virials"] = data.virials + grp["dipole"] = data.dipole + grp["charges"] = data.charges + grp["head"] = data.head + + +def save_AtomicData_to_HDF5(data, i, h5_file) -> None: + grp = h5_file.create_group(f"config_{i}") + grp["num_nodes"] = data.num_nodes + grp["edge_index"] = data.edge_index + grp["positions"] = data.positions + grp["shifts"] = data.shifts + grp["unit_shifts"] = data.unit_shifts + grp["cell"] = data.cell + grp["node_attrs"] = data.node_attrs + grp["weight"] = data.weight + grp["energy_weight"] = data.energy_weight + grp["forces_weight"] = data.forces_weight + grp["stress_weight"] = data.stress_weight + grp["virials_weight"] = data.virials_weight + grp["forces"] = data.forces + grp["energy"] = data.energy + grp["stress"] = data.stress + grp["virials"] = data.virials + grp["dipole"] = data.dipole + grp["charges"] = data.charges + grp["head"] = data.head + + +def save_configurations_as_HDF5(configurations: Configurations, _, h5_file) -> None: + grp = h5_file.create_group("config_batch_0") + for j, config in enumerate(configurations): + subgroup_name = f"config_{j}" + subgroup = grp.create_group(subgroup_name) + subgroup["atomic_numbers"] = write_value(config.atomic_numbers) + subgroup["positions"] = write_value(config.positions) + subgroup["energy"] = write_value(config.energy) + subgroup["forces"] = write_value(config.forces) + subgroup["stress"] = write_value(config.stress) + subgroup["virials"] = write_value(config.virials) + subgroup["head"] = write_value(config.head) + subgroup["dipole"] = write_value(config.dipole) + subgroup["charges"] = write_value(config.charges) + subgroup["cell"] = write_value(config.cell) + subgroup["pbc"] = write_value(config.pbc) + subgroup["weight"] = write_value(config.weight) + subgroup["energy_weight"] = write_value(config.energy_weight) + subgroup["forces_weight"] = write_value(config.forces_weight) + subgroup["stress_weight"] = write_value(config.stress_weight) + subgroup["virials_weight"] = write_value(config.virials_weight) + subgroup["config_type"] = write_value(config.config_type) + + +def write_value(value): + return value if value is not None else "None" diff --git a/mace/modules/__init__.py b/mace/modules/__init__.py new file mode 100644 index 00000000..e48e0b23 --- /dev/null +++ b/mace/modules/__init__.py @@ -0,0 +1,113 @@ +from typing import Callable, Dict, Optional, Type + +import torch + +from .blocks import ( + AgnosticNonlinearInteractionBlock, + AgnosticResidualNonlinearInteractionBlock, + AtomicEnergiesBlock, + EquivariantProductBasisBlock, + InteractionBlock, + LinearDipoleReadoutBlock, + LinearNodeEmbeddingBlock, + LinearReadoutBlock, + NonLinearDipoleReadoutBlock, + NonLinearReadoutBlock, + RadialEmbeddingBlock, + RealAgnosticAttResidualInteractionBlock, + RealAgnosticDensityInteractionBlock, + RealAgnosticDensityResidualInteractionBlock, + RealAgnosticInteractionBlock, + RealAgnosticResidualInteractionBlock, + ResidualElementDependentInteractionBlock, + ScaleShiftBlock, +) +from .loss import ( + DipoleSingleLoss, + UniversalLoss, + WeightedEnergyForcesDipoleLoss, + WeightedEnergyForcesLoss, + WeightedEnergyForcesStressLoss, + WeightedEnergyForcesVirialsLoss, + WeightedForcesLoss, + WeightedHuberEnergyForcesStressLoss, +) +from .models import ( + MACE, + AtomicDipolesMACE, + BOTNet, + EnergyDipolesMACE, + ScaleShiftBOTNet, + ScaleShiftMACE, +) +from .radial import BesselBasis, GaussianBasis, PolynomialCutoff, ZBLBasis +from .symmetric_contraction import SymmetricContraction +from .utils import ( + compute_avg_num_neighbors, + compute_fixed_charge_dipole, + compute_mean_rms_energy_forces, + compute_mean_std_atomic_inter_energy, + compute_rms_dipoles, + compute_statistics, +) + +interaction_classes: Dict[str, Type[InteractionBlock]] = { + "AgnosticNonlinearInteractionBlock": AgnosticNonlinearInteractionBlock, + "ResidualElementDependentInteractionBlock": ResidualElementDependentInteractionBlock, + "AgnosticResidualNonlinearInteractionBlock": AgnosticResidualNonlinearInteractionBlock, + "RealAgnosticResidualInteractionBlock": RealAgnosticResidualInteractionBlock, + "RealAgnosticAttResidualInteractionBlock": RealAgnosticAttResidualInteractionBlock, + "RealAgnosticInteractionBlock": RealAgnosticInteractionBlock, + "RealAgnosticDensityInteractionBlock": RealAgnosticDensityInteractionBlock, + "RealAgnosticDensityResidualInteractionBlock": RealAgnosticDensityResidualInteractionBlock, +} + +scaling_classes: Dict[str, Callable] = { + "std_scaling": compute_mean_std_atomic_inter_energy, + "rms_forces_scaling": compute_mean_rms_energy_forces, + "rms_dipoles_scaling": compute_rms_dipoles, +} + +gate_dict: Dict[str, Optional[Callable]] = { + "abs": torch.abs, + "tanh": torch.tanh, + "silu": torch.nn.functional.silu, + "None": None, +} + +__all__ = [ + "AtomicEnergiesBlock", + "RadialEmbeddingBlock", + "ZBLBasis", + "LinearNodeEmbeddingBlock", + "LinearReadoutBlock", + "EquivariantProductBasisBlock", + "ScaleShiftBlock", + "LinearDipoleReadoutBlock", + "NonLinearDipoleReadoutBlock", + "InteractionBlock", + "NonLinearReadoutBlock", + "PolynomialCutoff", + "BesselBasis", + "GaussianBasis", + "MACE", + "ScaleShiftMACE", + "BOTNet", + "ScaleShiftBOTNet", + "AtomicDipolesMACE", + "EnergyDipolesMACE", + "WeightedEnergyForcesLoss", + "WeightedForcesLoss", + "WeightedEnergyForcesVirialsLoss", + "WeightedEnergyForcesStressLoss", + "DipoleSingleLoss", + "WeightedEnergyForcesDipoleLoss", + "WeightedHuberEnergyForcesStressLoss", + "UniversalLoss", + "SymmetricContraction", + "interaction_classes", + "compute_mean_std_atomic_inter_energy", + "compute_avg_num_neighbors", + "compute_statistics", + "compute_fixed_charge_dipole", +] diff --git a/mace/modules/blocks.py b/mace/modules/blocks.py new file mode 100644 index 00000000..0db3b02e --- /dev/null +++ b/mace/modules/blocks.py @@ -0,0 +1,964 @@ +########################################################################################### +# Elementary Block for Building O(3) Equivariant Higher Order Message Passing Neural Network +# Authors: Ilyes Batatia, Gregor Simm +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +from abc import abstractmethod +from typing import Callable, List, Optional, Tuple, Union + +import numpy as np +import torch.nn.functional +from e3nn import nn, o3 +from e3nn.util.jit import compile_mode + +from mace.tools.compile import simplify_if_compile +from mace.tools.scatter import scatter_sum + +from .irreps_tools import ( + linear_out_irreps, + mask_head, + reshape_irreps, + tp_out_irreps_with_instructions, +) +from .radial import ( + AgnesiTransform, + BesselBasis, + ChebychevBasis, + GaussianBasis, + PolynomialCutoff, + SoftTransform, +) +from .symmetric_contraction import SymmetricContraction + + +@compile_mode("script") +class LinearNodeEmbeddingBlock(torch.nn.Module): + def __init__(self, irreps_in: o3.Irreps, irreps_out: o3.Irreps): + super().__init__() + self.linear = o3.Linear(irreps_in=irreps_in, irreps_out=irreps_out) + + def forward( + self, + node_attrs: torch.Tensor, + ) -> torch.Tensor: # [n_nodes, irreps] + return self.linear(node_attrs) + + +@compile_mode("script") +class LinearReadoutBlock(torch.nn.Module): + def __init__(self, irreps_in: o3.Irreps, irrep_out: o3.Irreps = o3.Irreps("0e")): + super().__init__() + self.linear = o3.Linear(irreps_in=irreps_in, irreps_out=irrep_out) + + def forward( + self, + x: torch.Tensor, + heads: Optional[torch.Tensor] = None, # pylint: disable=unused-argument + ) -> torch.Tensor: # [n_nodes, irreps] # [..., ] + return self.linear(x) # [n_nodes, 1] + + +@simplify_if_compile +@compile_mode("script") +class NonLinearReadoutBlock(torch.nn.Module): + def __init__( + self, + irreps_in: o3.Irreps, + MLP_irreps: o3.Irreps, + gate: Optional[Callable], + irrep_out: o3.Irreps = o3.Irreps("0e"), + num_heads: int = 1, + ): + super().__init__() + self.hidden_irreps = MLP_irreps + self.num_heads = num_heads + self.linear_1 = o3.Linear(irreps_in=irreps_in, irreps_out=self.hidden_irreps) + self.non_linearity = nn.Activation(irreps_in=self.hidden_irreps, acts=[gate]) + self.linear_2 = o3.Linear(irreps_in=self.hidden_irreps, irreps_out=irrep_out) + + def forward( + self, x: torch.Tensor, heads: Optional[torch.Tensor] = None + ) -> torch.Tensor: # [n_nodes, irreps] # [..., ] + x = self.non_linearity(self.linear_1(x)) + if hasattr(self, "num_heads"): + if self.num_heads > 1 and heads is not None: + x = mask_head(x, heads, self.num_heads) + return self.linear_2(x) # [n_nodes, len(heads)] + + +@compile_mode("script") +class LinearDipoleReadoutBlock(torch.nn.Module): + def __init__(self, irreps_in: o3.Irreps, dipole_only: bool = False): + super().__init__() + if dipole_only: + self.irreps_out = o3.Irreps("1x1o") + else: + self.irreps_out = o3.Irreps("1x0e + 1x1o") + self.linear = o3.Linear(irreps_in=irreps_in, irreps_out=self.irreps_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ] + return self.linear(x) # [n_nodes, 1] + + +@compile_mode("script") +class NonLinearDipoleReadoutBlock(torch.nn.Module): + def __init__( + self, + irreps_in: o3.Irreps, + MLP_irreps: o3.Irreps, + gate: Callable, + dipole_only: bool = False, + ): + super().__init__() + self.hidden_irreps = MLP_irreps + if dipole_only: + self.irreps_out = o3.Irreps("1x1o") + else: + self.irreps_out = o3.Irreps("1x0e + 1x1o") + irreps_scalars = o3.Irreps( + [(mul, ir) for mul, ir in MLP_irreps if ir.l == 0 and ir in self.irreps_out] + ) + irreps_gated = o3.Irreps( + [(mul, ir) for mul, ir in MLP_irreps if ir.l > 0 and ir in self.irreps_out] + ) + irreps_gates = o3.Irreps([mul, "0e"] for mul, _ in irreps_gated) + self.equivariant_nonlin = nn.Gate( + irreps_scalars=irreps_scalars, + act_scalars=[gate for _, ir in irreps_scalars], + irreps_gates=irreps_gates, + act_gates=[gate] * len(irreps_gates), + irreps_gated=irreps_gated, + ) + self.irreps_nonlin = self.equivariant_nonlin.irreps_in.simplify() + self.linear_1 = o3.Linear(irreps_in=irreps_in, irreps_out=self.irreps_nonlin) + self.linear_2 = o3.Linear( + irreps_in=self.hidden_irreps, irreps_out=self.irreps_out + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ] + x = self.equivariant_nonlin(self.linear_1(x)) + return self.linear_2(x) # [n_nodes, 1] + + +@compile_mode("script") +class AtomicEnergiesBlock(torch.nn.Module): + atomic_energies: torch.Tensor + + def __init__(self, atomic_energies: Union[np.ndarray, torch.Tensor]): + super().__init__() + # assert len(atomic_energies.shape) == 1 + + self.register_buffer( + "atomic_energies", + torch.tensor(atomic_energies, dtype=torch.get_default_dtype()), + ) # [n_elements, n_heads] + + def forward( + self, x: torch.Tensor # one-hot of elements [..., n_elements] + ) -> torch.Tensor: # [..., ] + return torch.matmul(x, torch.atleast_2d(self.atomic_energies).T) + + def __repr__(self): + formatted_energies = ", ".join( + [ + "[" + ", ".join([f"{x:.4f}" for x in group]) + "]" + for group in torch.atleast_2d(self.atomic_energies) + ] + ) + return f"{self.__class__.__name__}(energies=[{formatted_energies}])" + + +@compile_mode("script") +class RadialEmbeddingBlock(torch.nn.Module): + def __init__( + self, + r_max: float, + num_bessel: int, + num_polynomial_cutoff: int, + radial_type: str = "bessel", + distance_transform: str = "None", + ): + super().__init__() + if radial_type == "bessel": + self.bessel_fn = BesselBasis(r_max=r_max, num_basis=num_bessel) + elif radial_type == "gaussian": + self.bessel_fn = GaussianBasis(r_max=r_max, num_basis=num_bessel) + elif radial_type == "chebyshev": + self.bessel_fn = ChebychevBasis(r_max=r_max, num_basis=num_bessel) + if distance_transform == "Agnesi": + self.distance_transform = AgnesiTransform() + elif distance_transform == "Soft": + self.distance_transform = SoftTransform() + self.cutoff_fn = PolynomialCutoff(r_max=r_max, p=num_polynomial_cutoff) + self.out_dim = num_bessel + + def forward( + self, + edge_lengths: torch.Tensor, # [n_edges, 1] + node_attrs: torch.Tensor, + edge_index: torch.Tensor, + atomic_numbers: torch.Tensor, + ): + cutoff = self.cutoff_fn(edge_lengths) # [n_edges, 1] + if hasattr(self, "distance_transform"): + edge_lengths = self.distance_transform( + edge_lengths, node_attrs, edge_index, atomic_numbers + ) + radial = self.bessel_fn(edge_lengths) # [n_edges, n_basis] + return radial * cutoff # [n_edges, n_basis] + + +@compile_mode("script") +class EquivariantProductBasisBlock(torch.nn.Module): + def __init__( + self, + node_feats_irreps: o3.Irreps, + target_irreps: o3.Irreps, + correlation: int, + use_sc: bool = True, + num_elements: Optional[int] = None, + ) -> None: + super().__init__() + + self.use_sc = use_sc + self.symmetric_contractions = SymmetricContraction( + irreps_in=node_feats_irreps, + irreps_out=target_irreps, + correlation=correlation, + num_elements=num_elements, + ) + # Update linear + self.linear = o3.Linear( + target_irreps, + target_irreps, + internal_weights=True, + shared_weights=True, + ) + + def forward( + self, + node_feats: torch.Tensor, + sc: Optional[torch.Tensor], + node_attrs: torch.Tensor, + ) -> torch.Tensor: + node_feats = self.symmetric_contractions(node_feats, node_attrs) + if self.use_sc and sc is not None: + return self.linear(node_feats) + sc + return self.linear(node_feats) + + +@compile_mode("script") +class InteractionBlock(torch.nn.Module): + def __init__( + self, + node_attrs_irreps: o3.Irreps, + node_feats_irreps: o3.Irreps, + edge_attrs_irreps: o3.Irreps, + edge_feats_irreps: o3.Irreps, + target_irreps: o3.Irreps, + hidden_irreps: o3.Irreps, + avg_num_neighbors: float, + radial_MLP: Optional[List[int]] = None, + ) -> None: + super().__init__() + self.node_attrs_irreps = node_attrs_irreps + self.node_feats_irreps = node_feats_irreps + self.edge_attrs_irreps = edge_attrs_irreps + self.edge_feats_irreps = edge_feats_irreps + self.target_irreps = target_irreps + self.hidden_irreps = hidden_irreps + self.avg_num_neighbors = avg_num_neighbors + if radial_MLP is None: + radial_MLP = [64, 64, 64] + self.radial_MLP = radial_MLP + + self._setup() + + @abstractmethod + def _setup(self) -> None: + raise NotImplementedError + + @abstractmethod + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + ) -> torch.Tensor: + raise NotImplementedError + + +nonlinearities = {1: torch.nn.functional.silu, -1: torch.tanh} + + +@compile_mode("script") +class TensorProductWeightsBlock(torch.nn.Module): + def __init__(self, num_elements: int, num_edge_feats: int, num_feats_out: int): + super().__init__() + + weights = torch.empty( + (num_elements, num_edge_feats, num_feats_out), + dtype=torch.get_default_dtype(), + ) + torch.nn.init.xavier_uniform_(weights) + self.weights = torch.nn.Parameter(weights) + + def forward( + self, + sender_or_receiver_node_attrs: torch.Tensor, # assumes that the node attributes are one-hot encoded + edge_feats: torch.Tensor, + ): + return torch.einsum( + "be, ba, aek -> bk", edge_feats, sender_or_receiver_node_attrs, self.weights + ) + + def __repr__(self): + return ( + f'{self.__class__.__name__}(shape=({", ".join(str(s) for s in self.weights.shape)}), ' + f"weights={np.prod(self.weights.shape)})" + ) + + +@compile_mode("script") +class ResidualElementDependentInteractionBlock(InteractionBlock): + def _setup(self) -> None: + self.linear_up = o3.Linear( + self.node_feats_irreps, + self.node_feats_irreps, + internal_weights=True, + shared_weights=True, + ) + # TensorProduct + irreps_mid, instructions = tp_out_irreps_with_instructions( + self.node_feats_irreps, self.edge_attrs_irreps, self.target_irreps + ) + self.conv_tp = o3.TensorProduct( + self.node_feats_irreps, + self.edge_attrs_irreps, + irreps_mid, + instructions=instructions, + shared_weights=False, + internal_weights=False, + ) + self.conv_tp_weights = TensorProductWeightsBlock( + num_elements=self.node_attrs_irreps.num_irreps, + num_edge_feats=self.edge_feats_irreps.num_irreps, + num_feats_out=self.conv_tp.weight_numel, + ) + + # Linear + irreps_mid = irreps_mid.simplify() + self.irreps_out = linear_out_irreps(irreps_mid, self.target_irreps) + self.irreps_out = self.irreps_out.simplify() + self.linear = o3.Linear( + irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + ) + + # Selector TensorProduct + self.skip_tp = o3.FullyConnectedTensorProduct( + self.node_feats_irreps, self.node_attrs_irreps, self.irreps_out + ) + + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + ) -> torch.Tensor: + sender = edge_index[0] + receiver = edge_index[1] + num_nodes = node_feats.shape[0] + sc = self.skip_tp(node_feats, node_attrs) + node_feats = self.linear_up(node_feats) + tp_weights = self.conv_tp_weights(node_attrs[sender], edge_feats) + mji = self.conv_tp( + node_feats[sender], edge_attrs, tp_weights + ) # [n_edges, irreps] + message = scatter_sum( + src=mji, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, irreps] + message = self.linear(message) / self.avg_num_neighbors + return message + sc # [n_nodes, irreps] + + +@compile_mode("script") +class AgnosticNonlinearInteractionBlock(InteractionBlock): + def _setup(self) -> None: + self.linear_up = o3.Linear( + self.node_feats_irreps, + self.node_feats_irreps, + internal_weights=True, + shared_weights=True, + ) + # TensorProduct + irreps_mid, instructions = tp_out_irreps_with_instructions( + self.node_feats_irreps, self.edge_attrs_irreps, self.target_irreps + ) + self.conv_tp = o3.TensorProduct( + self.node_feats_irreps, + self.edge_attrs_irreps, + irreps_mid, + instructions=instructions, + shared_weights=False, + internal_weights=False, + ) + + # Convolution weights + input_dim = self.edge_feats_irreps.num_irreps + self.conv_tp_weights = nn.FullyConnectedNet( + [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], + torch.nn.functional.silu, + ) + + # Linear + irreps_mid = irreps_mid.simplify() + self.irreps_out = linear_out_irreps(irreps_mid, self.target_irreps) + self.irreps_out = self.irreps_out.simplify() + self.linear = o3.Linear( + irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + ) + + # Selector TensorProduct + self.skip_tp = o3.FullyConnectedTensorProduct( + self.irreps_out, self.node_attrs_irreps, self.irreps_out + ) + + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + ) -> torch.Tensor: + sender = edge_index[0] + receiver = edge_index[1] + num_nodes = node_feats.shape[0] + tp_weights = self.conv_tp_weights(edge_feats) + node_feats = self.linear_up(node_feats) + mji = self.conv_tp( + node_feats[sender], edge_attrs, tp_weights + ) # [n_edges, irreps] + message = scatter_sum( + src=mji, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, irreps] + message = self.linear(message) / self.avg_num_neighbors + message = self.skip_tp(message, node_attrs) + return message # [n_nodes, irreps] + + +@compile_mode("script") +class AgnosticResidualNonlinearInteractionBlock(InteractionBlock): + def _setup(self) -> None: + # First linear + self.linear_up = o3.Linear( + self.node_feats_irreps, + self.node_feats_irreps, + internal_weights=True, + shared_weights=True, + ) + # TensorProduct + irreps_mid, instructions = tp_out_irreps_with_instructions( + self.node_feats_irreps, self.edge_attrs_irreps, self.target_irreps + ) + self.conv_tp = o3.TensorProduct( + self.node_feats_irreps, + self.edge_attrs_irreps, + irreps_mid, + instructions=instructions, + shared_weights=False, + internal_weights=False, + ) + + # Convolution weights + input_dim = self.edge_feats_irreps.num_irreps + self.conv_tp_weights = nn.FullyConnectedNet( + [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], + torch.nn.functional.silu, + ) + + # Linear + irreps_mid = irreps_mid.simplify() + self.irreps_out = linear_out_irreps(irreps_mid, self.target_irreps) + self.irreps_out = self.irreps_out.simplify() + self.linear = o3.Linear( + irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + ) + + # Selector TensorProduct + self.skip_tp = o3.FullyConnectedTensorProduct( + self.node_feats_irreps, self.node_attrs_irreps, self.irreps_out + ) + + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + ) -> torch.Tensor: + sender = edge_index[0] + receiver = edge_index[1] + num_nodes = node_feats.shape[0] + sc = self.skip_tp(node_feats, node_attrs) + node_feats = self.linear_up(node_feats) + tp_weights = self.conv_tp_weights(edge_feats) + mji = self.conv_tp( + node_feats[sender], edge_attrs, tp_weights + ) # [n_edges, irreps] + message = scatter_sum( + src=mji, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, irreps] + message = self.linear(message) / self.avg_num_neighbors + message = message + sc + return message # [n_nodes, irreps] + + +@compile_mode("script") +class RealAgnosticInteractionBlock(InteractionBlock): + def _setup(self) -> None: + # First linear + self.linear_up = o3.Linear( + self.node_feats_irreps, + self.node_feats_irreps, + internal_weights=True, + shared_weights=True, + ) + # TensorProduct + irreps_mid, instructions = tp_out_irreps_with_instructions( + self.node_feats_irreps, + self.edge_attrs_irreps, + self.target_irreps, + ) + self.conv_tp = o3.TensorProduct( + self.node_feats_irreps, + self.edge_attrs_irreps, + irreps_mid, + instructions=instructions, + shared_weights=False, + internal_weights=False, + ) + + # Convolution weights + input_dim = self.edge_feats_irreps.num_irreps + self.conv_tp_weights = nn.FullyConnectedNet( + [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], + torch.nn.functional.silu, + ) + + # Linear + irreps_mid = irreps_mid.simplify() + self.irreps_out = self.target_irreps + self.linear = o3.Linear( + irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + ) + + # Selector TensorProduct + self.skip_tp = o3.FullyConnectedTensorProduct( + self.irreps_out, self.node_attrs_irreps, self.irreps_out + ) + self.reshape = reshape_irreps(self.irreps_out) + + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + ) -> Tuple[torch.Tensor, None]: + sender = edge_index[0] + receiver = edge_index[1] + num_nodes = node_feats.shape[0] + node_feats = self.linear_up(node_feats) + tp_weights = self.conv_tp_weights(edge_feats) + mji = self.conv_tp( + node_feats[sender], edge_attrs, tp_weights + ) # [n_edges, irreps] + message = scatter_sum( + src=mji, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, irreps] + message = self.linear(message) / self.avg_num_neighbors + message = self.skip_tp(message, node_attrs) + return ( + self.reshape(message), + None, + ) # [n_nodes, channels, (lmax + 1)**2] + + +@compile_mode("script") +class RealAgnosticResidualInteractionBlock(InteractionBlock): + def _setup(self) -> None: + # First linear + self.linear_up = o3.Linear( + self.node_feats_irreps, + self.node_feats_irreps, + internal_weights=True, + shared_weights=True, + ) + # TensorProduct + irreps_mid, instructions = tp_out_irreps_with_instructions( + self.node_feats_irreps, + self.edge_attrs_irreps, + self.target_irreps, + ) + self.conv_tp = o3.TensorProduct( + self.node_feats_irreps, + self.edge_attrs_irreps, + irreps_mid, + instructions=instructions, + shared_weights=False, + internal_weights=False, + ) + + # Convolution weights + input_dim = self.edge_feats_irreps.num_irreps + self.conv_tp_weights = nn.FullyConnectedNet( + [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], + torch.nn.functional.silu, # gate + ) + + # Linear + irreps_mid = irreps_mid.simplify() + self.irreps_out = self.target_irreps + self.linear = o3.Linear( + irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + ) + + # Selector TensorProduct + self.skip_tp = o3.FullyConnectedTensorProduct( + self.node_feats_irreps, self.node_attrs_irreps, self.hidden_irreps + ) + self.reshape = reshape_irreps(self.irreps_out) + + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + sender = edge_index[0] + receiver = edge_index[1] + num_nodes = node_feats.shape[0] + sc = self.skip_tp(node_feats, node_attrs) + node_feats = self.linear_up(node_feats) + tp_weights = self.conv_tp_weights(edge_feats) + mji = self.conv_tp( + node_feats[sender], edge_attrs, tp_weights + ) # [n_edges, irreps] + message = scatter_sum( + src=mji, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, irreps] + message = self.linear(message) / self.avg_num_neighbors + return ( + self.reshape(message), + sc, + ) # [n_nodes, channels, (lmax + 1)**2] + + +@compile_mode("script") +class RealAgnosticDensityInteractionBlock(InteractionBlock): + def _setup(self) -> None: + # First linear + self.linear_up = o3.Linear( + self.node_feats_irreps, + self.node_feats_irreps, + internal_weights=True, + shared_weights=True, + ) + # TensorProduct + irreps_mid, instructions = tp_out_irreps_with_instructions( + self.node_feats_irreps, + self.edge_attrs_irreps, + self.target_irreps, + ) + self.conv_tp = o3.TensorProduct( + self.node_feats_irreps, + self.edge_attrs_irreps, + irreps_mid, + instructions=instructions, + shared_weights=False, + internal_weights=False, + ) + + # Convolution weights + input_dim = self.edge_feats_irreps.num_irreps + self.conv_tp_weights = nn.FullyConnectedNet( + [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], + torch.nn.functional.silu, + ) + + # Linear + irreps_mid = irreps_mid.simplify() + self.irreps_out = self.target_irreps + self.linear = o3.Linear( + irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + ) + + # Selector TensorProduct + self.skip_tp = o3.FullyConnectedTensorProduct( + self.irreps_out, self.node_attrs_irreps, self.irreps_out + ) + self.reshape = reshape_irreps(self.irreps_out) + + # Density normalization + self.density_fn = nn.FullyConnectedNet( + [input_dim] + + [ + 1, + ], + torch.nn.functional.silu, + ) + # Reshape + self.reshape = reshape_irreps(self.irreps_out) + + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + ) -> Tuple[torch.Tensor, None]: + sender = edge_index[0] + receiver = edge_index[1] + num_nodes = node_feats.shape[0] + node_feats = self.linear_up(node_feats) + tp_weights = self.conv_tp_weights(edge_feats) + edge_density = torch.tanh(self.density_fn(edge_feats) ** 2) + mji = self.conv_tp( + node_feats[sender], edge_attrs, tp_weights + ) # [n_edges, irreps] + density = scatter_sum( + src=edge_density, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, 1] + message = scatter_sum( + src=mji, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, irreps] + message = self.linear(message) / (density + 1) + message = self.skip_tp(message, node_attrs) + return ( + self.reshape(message), + None, + ) # [n_nodes, channels, (lmax + 1)**2] + + +@compile_mode("script") +class RealAgnosticDensityResidualInteractionBlock(InteractionBlock): + def _setup(self) -> None: + # First linear + self.linear_up = o3.Linear( + self.node_feats_irreps, + self.node_feats_irreps, + internal_weights=True, + shared_weights=True, + ) + # TensorProduct + irreps_mid, instructions = tp_out_irreps_with_instructions( + self.node_feats_irreps, + self.edge_attrs_irreps, + self.target_irreps, + ) + self.conv_tp = o3.TensorProduct( + self.node_feats_irreps, + self.edge_attrs_irreps, + irreps_mid, + instructions=instructions, + shared_weights=False, + internal_weights=False, + ) + + # Convolution weights + input_dim = self.edge_feats_irreps.num_irreps + self.conv_tp_weights = nn.FullyConnectedNet( + [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], + torch.nn.functional.silu, # gate + ) + + # Linear + irreps_mid = irreps_mid.simplify() + self.irreps_out = self.target_irreps + self.linear = o3.Linear( + irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + ) + + # Selector TensorProduct + self.skip_tp = o3.FullyConnectedTensorProduct( + self.node_feats_irreps, self.node_attrs_irreps, self.hidden_irreps + ) + self.reshape = reshape_irreps(self.irreps_out) + + # Density normalization + self.density_fn = nn.FullyConnectedNet( + [input_dim] + + [ + 1, + ], + torch.nn.functional.silu, + ) + + # Reshape + self.reshape = reshape_irreps(self.irreps_out) + + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + sender = edge_index[0] + receiver = edge_index[1] + num_nodes = node_feats.shape[0] + sc = self.skip_tp(node_feats, node_attrs) + node_feats = self.linear_up(node_feats) + tp_weights = self.conv_tp_weights(edge_feats) + edge_density = torch.tanh(self.density_fn(edge_feats) ** 2) + mji = self.conv_tp( + node_feats[sender], edge_attrs, tp_weights + ) # [n_edges, irreps] + density = scatter_sum( + src=edge_density, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, 1] + message = scatter_sum( + src=mji, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, irreps] + message = self.linear(message) / (density + 1) + return ( + self.reshape(message), + sc, + ) # [n_nodes, channels, (lmax + 1)**2] + + +@compile_mode("script") +class RealAgnosticAttResidualInteractionBlock(InteractionBlock): + def _setup(self) -> None: + self.node_feats_down_irreps = o3.Irreps("64x0e") + # First linear + self.linear_up = o3.Linear( + self.node_feats_irreps, + self.node_feats_irreps, + internal_weights=True, + shared_weights=True, + ) + # TensorProduct + irreps_mid, instructions = tp_out_irreps_with_instructions( + self.node_feats_irreps, + self.edge_attrs_irreps, + self.target_irreps, + ) + self.conv_tp = o3.TensorProduct( + self.node_feats_irreps, + self.edge_attrs_irreps, + irreps_mid, + instructions=instructions, + shared_weights=False, + internal_weights=False, + ) + + # Convolution weights + self.linear_down = o3.Linear( + self.node_feats_irreps, + self.node_feats_down_irreps, + internal_weights=True, + shared_weights=True, + ) + input_dim = ( + self.edge_feats_irreps.num_irreps + + 2 * self.node_feats_down_irreps.num_irreps + ) + self.conv_tp_weights = nn.FullyConnectedNet( + [input_dim] + 3 * [256] + [self.conv_tp.weight_numel], + torch.nn.functional.silu, + ) + + # Linear + irreps_mid = irreps_mid.simplify() + self.irreps_out = self.target_irreps + self.linear = o3.Linear( + irreps_mid, + self.irreps_out, + internal_weights=True, + shared_weights=True, + ) + + self.reshape = reshape_irreps(self.irreps_out) + + # Skip connection. + self.skip_linear = o3.Linear(self.node_feats_irreps, self.hidden_irreps) + + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + ) -> Tuple[torch.Tensor, None]: + sender = edge_index[0] + receiver = edge_index[1] + num_nodes = node_feats.shape[0] + sc = self.skip_linear(node_feats) + node_feats_up = self.linear_up(node_feats) + node_feats_down = self.linear_down(node_feats) + augmented_edge_feats = torch.cat( + [ + edge_feats, + node_feats_down[sender], + node_feats_down[receiver], + ], + dim=-1, + ) + tp_weights = self.conv_tp_weights(augmented_edge_feats) + mji = self.conv_tp( + node_feats_up[sender], edge_attrs, tp_weights + ) # [n_edges, irreps] + message = scatter_sum( + src=mji, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, irreps] + message = self.linear(message) / self.avg_num_neighbors + return ( + self.reshape(message), + sc, + ) # [n_nodes, channels, (lmax + 1)**2] + + +@compile_mode("script") +class ScaleShiftBlock(torch.nn.Module): + def __init__(self, scale: float, shift: float): + super().__init__() + self.register_buffer( + "scale", + torch.tensor(scale, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "shift", + torch.tensor(shift, dtype=torch.get_default_dtype()), + ) + + def forward(self, x: torch.Tensor, head: torch.Tensor) -> torch.Tensor: + return ( + torch.atleast_1d(self.scale)[head] * x + torch.atleast_1d(self.shift)[head] + ) + + def __repr__(self): + formatted_scale = ( + ", ".join([f"{x:.4f}" for x in self.scale]) + if self.scale.numel() > 1 + else f"{self.scale.item():.4f}" + ) + formatted_shift = ( + ", ".join([f"{x:.4f}" for x in self.shift]) + if self.shift.numel() > 1 + else f"{self.shift.item():.4f}" + ) + return f"{self.__class__.__name__}(scale={formatted_scale}, shift={formatted_shift})" diff --git a/mace/modules/irreps_tools.py b/mace/modules/irreps_tools.py new file mode 100644 index 00000000..b0960193 --- /dev/null +++ b/mace/modules/irreps_tools.py @@ -0,0 +1,94 @@ +########################################################################################### +# Elementary tools for handling irreducible representations +# Authors: Ilyes Batatia, Gregor Simm +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +from typing import List, Tuple + +import torch +from e3nn import o3 +from e3nn.util.jit import compile_mode + + +# Based on mir-group/nequip +def tp_out_irreps_with_instructions( + irreps1: o3.Irreps, irreps2: o3.Irreps, target_irreps: o3.Irreps +) -> Tuple[o3.Irreps, List]: + trainable = True + + # Collect possible irreps and their instructions + irreps_out_list: List[Tuple[int, o3.Irreps]] = [] + instructions = [] + for i, (mul, ir_in) in enumerate(irreps1): + for j, (_, ir_edge) in enumerate(irreps2): + for ir_out in ir_in * ir_edge: # | l1 - l2 | <= l <= l1 + l2 + if ir_out in target_irreps: + k = len(irreps_out_list) # instruction index + irreps_out_list.append((mul, ir_out)) + instructions.append((i, j, k, "uvu", trainable)) + + # We sort the output irreps of the tensor product so that we can simplify them + # when they are provided to the second o3.Linear + irreps_out = o3.Irreps(irreps_out_list) + irreps_out, permut, _ = irreps_out.sort() + + # Permute the output indexes of the instructions to match the sorted irreps: + instructions = [ + (i_in1, i_in2, permut[i_out], mode, train) + for i_in1, i_in2, i_out, mode, train in instructions + ] + + instructions = sorted(instructions, key=lambda x: x[2]) + + return irreps_out, instructions + + +def linear_out_irreps(irreps: o3.Irreps, target_irreps: o3.Irreps) -> o3.Irreps: + # Assuming simplified irreps + irreps_mid = [] + for _, ir_in in irreps: + found = False + + for mul, ir_out in target_irreps: + if ir_in == ir_out: + irreps_mid.append((mul, ir_out)) + found = True + break + + if not found: + raise RuntimeError(f"{ir_in} not in {target_irreps}") + + return o3.Irreps(irreps_mid) + + +@compile_mode("script") +class reshape_irreps(torch.nn.Module): + def __init__(self, irreps: o3.Irreps) -> None: + super().__init__() + self.irreps = o3.Irreps(irreps) + self.dims = [] + self.muls = [] + for mul, ir in self.irreps: + d = ir.dim + self.dims.append(d) + self.muls.append(mul) + + def forward(self, tensor: torch.Tensor) -> torch.Tensor: + ix = 0 + out = [] + batch, _ = tensor.shape + for mul, d in zip(self.muls, self.dims): + field = tensor[:, ix : ix + mul * d] # [batch, sample, mul * repr] + ix += mul * d + field = field.reshape(batch, mul, d) + out.append(field) + return torch.cat(out, dim=-1) + + +def mask_head(x: torch.Tensor, head: torch.Tensor, num_heads: int) -> torch.Tensor: + mask = torch.zeros(x.shape[0], x.shape[1] // num_heads, num_heads, device=x.device) + idx = torch.arange(mask.shape[0], device=x.device) + mask[idx, :, head] = 1 + mask = mask.permute(0, 2, 1).reshape(x.shape) + return x * mask diff --git a/mace/modules/loss.py b/mace/modules/loss.py new file mode 100644 index 00000000..a7e28c55 --- /dev/null +++ b/mace/modules/loss.py @@ -0,0 +1,383 @@ +########################################################################################### +# Implementation of different loss functions +# Authors: Ilyes Batatia, Gregor Simm +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import torch + +from mace.tools import TensorDict +from mace.tools.torch_geometric import Batch + + +def mean_squared_error_energy(ref: Batch, pred: TensorDict) -> torch.Tensor: + # energy: [n_graphs, ] + return torch.mean(torch.square(ref["energy"] - pred["energy"])) # [] + + +def weighted_mean_squared_error_energy(ref: Batch, pred: TensorDict) -> torch.Tensor: + # energy: [n_graphs, ] + configs_weight = ref.weight # [n_graphs, ] + configs_energy_weight = ref.energy_weight # [n_graphs, ] + num_atoms = ref.ptr[1:] - ref.ptr[:-1] # [n_graphs,] + return torch.mean( + configs_weight + * configs_energy_weight + * torch.square((ref["energy"] - pred["energy"]) / num_atoms) + ) # [] + + +def weighted_mean_squared_stress(ref: Batch, pred: TensorDict) -> torch.Tensor: + # energy: [n_graphs, ] + configs_weight = ref.weight.view(-1, 1, 1) # [n_graphs, ] + configs_stress_weight = ref.stress_weight.view(-1, 1, 1) # [n_graphs, ] + return torch.mean( + configs_weight + * configs_stress_weight + * torch.square(ref["stress"] - pred["stress"]) + ) # [] + + +def weighted_mean_squared_virials(ref: Batch, pred: TensorDict) -> torch.Tensor: + # energy: [n_graphs, ] + configs_weight = ref.weight.view(-1, 1, 1) # [n_graphs, ] + configs_virials_weight = ref.virials_weight.view(-1, 1, 1) # [n_graphs, ] + num_atoms = (ref.ptr[1:] - ref.ptr[:-1]).view(-1, 1, 1) # [n_graphs,] + return torch.mean( + configs_weight + * configs_virials_weight + * torch.square((ref["virials"] - pred["virials"]) / num_atoms) + ) # [] + + +def mean_squared_error_forces(ref: Batch, pred: TensorDict) -> torch.Tensor: + # forces: [n_atoms, 3] + configs_weight = torch.repeat_interleave( + ref.weight, ref.ptr[1:] - ref.ptr[:-1] + ).unsqueeze( + -1 + ) # [n_atoms, 1] + configs_forces_weight = torch.repeat_interleave( + ref.forces_weight, ref.ptr[1:] - ref.ptr[:-1] + ).unsqueeze( + -1 + ) # [n_atoms, 1] + return torch.mean( + configs_weight + * configs_forces_weight + * torch.square(ref["forces"] - pred["forces"]) + ) # [] + + +def weighted_mean_squared_error_dipole(ref: Batch, pred: TensorDict) -> torch.Tensor: + # dipole: [n_graphs, ] + num_atoms = (ref.ptr[1:] - ref.ptr[:-1]).unsqueeze(-1) # [n_graphs,1] + return torch.mean(torch.square((ref["dipole"] - pred["dipole"]) / num_atoms)) # [] + # return torch.mean(torch.square((torch.reshape(ref['dipole'], pred["dipole"].shape) - pred['dipole']) / num_atoms)) # [] + + +def conditional_mse_forces(ref: Batch, pred: TensorDict) -> torch.Tensor: + # forces: [n_atoms, 3] + configs_weight = torch.repeat_interleave( + ref.weight, ref.ptr[1:] - ref.ptr[:-1] + ).unsqueeze( + -1 + ) # [n_atoms, 1] + configs_forces_weight = torch.repeat_interleave( + ref.forces_weight, ref.ptr[1:] - ref.ptr[:-1] + ).unsqueeze( + -1 + ) # [n_atoms, 1] + + # Define the multiplication factors for each condition + factors = torch.tensor([1.0, 0.7, 0.4, 0.1]) + + # Apply multiplication factors based on conditions + c1 = torch.norm(ref["forces"], dim=-1) < 100 + c2 = (torch.norm(ref["forces"], dim=-1) >= 100) & ( + torch.norm(ref["forces"], dim=-1) < 200 + ) + c3 = (torch.norm(ref["forces"], dim=-1) >= 200) & ( + torch.norm(ref["forces"], dim=-1) < 300 + ) + + err = ref["forces"] - pred["forces"] + + se = torch.zeros_like(err) + + se[c1] = torch.square(err[c1]) * factors[0] + se[c2] = torch.square(err[c2]) * factors[1] + se[c3] = torch.square(err[c3]) * factors[2] + se[~(c1 | c2 | c3)] = torch.square(err[~(c1 | c2 | c3)]) * factors[3] + + return torch.mean(configs_weight * configs_forces_weight * se) + + +def conditional_huber_forces( + ref_forces: Batch, pred_forces: TensorDict, huber_delta: float +) -> torch.Tensor: + # Define the multiplication factors for each condition + factors = huber_delta * torch.tensor([1.0, 0.7, 0.4, 0.1]) + + # Apply multiplication factors based on conditions + c1 = torch.norm(ref_forces, dim=-1) < 100 + c2 = (torch.norm(ref_forces, dim=-1) >= 100) & ( + torch.norm(ref_forces, dim=-1) < 200 + ) + c3 = (torch.norm(ref_forces, dim=-1) >= 200) & ( + torch.norm(ref_forces, dim=-1) < 300 + ) + c4 = ~(c1 | c2 | c3) + + se = torch.zeros_like(pred_forces) + + se[c1] = torch.nn.functional.huber_loss( + ref_forces[c1], pred_forces[c1], reduction="none", delta=factors[0] + ) + se[c2] = torch.nn.functional.huber_loss( + ref_forces[c2], pred_forces[c2], reduction="none", delta=factors[1] + ) + se[c3] = torch.nn.functional.huber_loss( + ref_forces[c3], pred_forces[c3], reduction="none", delta=factors[2] + ) + se[c4] = torch.nn.functional.huber_loss( + ref_forces[c4], pred_forces[c4], reduction="none", delta=factors[3] + ) + + return torch.mean(se) + + +class WeightedEnergyForcesLoss(torch.nn.Module): + def __init__(self, energy_weight=1.0, forces_weight=1.0) -> None: + super().__init__() + self.register_buffer( + "energy_weight", + torch.tensor(energy_weight, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "forces_weight", + torch.tensor(forces_weight, dtype=torch.get_default_dtype()), + ) + + def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: + return self.energy_weight * weighted_mean_squared_error_energy( + ref, pred + ) + self.forces_weight * mean_squared_error_forces(ref, pred) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " + f"forces_weight={self.forces_weight:.3f})" + ) + + +class WeightedForcesLoss(torch.nn.Module): + def __init__(self, forces_weight=1.0) -> None: + super().__init__() + self.register_buffer( + "forces_weight", + torch.tensor(forces_weight, dtype=torch.get_default_dtype()), + ) + + def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: + return self.forces_weight * mean_squared_error_forces(ref, pred) + + def __repr__(self): + return f"{self.__class__.__name__}(" f"forces_weight={self.forces_weight:.3f})" + + +class WeightedEnergyForcesStressLoss(torch.nn.Module): + def __init__(self, energy_weight=1.0, forces_weight=1.0, stress_weight=1.0) -> None: + super().__init__() + self.register_buffer( + "energy_weight", + torch.tensor(energy_weight, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "forces_weight", + torch.tensor(forces_weight, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "stress_weight", + torch.tensor(stress_weight, dtype=torch.get_default_dtype()), + ) + + def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: + return ( + self.energy_weight * weighted_mean_squared_error_energy(ref, pred) + + self.forces_weight * mean_squared_error_forces(ref, pred) + + self.stress_weight * weighted_mean_squared_stress(ref, pred) + ) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " + f"forces_weight={self.forces_weight:.3f}, stress_weight={self.stress_weight:.3f})" + ) + + +class WeightedHuberEnergyForcesStressLoss(torch.nn.Module): + def __init__( + self, energy_weight=1.0, forces_weight=1.0, stress_weight=1.0, huber_delta=0.01 + ) -> None: + super().__init__() + self.huber_loss = torch.nn.HuberLoss(reduction="mean", delta=huber_delta) + self.register_buffer( + "energy_weight", + torch.tensor(energy_weight, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "forces_weight", + torch.tensor(forces_weight, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "stress_weight", + torch.tensor(stress_weight, dtype=torch.get_default_dtype()), + ) + + def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: + num_atoms = ref.ptr[1:] - ref.ptr[:-1] + return ( + self.energy_weight + * self.huber_loss(ref["energy"] / num_atoms, pred["energy"] / num_atoms) + + self.forces_weight * self.huber_loss(ref["forces"], pred["forces"]) + + self.stress_weight * self.huber_loss(ref["stress"], pred["stress"]) + ) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " + f"forces_weight={self.forces_weight:.3f}, stress_weight={self.stress_weight:.3f})" + ) + + +class UniversalLoss(torch.nn.Module): + def __init__( + self, energy_weight=1.0, forces_weight=1.0, stress_weight=1.0, huber_delta=0.01 + ) -> None: + super().__init__() + self.huber_delta = huber_delta + self.huber_loss = torch.nn.HuberLoss(reduction="mean", delta=huber_delta) + self.register_buffer( + "energy_weight", + torch.tensor(energy_weight, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "forces_weight", + torch.tensor(forces_weight, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "stress_weight", + torch.tensor(stress_weight, dtype=torch.get_default_dtype()), + ) + + def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: + num_atoms = ref.ptr[1:] - ref.ptr[:-1] + configs_stress_weight = ref.stress_weight.view(-1, 1, 1) # [n_graphs, ] + configs_energy_weight = ref.energy_weight # [n_graphs, ] + configs_forces_weight = torch.repeat_interleave( + ref.forces_weight, ref.ptr[1:] - ref.ptr[:-1] + ).unsqueeze(-1) + return ( + self.energy_weight + * self.huber_loss( + configs_energy_weight * ref["energy"] / num_atoms, + configs_energy_weight * pred["energy"] / num_atoms, + ) + + self.forces_weight + * conditional_huber_forces( + configs_forces_weight * ref["forces"], + configs_forces_weight * pred["forces"], + huber_delta=self.huber_delta, + ) + + self.stress_weight + * self.huber_loss( + configs_stress_weight * ref["stress"], + configs_stress_weight * pred["stress"], + ) + ) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " + f"forces_weight={self.forces_weight:.3f}, stress_weight={self.stress_weight:.3f})" + ) + + +class WeightedEnergyForcesVirialsLoss(torch.nn.Module): + def __init__( + self, energy_weight=1.0, forces_weight=1.0, virials_weight=1.0 + ) -> None: + super().__init__() + self.register_buffer( + "energy_weight", + torch.tensor(energy_weight, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "forces_weight", + torch.tensor(forces_weight, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "virials_weight", + torch.tensor(virials_weight, dtype=torch.get_default_dtype()), + ) + + def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: + return ( + self.energy_weight * weighted_mean_squared_error_energy(ref, pred) + + self.forces_weight * mean_squared_error_forces(ref, pred) + + self.virials_weight * weighted_mean_squared_virials(ref, pred) + ) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " + f"forces_weight={self.forces_weight:.3f}, virials_weight={self.virials_weight:.3f})" + ) + + +class DipoleSingleLoss(torch.nn.Module): + def __init__(self, dipole_weight=1.0) -> None: + super().__init__() + self.register_buffer( + "dipole_weight", + torch.tensor(dipole_weight, dtype=torch.get_default_dtype()), + ) + + def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: + return ( + self.dipole_weight * weighted_mean_squared_error_dipole(ref, pred) * 100.0 + ) # multiply by 100 to have the right scale for the loss + + def __repr__(self): + return f"{self.__class__.__name__}(" f"dipole_weight={self.dipole_weight:.3f})" + + +class WeightedEnergyForcesDipoleLoss(torch.nn.Module): + def __init__(self, energy_weight=1.0, forces_weight=1.0, dipole_weight=1.0) -> None: + super().__init__() + self.register_buffer( + "energy_weight", + torch.tensor(energy_weight, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "forces_weight", + torch.tensor(forces_weight, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "dipole_weight", + torch.tensor(dipole_weight, dtype=torch.get_default_dtype()), + ) + + def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: + return ( + self.energy_weight * weighted_mean_squared_error_energy(ref, pred) + + self.forces_weight * mean_squared_error_forces(ref, pred) + + self.dipole_weight * weighted_mean_squared_error_dipole(ref, pred) * 100 + ) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " + f"forces_weight={self.forces_weight:.3f}, dipole_weight={self.dipole_weight:.3f})" + ) diff --git a/mace/modules/models.py b/mace/modules/models.py new file mode 100644 index 00000000..c0d8ab43 --- /dev/null +++ b/mace/modules/models.py @@ -0,0 +1,1109 @@ +########################################################################################### +# Implementation of MACE models and other models based E(3)-Equivariant MPNNs +# Authors: Ilyes Batatia, Gregor Simm +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +from typing import Any, Callable, Dict, List, Optional, Type, Union + +import numpy as np +import torch +from e3nn import o3 +from e3nn.util.jit import compile_mode + +from mace.data import AtomicData +from mace.modules.radial import ZBLBasis +from mace.tools.scatter import scatter_sum + +from .blocks import ( + AtomicEnergiesBlock, + EquivariantProductBasisBlock, + InteractionBlock, + LinearDipoleReadoutBlock, + LinearNodeEmbeddingBlock, + LinearReadoutBlock, + NonLinearDipoleReadoutBlock, + NonLinearReadoutBlock, + RadialEmbeddingBlock, + ScaleShiftBlock, +) +from .utils import ( + compute_fixed_charge_dipole, + compute_forces, + get_edge_vectors_and_lengths, + get_outputs, + get_symmetric_displacement, +) + +# pylint: disable=C0302 + + +@compile_mode("script") +class MACE(torch.nn.Module): + def __init__( + self, + r_max: float, + num_bessel: int, + num_polynomial_cutoff: int, + max_ell: int, + interaction_cls: Type[InteractionBlock], + interaction_cls_first: Type[InteractionBlock], + num_interactions: int, + num_elements: int, + hidden_irreps: o3.Irreps, + MLP_irreps: o3.Irreps, + atomic_energies: np.ndarray, + avg_num_neighbors: float, + atomic_numbers: List[int], + correlation: Union[int, List[int]], + gate: Optional[Callable], + pair_repulsion: bool = False, + distance_transform: str = "None", + radial_MLP: Optional[List[int]] = None, + radial_type: Optional[str] = "bessel", + heads: Optional[List[str]] = None, + ): + super().__init__() + self.register_buffer( + "atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64) + ) + self.register_buffer( + "r_max", torch.tensor(r_max, dtype=torch.get_default_dtype()) + ) + self.register_buffer( + "num_interactions", torch.tensor(num_interactions, dtype=torch.int64) + ) + if heads is None: + heads = ["default"] + self.heads = heads + if isinstance(correlation, int): + correlation = [correlation] * num_interactions + # Embedding + node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) + node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) + self.node_embedding = LinearNodeEmbeddingBlock( + irreps_in=node_attr_irreps, irreps_out=node_feats_irreps + ) + self.radial_embedding = RadialEmbeddingBlock( + r_max=r_max, + num_bessel=num_bessel, + num_polynomial_cutoff=num_polynomial_cutoff, + radial_type=radial_type, + distance_transform=distance_transform, + ) + edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") + if pair_repulsion: + self.pair_repulsion_fn = ZBLBasis(r_max=r_max, p=num_polynomial_cutoff) + self.pair_repulsion = True + + sh_irreps = o3.Irreps.spherical_harmonics(max_ell) + num_features = hidden_irreps.count(o3.Irrep(0, 1)) + interaction_irreps = (sh_irreps * num_features).sort()[0].simplify() + self.spherical_harmonics = o3.SphericalHarmonics( + sh_irreps, normalize=True, normalization="component" + ) + if radial_MLP is None: + radial_MLP = [64, 64, 64] + # Interactions and readout + self.atomic_energies_fn = AtomicEnergiesBlock(atomic_energies) + + inter = interaction_cls_first( + node_attrs_irreps=node_attr_irreps, + node_feats_irreps=node_feats_irreps, + edge_attrs_irreps=sh_irreps, + edge_feats_irreps=edge_feats_irreps, + target_irreps=interaction_irreps, + hidden_irreps=hidden_irreps, + avg_num_neighbors=avg_num_neighbors, + radial_MLP=radial_MLP, + ) + self.interactions = torch.nn.ModuleList([inter]) + + # Use the appropriate self connection at the first layer for proper E0 + use_sc_first = False + if "Residual" in str(interaction_cls_first): + use_sc_first = True + + node_feats_irreps_out = inter.target_irreps + prod = EquivariantProductBasisBlock( + node_feats_irreps=node_feats_irreps_out, + target_irreps=hidden_irreps, + correlation=correlation[0], + num_elements=num_elements, + use_sc=use_sc_first, + ) + self.products = torch.nn.ModuleList([prod]) + + self.readouts = torch.nn.ModuleList() + self.readouts.append( + LinearReadoutBlock(hidden_irreps, o3.Irreps(f"{len(heads)}x0e")) + ) + + for i in range(num_interactions - 1): + if i == num_interactions - 2: + hidden_irreps_out = str( + hidden_irreps[0] + ) # Select only scalars for last layer + else: + hidden_irreps_out = hidden_irreps + inter = interaction_cls( + node_attrs_irreps=node_attr_irreps, + node_feats_irreps=hidden_irreps, + edge_attrs_irreps=sh_irreps, + edge_feats_irreps=edge_feats_irreps, + target_irreps=interaction_irreps, + hidden_irreps=hidden_irreps_out, + avg_num_neighbors=avg_num_neighbors, + radial_MLP=radial_MLP, + ) + self.interactions.append(inter) + prod = EquivariantProductBasisBlock( + node_feats_irreps=interaction_irreps, + target_irreps=hidden_irreps_out, + correlation=correlation[i + 1], + num_elements=num_elements, + use_sc=True, + ) + self.products.append(prod) + if i == num_interactions - 2: + self.readouts.append( + NonLinearReadoutBlock( + hidden_irreps_out, + (len(heads) * MLP_irreps).simplify(), + gate, + o3.Irreps(f"{len(heads)}x0e"), + len(heads), + ) + ) + else: + self.readouts.append( + LinearReadoutBlock(hidden_irreps, o3.Irreps(f"{len(heads)}x0e")) + ) + + def forward( + self, + data: Dict[str, torch.Tensor], + training: bool = False, + compute_force: bool = True, + compute_virials: bool = False, + compute_stress: bool = False, + compute_displacement: bool = False, + compute_hessian: bool = False, + ) -> Dict[str, Optional[torch.Tensor]]: + # Setup + data["node_attrs"].requires_grad_(True) + data["positions"].requires_grad_(True) + num_atoms_arange = torch.arange(data["positions"].shape[0]) + num_graphs = data["ptr"].numel() - 1 + node_heads = ( + data["head"][data["batch"]] + if "head" in data + else torch.zeros_like(data["batch"]) + ) + displacement = torch.zeros( + (num_graphs, 3, 3), + dtype=data["positions"].dtype, + device=data["positions"].device, + ) + if compute_virials or compute_stress or compute_displacement: + ( + data["positions"], + data["shifts"], + displacement, + ) = get_symmetric_displacement( + positions=data["positions"], + unit_shifts=data["unit_shifts"], + cell=data["cell"], + edge_index=data["edge_index"], + num_graphs=num_graphs, + batch=data["batch"], + ) + + # Atomic energies + node_e0 = self.atomic_energies_fn(data["node_attrs"])[ + num_atoms_arange, node_heads + ] + e0 = scatter_sum( + src=node_e0, index=data["batch"], dim=0, dim_size=num_graphs + ) # [n_graphs, n_heads] + # Embeddings + node_feats = self.node_embedding(data["node_attrs"]) + vectors, lengths = get_edge_vectors_and_lengths( + positions=data["positions"], + edge_index=data["edge_index"], + shifts=data["shifts"], + ) + edge_attrs = self.spherical_harmonics(vectors) + edge_feats = self.radial_embedding( + lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers + ) + if hasattr(self, "pair_repulsion"): + pair_node_energy = self.pair_repulsion_fn( + lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers + ) + pair_energy = scatter_sum( + src=pair_node_energy, index=data["batch"], dim=-1, dim_size=num_graphs + ) # [n_graphs,] + else: + pair_node_energy = torch.zeros_like(node_e0) + pair_energy = torch.zeros_like(e0) + + # Interactions + energies = [e0, pair_energy] + node_energies_list = [node_e0, pair_node_energy] + node_feats_list = [] + for interaction, product, readout in zip( + self.interactions, self.products, self.readouts + ): + node_feats, sc = interaction( + node_attrs=data["node_attrs"], + node_feats=node_feats, + edge_attrs=edge_attrs, + edge_feats=edge_feats, + edge_index=data["edge_index"], + ) + node_feats = product( + node_feats=node_feats, + sc=sc, + node_attrs=data["node_attrs"], + ) + node_feats_list.append(node_feats) + node_energies = readout(node_feats, node_heads)[ + num_atoms_arange, node_heads + ] # [n_nodes, len(heads)] + energy = scatter_sum( + src=node_energies, + index=data["batch"], + dim=0, + dim_size=num_graphs, + ) # [n_graphs,] + energies.append(energy) + node_energies_list.append(node_energies) + # Concatenate node features + node_feats_out = torch.cat(node_feats_list, dim=-1) + + # Sum over energy contributions + contributions = torch.stack(energies, dim=-1) + total_energy = torch.sum(contributions, dim=-1) # [n_graphs, ] + node_energy_contributions = torch.stack(node_energies_list, dim=-1) + node_energy = torch.sum(node_energy_contributions, dim=-1) # [n_nodes, ] + + # Outputs + forces, virials, stress, hessian = get_outputs( + energy=total_energy, + positions=data["positions"], + displacement=displacement, + cell=data["cell"], + training=training, + compute_force=compute_force, + compute_virials=compute_virials, + compute_stress=compute_stress, + compute_hessian=compute_hessian, + ) + + return { + "energy": total_energy, + "node_energy": node_energy, + "contributions": contributions, + "forces": forces, + "virials": virials, + "stress": stress, + "displacement": displacement, + "hessian": hessian, + "node_feats": node_feats_out, + } + + +@compile_mode("script") +class ScaleShiftMACE(MACE): + def __init__( + self, + atomic_inter_scale: float, + atomic_inter_shift: float, + **kwargs, + ): + super().__init__(**kwargs) + self.scale_shift = ScaleShiftBlock( + scale=atomic_inter_scale, shift=atomic_inter_shift + ) + + def forward( + self, + data: Dict[str, torch.Tensor], + training: bool = False, + compute_force: bool = True, + compute_virials: bool = False, + compute_stress: bool = False, + compute_displacement: bool = False, + compute_hessian: bool = False, + ) -> Dict[str, Optional[torch.Tensor]]: + # Setup + data["positions"].requires_grad_(True) + data["node_attrs"].requires_grad_(True) + num_graphs = data["ptr"].numel() - 1 + num_atoms_arange = torch.arange(data["positions"].shape[0]) + node_heads = ( + data["head"][data["batch"]] + if "head" in data + else torch.zeros_like(data["batch"]) + ) + displacement = torch.zeros( + (num_graphs, 3, 3), + dtype=data["positions"].dtype, + device=data["positions"].device, + ) + if compute_virials or compute_stress or compute_displacement: + ( + data["positions"], + data["shifts"], + displacement, + ) = get_symmetric_displacement( + positions=data["positions"], + unit_shifts=data["unit_shifts"], + cell=data["cell"], + edge_index=data["edge_index"], + num_graphs=num_graphs, + batch=data["batch"], + ) + + # Atomic energies + node_e0 = self.atomic_energies_fn(data["node_attrs"])[ + num_atoms_arange, node_heads + ] + e0 = scatter_sum( + src=node_e0, index=data["batch"], dim=0, dim_size=num_graphs + ) # [n_graphs, num_heads] + + # Embeddings + node_feats = self.node_embedding(data["node_attrs"]) + vectors, lengths = get_edge_vectors_and_lengths( + positions=data["positions"], + edge_index=data["edge_index"], + shifts=data["shifts"], + ) + edge_attrs = self.spherical_harmonics(vectors) + edge_feats = self.radial_embedding( + lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers + ) + if hasattr(self, "pair_repulsion"): + pair_node_energy = self.pair_repulsion_fn( + lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers + ) + else: + pair_node_energy = torch.zeros_like(node_e0) + # Interactions + node_es_list = [pair_node_energy] + node_feats_list = [] + for interaction, product, readout in zip( + self.interactions, self.products, self.readouts + ): + node_feats, sc = interaction( + node_attrs=data["node_attrs"], + node_feats=node_feats, + edge_attrs=edge_attrs, + edge_feats=edge_feats, + edge_index=data["edge_index"], + ) + node_feats = product( + node_feats=node_feats, sc=sc, node_attrs=data["node_attrs"] + ) + node_feats_list.append(node_feats) + node_es_list.append( + readout(node_feats, node_heads)[num_atoms_arange, node_heads] + ) # {[n_nodes, ], } + + # Concatenate node features + node_feats_out = torch.cat(node_feats_list, dim=-1) + # Sum over interactions + node_inter_es = torch.sum( + torch.stack(node_es_list, dim=0), dim=0 + ) # [n_nodes, ] + node_inter_es = self.scale_shift(node_inter_es, node_heads) + + # Sum over nodes in graph + inter_e = scatter_sum( + src=node_inter_es, index=data["batch"], dim=-1, dim_size=num_graphs + ) # [n_graphs,] + + # Add E_0 and (scaled) interaction energy + total_energy = e0 + inter_e + node_energy = node_e0 + node_inter_es + forces, virials, stress, hessian = get_outputs( + energy=inter_e, + positions=data["positions"], + displacement=displacement, + cell=data["cell"], + training=training, + compute_force=compute_force, + compute_virials=compute_virials, + compute_stress=compute_stress, + compute_hessian=compute_hessian, + ) + output = { + "energy": total_energy, + "node_energy": node_energy, + "interaction_energy": inter_e, + "forces": forces, + "virials": virials, + "stress": stress, + "hessian": hessian, + "displacement": displacement, + "node_feats": node_feats_out, + } + + return output + + +class BOTNet(torch.nn.Module): + def __init__( + self, + r_max: float, + num_bessel: int, + num_polynomial_cutoff: int, + max_ell: int, + interaction_cls: Type[InteractionBlock], + interaction_cls_first: Type[InteractionBlock], + num_interactions: int, + num_elements: int, + hidden_irreps: o3.Irreps, + MLP_irreps: o3.Irreps, + atomic_energies: np.ndarray, + gate: Optional[Callable], + avg_num_neighbors: float, + atomic_numbers: List[int], + ): + super().__init__() + self.r_max = r_max + self.atomic_numbers = atomic_numbers + # Embedding + node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) + node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) + self.node_embedding = LinearNodeEmbeddingBlock( + irreps_in=node_attr_irreps, irreps_out=node_feats_irreps + ) + self.radial_embedding = RadialEmbeddingBlock( + r_max=r_max, + num_bessel=num_bessel, + num_polynomial_cutoff=num_polynomial_cutoff, + ) + edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") + + sh_irreps = o3.Irreps.spherical_harmonics(max_ell) + self.spherical_harmonics = o3.SphericalHarmonics( + sh_irreps, normalize=True, normalization="component" + ) + + # Interactions and readouts + self.atomic_energies_fn = AtomicEnergiesBlock(atomic_energies) + + self.interactions = torch.nn.ModuleList() + self.readouts = torch.nn.ModuleList() + + inter = interaction_cls_first( + node_attrs_irreps=node_attr_irreps, + node_feats_irreps=node_feats_irreps, + edge_attrs_irreps=sh_irreps, + edge_feats_irreps=edge_feats_irreps, + target_irreps=hidden_irreps, + avg_num_neighbors=avg_num_neighbors, + ) + self.interactions.append(inter) + self.readouts.append(LinearReadoutBlock(inter.irreps_out)) + + for i in range(num_interactions - 1): + inter = interaction_cls( + node_attrs_irreps=node_attr_irreps, + node_feats_irreps=inter.irreps_out, + edge_attrs_irreps=sh_irreps, + edge_feats_irreps=edge_feats_irreps, + target_irreps=hidden_irreps, + avg_num_neighbors=avg_num_neighbors, + ) + self.interactions.append(inter) + if i == num_interactions - 2: + self.readouts.append( + NonLinearReadoutBlock(inter.irreps_out, MLP_irreps, gate) + ) + else: + self.readouts.append(LinearReadoutBlock(inter.irreps_out)) + + def forward(self, data: AtomicData, training=False) -> Dict[str, Any]: + # Setup + data.positions.requires_grad = True + num_atoms_arange = torch.arange(data.positions.shape[0]) + + # Atomic energies + node_e0 = self.atomic_energies_fn(data["node_attrs"])[ + num_atoms_arange, data["head"][data["batch"]] + ] + e0 = scatter_sum( + src=node_e0, index=data.batch, dim=-1, dim_size=data.num_graphs + ) # [n_graphs, n_heads] + + # Embeddings + node_feats = self.node_embedding(data.node_attrs) + vectors, lengths = get_edge_vectors_and_lengths( + positions=data.positions, edge_index=data.edge_index, shifts=data.shifts + ) + edge_attrs = self.spherical_harmonics(vectors) + edge_feats = self.radial_embedding( + lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers + ) + + # Interactions + energies = [e0] + for interaction, readout in zip(self.interactions, self.readouts): + node_feats = interaction( + node_attrs=data.node_attrs, + node_feats=node_feats, + edge_attrs=edge_attrs, + edge_feats=edge_feats, + edge_index=data.edge_index, + ) + node_energies = readout(node_feats).squeeze(-1) # [n_nodes, ] + energy = scatter_sum( + src=node_energies, index=data.batch, dim=-1, dim_size=data.num_graphs + ) # [n_graphs,] + energies.append(energy) + + # Sum over energy contributions + contributions = torch.stack(energies, dim=-1) + total_energy = torch.sum(contributions, dim=-1) # [n_graphs, ] + + output = { + "energy": total_energy, + "contributions": contributions, + "forces": compute_forces( + energy=total_energy, positions=data.positions, training=training + ), + } + + return output + + +class ScaleShiftBOTNet(BOTNet): + def __init__( + self, + atomic_inter_scale: float, + atomic_inter_shift: float, + **kwargs, + ): + super().__init__(**kwargs) + self.scale_shift = ScaleShiftBlock( + scale=atomic_inter_scale, shift=atomic_inter_shift + ) + + def forward(self, data: AtomicData, training=False) -> Dict[str, Any]: + # Setup + data.positions.requires_grad = True + num_atoms_arange = torch.arange(data.positions.shape[0]) + # Atomic energies + node_e0 = self.atomic_energies_fn(data["node_attrs"])[ + num_atoms_arange, data["head"][data["batch"]] + ] + e0 = scatter_sum( + src=node_e0, index=data.batch, dim=-1, dim_size=data.num_graphs + ) # [n_graphs,] + + # Embeddings + node_feats = self.node_embedding(data.node_attrs) + vectors, lengths = get_edge_vectors_and_lengths( + positions=data.positions, edge_index=data.edge_index, shifts=data.shifts + ) + edge_attrs = self.spherical_harmonics(vectors) + edge_feats = self.radial_embedding( + lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers + ) + + # Interactions + node_es_list = [] + for interaction, readout in zip(self.interactions, self.readouts): + node_feats = interaction( + node_attrs=data.node_attrs, + node_feats=node_feats, + edge_attrs=edge_attrs, + edge_feats=edge_feats, + edge_index=data.edge_index, + ) + + node_es_list.append(readout(node_feats).squeeze(-1)) # {[n_nodes, ], } + + # Sum over interactions + node_inter_es = torch.sum( + torch.stack(node_es_list, dim=0), dim=0 + ) # [n_nodes, ] + node_inter_es = self.scale_shift(node_inter_es, data["head"][data["batch"]]) + + # Sum over nodes in graph + inter_e = scatter_sum( + src=node_inter_es, index=data.batch, dim=-1, dim_size=data.num_graphs + ) # [n_graphs,] + + # Add E_0 and (scaled) interaction energy + total_e = e0 + inter_e + + output = { + "energy": total_e, + "forces": compute_forces( + energy=inter_e, positions=data.positions, training=training + ), + } + + return output + + +@compile_mode("script") +class AtomicDipolesMACE(torch.nn.Module): + def __init__( + self, + r_max: float, + num_bessel: int, + num_polynomial_cutoff: int, + max_ell: int, + interaction_cls: Type[InteractionBlock], + interaction_cls_first: Type[InteractionBlock], + num_interactions: int, + num_elements: int, + hidden_irreps: o3.Irreps, + MLP_irreps: o3.Irreps, + avg_num_neighbors: float, + atomic_numbers: List[int], + correlation: int, + gate: Optional[Callable], + atomic_energies: Optional[ + None + ], # Just here to make it compatible with energy models, MUST be None + radial_type: Optional[str] = "bessel", + radial_MLP: Optional[List[int]] = None, + ): + super().__init__() + self.register_buffer( + "atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64) + ) + self.register_buffer("r_max", torch.tensor(r_max, dtype=torch.float64)) + self.register_buffer( + "num_interactions", torch.tensor(num_interactions, dtype=torch.int64) + ) + assert atomic_energies is None + + # Embedding + node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) + node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) + self.node_embedding = LinearNodeEmbeddingBlock( + irreps_in=node_attr_irreps, irreps_out=node_feats_irreps + ) + self.radial_embedding = RadialEmbeddingBlock( + r_max=r_max, + num_bessel=num_bessel, + num_polynomial_cutoff=num_polynomial_cutoff, + radial_type=radial_type, + ) + edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") + + sh_irreps = o3.Irreps.spherical_harmonics(max_ell) + num_features = hidden_irreps.count(o3.Irrep(0, 1)) + interaction_irreps = (sh_irreps * num_features).sort()[0].simplify() + self.spherical_harmonics = o3.SphericalHarmonics( + sh_irreps, normalize=True, normalization="component" + ) + if radial_MLP is None: + radial_MLP = [64, 64, 64] + + # Interactions and readouts + inter = interaction_cls_first( + node_attrs_irreps=node_attr_irreps, + node_feats_irreps=node_feats_irreps, + edge_attrs_irreps=sh_irreps, + edge_feats_irreps=edge_feats_irreps, + target_irreps=interaction_irreps, + hidden_irreps=hidden_irreps, + avg_num_neighbors=avg_num_neighbors, + radial_MLP=radial_MLP, + ) + self.interactions = torch.nn.ModuleList([inter]) + + # Use the appropriate self connection at the first layer + use_sc_first = False + if "Residual" in str(interaction_cls_first): + use_sc_first = True + + node_feats_irreps_out = inter.target_irreps + prod = EquivariantProductBasisBlock( + node_feats_irreps=node_feats_irreps_out, + target_irreps=hidden_irreps, + correlation=correlation, + num_elements=num_elements, + use_sc=use_sc_first, + ) + self.products = torch.nn.ModuleList([prod]) + + self.readouts = torch.nn.ModuleList() + self.readouts.append(LinearDipoleReadoutBlock(hidden_irreps, dipole_only=True)) + + for i in range(num_interactions - 1): + if i == num_interactions - 2: + assert ( + len(hidden_irreps) > 1 + ), "To predict dipoles use at least l=1 hidden_irreps" + hidden_irreps_out = str( + hidden_irreps[1] + ) # Select only l=1 vectors for last layer + else: + hidden_irreps_out = hidden_irreps + inter = interaction_cls( + node_attrs_irreps=node_attr_irreps, + node_feats_irreps=hidden_irreps, + edge_attrs_irreps=sh_irreps, + edge_feats_irreps=edge_feats_irreps, + target_irreps=interaction_irreps, + hidden_irreps=hidden_irreps_out, + avg_num_neighbors=avg_num_neighbors, + radial_MLP=radial_MLP, + ) + self.interactions.append(inter) + prod = EquivariantProductBasisBlock( + node_feats_irreps=interaction_irreps, + target_irreps=hidden_irreps_out, + correlation=correlation, + num_elements=num_elements, + use_sc=True, + ) + self.products.append(prod) + if i == num_interactions - 2: + self.readouts.append( + NonLinearDipoleReadoutBlock( + hidden_irreps_out, MLP_irreps, gate, dipole_only=True + ) + ) + else: + self.readouts.append( + LinearDipoleReadoutBlock(hidden_irreps, dipole_only=True) + ) + + def forward( + self, + data: Dict[str, torch.Tensor], + training: bool = False, # pylint: disable=W0613 + compute_force: bool = False, + compute_virials: bool = False, + compute_stress: bool = False, + compute_displacement: bool = False, + ) -> Dict[str, Optional[torch.Tensor]]: + assert compute_force is False + assert compute_virials is False + assert compute_stress is False + assert compute_displacement is False + # Setup + data["node_attrs"].requires_grad_(True) + data["positions"].requires_grad_(True) + num_graphs = data["ptr"].numel() - 1 + + # Embeddings + node_feats = self.node_embedding(data["node_attrs"]) + vectors, lengths = get_edge_vectors_and_lengths( + positions=data["positions"], + edge_index=data["edge_index"], + shifts=data["shifts"], + ) + edge_attrs = self.spherical_harmonics(vectors) + edge_feats = self.radial_embedding( + lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers + ) + + # Interactions + dipoles = [] + for interaction, product, readout in zip( + self.interactions, self.products, self.readouts + ): + node_feats, sc = interaction( + node_attrs=data["node_attrs"], + node_feats=node_feats, + edge_attrs=edge_attrs, + edge_feats=edge_feats, + edge_index=data["edge_index"], + ) + node_feats = product( + node_feats=node_feats, + sc=sc, + node_attrs=data["node_attrs"], + ) + node_dipoles = readout(node_feats).squeeze(-1) # [n_nodes,3] + dipoles.append(node_dipoles) + + # Compute the dipoles + contributions_dipoles = torch.stack( + dipoles, dim=-1 + ) # [n_nodes,3,n_contributions] + atomic_dipoles = torch.sum(contributions_dipoles, dim=-1) # [n_nodes,3] + total_dipole = scatter_sum( + src=atomic_dipoles, + index=data["batch"], + dim=0, + dim_size=num_graphs, + ) # [n_graphs,3] + baseline = compute_fixed_charge_dipole( + charges=data["charges"], + positions=data["positions"], + batch=data["batch"], + num_graphs=num_graphs, + ) # [n_graphs,3] + total_dipole = total_dipole + baseline + + output = { + "dipole": total_dipole, + "atomic_dipoles": atomic_dipoles, + } + return output + + +@compile_mode("script") +class EnergyDipolesMACE(torch.nn.Module): + def __init__( + self, + r_max: float, + num_bessel: int, + num_polynomial_cutoff: int, + max_ell: int, + interaction_cls: Type[InteractionBlock], + interaction_cls_first: Type[InteractionBlock], + num_interactions: int, + num_elements: int, + hidden_irreps: o3.Irreps, + MLP_irreps: o3.Irreps, + avg_num_neighbors: float, + atomic_numbers: List[int], + correlation: int, + gate: Optional[Callable], + atomic_energies: Optional[np.ndarray], + radial_MLP: Optional[List[int]] = None, + ): + super().__init__() + self.register_buffer( + "atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64) + ) + self.register_buffer("r_max", torch.tensor(r_max, dtype=torch.float64)) + self.register_buffer( + "num_interactions", torch.tensor(num_interactions, dtype=torch.int64) + ) + # Embedding + node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) + node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) + self.node_embedding = LinearNodeEmbeddingBlock( + irreps_in=node_attr_irreps, irreps_out=node_feats_irreps + ) + self.radial_embedding = RadialEmbeddingBlock( + r_max=r_max, + num_bessel=num_bessel, + num_polynomial_cutoff=num_polynomial_cutoff, + ) + edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") + + sh_irreps = o3.Irreps.spherical_harmonics(max_ell) + num_features = hidden_irreps.count(o3.Irrep(0, 1)) + interaction_irreps = (sh_irreps * num_features).sort()[0].simplify() + self.spherical_harmonics = o3.SphericalHarmonics( + sh_irreps, normalize=True, normalization="component" + ) + if radial_MLP is None: + radial_MLP = [64, 64, 64] + # Interactions and readouts + self.atomic_energies_fn = AtomicEnergiesBlock(atomic_energies) + + inter = interaction_cls_first( + node_attrs_irreps=node_attr_irreps, + node_feats_irreps=node_feats_irreps, + edge_attrs_irreps=sh_irreps, + edge_feats_irreps=edge_feats_irreps, + target_irreps=interaction_irreps, + hidden_irreps=hidden_irreps, + avg_num_neighbors=avg_num_neighbors, + radial_MLP=radial_MLP, + ) + self.interactions = torch.nn.ModuleList([inter]) + + # Use the appropriate self connection at the first layer + use_sc_first = False + if "Residual" in str(interaction_cls_first): + use_sc_first = True + + node_feats_irreps_out = inter.target_irreps + prod = EquivariantProductBasisBlock( + node_feats_irreps=node_feats_irreps_out, + target_irreps=hidden_irreps, + correlation=correlation, + num_elements=num_elements, + use_sc=use_sc_first, + ) + self.products = torch.nn.ModuleList([prod]) + + self.readouts = torch.nn.ModuleList() + self.readouts.append(LinearDipoleReadoutBlock(hidden_irreps, dipole_only=False)) + + for i in range(num_interactions - 1): + if i == num_interactions - 2: + assert ( + len(hidden_irreps) > 1 + ), "To predict dipoles use at least l=1 hidden_irreps" + hidden_irreps_out = str( + hidden_irreps[:2] + ) # Select scalars and l=1 vectors for last layer + else: + hidden_irreps_out = hidden_irreps + inter = interaction_cls( + node_attrs_irreps=node_attr_irreps, + node_feats_irreps=hidden_irreps, + edge_attrs_irreps=sh_irreps, + edge_feats_irreps=edge_feats_irreps, + target_irreps=interaction_irreps, + hidden_irreps=hidden_irreps_out, + avg_num_neighbors=avg_num_neighbors, + radial_MLP=radial_MLP, + ) + self.interactions.append(inter) + prod = EquivariantProductBasisBlock( + node_feats_irreps=interaction_irreps, + target_irreps=hidden_irreps_out, + correlation=correlation, + num_elements=num_elements, + use_sc=True, + ) + self.products.append(prod) + if i == num_interactions - 2: + self.readouts.append( + NonLinearDipoleReadoutBlock( + hidden_irreps_out, MLP_irreps, gate, dipole_only=False + ) + ) + else: + self.readouts.append( + LinearDipoleReadoutBlock(hidden_irreps, dipole_only=False) + ) + + def forward( + self, + data: Dict[str, torch.Tensor], + training: bool = False, + compute_force: bool = True, + compute_virials: bool = False, + compute_stress: bool = False, + compute_displacement: bool = False, + ) -> Dict[str, Optional[torch.Tensor]]: + # Setup + data["node_attrs"].requires_grad_(True) + data["positions"].requires_grad_(True) + num_graphs = data["ptr"].numel() - 1 + num_atoms_arange = torch.arange(data["positions"].shape[0]) + displacement = torch.zeros( + (num_graphs, 3, 3), + dtype=data["positions"].dtype, + device=data["positions"].device, + ) + if compute_virials or compute_stress or compute_displacement: + ( + data["positions"], + data["shifts"], + displacement, + ) = get_symmetric_displacement( + positions=data["positions"], + unit_shifts=data["unit_shifts"], + cell=data["cell"], + edge_index=data["edge_index"], + num_graphs=num_graphs, + batch=data["batch"], + ) + + # Atomic energies + node_e0 = self.atomic_energies_fn(data["node_attrs"])[ + num_atoms_arange, data["head"][data["batch"]] + ] + e0 = scatter_sum( + src=node_e0, index=data["batch"], dim=-1, dim_size=num_graphs + ) # [n_graphs,] + + # Embeddings + node_feats = self.node_embedding(data["node_attrs"]) + vectors, lengths = get_edge_vectors_and_lengths( + positions=data["positions"], + edge_index=data["edge_index"], + shifts=data["shifts"], + ) + edge_attrs = self.spherical_harmonics(vectors) + edge_feats = self.radial_embedding( + lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers + ) + + # Interactions + energies = [e0] + node_energies_list = [node_e0] + dipoles = [] + for interaction, product, readout in zip( + self.interactions, self.products, self.readouts + ): + node_feats, sc = interaction( + node_attrs=data["node_attrs"], + node_feats=node_feats, + edge_attrs=edge_attrs, + edge_feats=edge_feats, + edge_index=data["edge_index"], + ) + node_feats = product( + node_feats=node_feats, + sc=sc, + node_attrs=data["node_attrs"], + ) + node_out = readout(node_feats).squeeze(-1) # [n_nodes, ] + # node_energies = readout(node_feats).squeeze(-1) # [n_nodes, ] + node_energies = node_out[:, 0] + energy = scatter_sum( + src=node_energies, index=data["batch"], dim=-1, dim_size=num_graphs + ) # [n_graphs,] + energies.append(energy) + node_dipoles = node_out[:, 1:] + dipoles.append(node_dipoles) + + # Compute the energies and dipoles + contributions = torch.stack(energies, dim=-1) + total_energy = torch.sum(contributions, dim=-1) # [n_graphs, ] + node_energy_contributions = torch.stack(node_energies_list, dim=-1) + node_energy = torch.sum(node_energy_contributions, dim=-1) # [n_nodes, ] + contributions_dipoles = torch.stack( + dipoles, dim=-1 + ) # [n_nodes,3,n_contributions] + atomic_dipoles = torch.sum(contributions_dipoles, dim=-1) # [n_nodes,3] + total_dipole = scatter_sum( + src=atomic_dipoles, + index=data["batch"].unsqueeze(-1), + dim=0, + dim_size=num_graphs, + ) # [n_graphs,3] + baseline = compute_fixed_charge_dipole( + charges=data["charges"], + positions=data["positions"], + batch=data["batch"], + num_graphs=num_graphs, + ) # [n_graphs,3] + total_dipole = total_dipole + baseline + + forces, virials, stress, _ = get_outputs( + energy=total_energy, + positions=data["positions"], + displacement=displacement, + cell=data["cell"], + training=training, + compute_force=compute_force, + compute_virials=compute_virials, + compute_stress=compute_stress, + ) + + output = { + "energy": total_energy, + "node_energy": node_energy, + "contributions": contributions, + "forces": forces, + "virials": virials, + "stress": stress, + "displacement": displacement, + "dipole": total_dipole, + "atomic_dipoles": atomic_dipoles, + } + return output diff --git a/mace/modules/radial.py b/mace/modules/radial.py new file mode 100644 index 00000000..a928c184 --- /dev/null +++ b/mace/modules/radial.py @@ -0,0 +1,323 @@ +########################################################################################### +# Radial basis and cutoff +# Authors: Ilyes Batatia, Gregor Simm +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import ase +import numpy as np +import torch +from e3nn.util.jit import compile_mode + +from mace.tools.compile import simplify_if_compile +from mace.tools.scatter import scatter_sum + + +@compile_mode("script") +class BesselBasis(torch.nn.Module): + """ + Equation (7) + """ + + def __init__(self, r_max: float, num_basis=8, trainable=False): + super().__init__() + + bessel_weights = ( + np.pi + / r_max + * torch.linspace( + start=1.0, + end=num_basis, + steps=num_basis, + dtype=torch.get_default_dtype(), + ) + ) + if trainable: + self.bessel_weights = torch.nn.Parameter(bessel_weights) + else: + self.register_buffer("bessel_weights", bessel_weights) + + self.register_buffer( + "r_max", torch.tensor(r_max, dtype=torch.get_default_dtype()) + ) + self.register_buffer( + "prefactor", + torch.tensor(np.sqrt(2.0 / r_max), dtype=torch.get_default_dtype()), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: # [..., 1] + numerator = torch.sin(self.bessel_weights * x) # [..., num_basis] + return self.prefactor * (numerator / x) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(r_max={self.r_max}, num_basis={len(self.bessel_weights)}, " + f"trainable={self.bessel_weights.requires_grad})" + ) + + +@compile_mode("script") +class ChebychevBasis(torch.nn.Module): + """ + Equation (7) + """ + + def __init__(self, r_max: float, num_basis=8): + super().__init__() + self.register_buffer( + "n", + torch.arange(1, num_basis + 1, dtype=torch.get_default_dtype()).unsqueeze( + 0 + ), + ) + self.num_basis = num_basis + self.r_max = r_max + + def forward(self, x: torch.Tensor) -> torch.Tensor: # [..., 1] + x = x.repeat(1, self.num_basis) + n = self.n.repeat(len(x), 1) + return torch.special.chebyshev_polynomial_t(x, n) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(r_max={self.r_max}, num_basis={self.num_basis}," + ) + + +@compile_mode("script") +class GaussianBasis(torch.nn.Module): + """ + Gaussian basis functions + """ + + def __init__(self, r_max: float, num_basis=128, trainable=False): + super().__init__() + gaussian_weights = torch.linspace( + start=0.0, end=r_max, steps=num_basis, dtype=torch.get_default_dtype() + ) + if trainable: + self.gaussian_weights = torch.nn.Parameter( + gaussian_weights, requires_grad=True + ) + else: + self.register_buffer("gaussian_weights", gaussian_weights) + self.coeff = -0.5 / (r_max / (num_basis - 1)) ** 2 + + def forward(self, x: torch.Tensor) -> torch.Tensor: # [..., 1] + x = x - self.gaussian_weights + return torch.exp(self.coeff * torch.pow(x, 2)) + + +@compile_mode("script") +class PolynomialCutoff(torch.nn.Module): + """ + Equation (8) + """ + + p: torch.Tensor + r_max: torch.Tensor + + def __init__(self, r_max: float, p=6): + super().__init__() + self.register_buffer("p", torch.tensor(p, dtype=torch.get_default_dtype())) + self.register_buffer( + "r_max", torch.tensor(r_max, dtype=torch.get_default_dtype()) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # yapf: disable + envelope = ( + 1.0 + - ((self.p + 1.0) * (self.p + 2.0) / 2.0) * torch.pow(x / self.r_max, self.p) + + self.p * (self.p + 2.0) * torch.pow(x / self.r_max, self.p + 1) + - (self.p * (self.p + 1.0) / 2) * torch.pow(x / self.r_max, self.p + 2) + ) + # yapf: enable + + # noinspection PyUnresolvedReferences + return envelope * (x < self.r_max) + + def __repr__(self): + return f"{self.__class__.__name__}(p={self.p}, r_max={self.r_max})" + + +@compile_mode("script") +class ZBLBasis(torch.nn.Module): + """ + Implementation of the Ziegler-Biersack-Littmark (ZBL) potential + """ + + p: torch.Tensor + r_max: torch.Tensor + + def __init__(self, r_max: float, p=6, trainable=False): + super().__init__() + self.register_buffer( + "r_max", torch.tensor(r_max, dtype=torch.get_default_dtype()) + ) + # Pre-calculate the p coefficients for the ZBL potential + self.register_buffer( + "c", + torch.tensor( + [0.1818, 0.5099, 0.2802, 0.02817], dtype=torch.get_default_dtype() + ), + ) + self.register_buffer("p", torch.tensor(p, dtype=torch.get_default_dtype())) + self.register_buffer( + "covalent_radii", + torch.tensor( + ase.data.covalent_radii, + dtype=torch.get_default_dtype(), + ), + ) + self.cutoff = PolynomialCutoff(r_max, p) + if trainable: + self.a_exp = torch.nn.Parameter(torch.tensor(0.300, requires_grad=True)) + self.a_prefactor = torch.nn.Parameter( + torch.tensor(0.4543, requires_grad=True) + ) + else: + self.register_buffer("a_exp", torch.tensor(0.300)) + self.register_buffer("a_prefactor", torch.tensor(0.4543)) + + def forward( + self, + x: torch.Tensor, + node_attrs: torch.Tensor, + edge_index: torch.Tensor, + atomic_numbers: torch.Tensor, + ) -> torch.Tensor: + sender = edge_index[0] + receiver = edge_index[1] + node_atomic_numbers = atomic_numbers[torch.argmax(node_attrs, dim=1)].unsqueeze( + -1 + ) + Z_u = node_atomic_numbers[sender] + Z_v = node_atomic_numbers[receiver] + a = ( + self.a_prefactor + * 0.529 + / (torch.pow(Z_u, self.a_exp) + torch.pow(Z_v, self.a_exp)) + ) + r_over_a = x / a + phi = ( + self.c[0] * torch.exp(-3.2 * r_over_a) + + self.c[1] * torch.exp(-0.9423 * r_over_a) + + self.c[2] * torch.exp(-0.4028 * r_over_a) + + self.c[3] * torch.exp(-0.2016 * r_over_a) + ) + v_edges = (14.3996 * Z_u * Z_v) / x * phi + r_max = self.covalent_radii[Z_u] + self.covalent_radii[Z_v] + envelope = ( + 1.0 + - ((self.p + 1.0) * (self.p + 2.0) / 2.0) * torch.pow(x / r_max, self.p) + + self.p * (self.p + 2.0) * torch.pow(x / r_max, self.p + 1) + - (self.p * (self.p + 1.0) / 2) * torch.pow(x / r_max, self.p + 2) + ) * (x < r_max) + v_edges = 0.5 * v_edges * envelope + V_ZBL = scatter_sum(v_edges, receiver, dim=0, dim_size=node_attrs.size(0)) + return V_ZBL.squeeze(-1) + + def __repr__(self): + return f"{self.__class__.__name__}(r_max={self.r_max}, c={self.c})" + + +@compile_mode("script") +class AgnesiTransform(torch.nn.Module): + """ + Agnesi transform see ACEpotentials.jl, JCP 2023, p. 160 + """ + + def __init__( + self, + q: float = 0.9183, + p: float = 4.5791, + a: float = 1.0805, + trainable=False, + ): + super().__init__() + self.register_buffer("q", torch.tensor(q, dtype=torch.get_default_dtype())) + self.register_buffer("p", torch.tensor(p, dtype=torch.get_default_dtype())) + self.register_buffer("a", torch.tensor(a, dtype=torch.get_default_dtype())) + self.register_buffer( + "covalent_radii", + torch.tensor( + ase.data.covalent_radii, + dtype=torch.get_default_dtype(), + ), + ) + if trainable: + self.a = torch.nn.Parameter(torch.tensor(1.0805, requires_grad=True)) + self.q = torch.nn.Parameter(torch.tensor(0.9183, requires_grad=True)) + self.p = torch.nn.Parameter(torch.tensor(4.5791, requires_grad=True)) + + def forward( + self, + x: torch.Tensor, + node_attrs: torch.Tensor, + edge_index: torch.Tensor, + atomic_numbers: torch.Tensor, + ) -> torch.Tensor: + sender = edge_index[0] + receiver = edge_index[1] + node_atomic_numbers = atomic_numbers[torch.argmax(node_attrs, dim=1)].unsqueeze( + -1 + ) + Z_u = node_atomic_numbers[sender] + Z_v = node_atomic_numbers[receiver] + r_0 = 0.5 * (self.covalent_radii[Z_u] + self.covalent_radii[Z_v]) + return ( + 1 + (self.a * ((x / r_0) ** self.q) / (1 + (x / r_0) ** (self.q - self.p))) + ) ** (-1) + + def __repr__(self): + return f"{self.__class__.__name__}(a={self.a}, q={self.q}, p={self.p})" + + +@simplify_if_compile +@compile_mode("script") +class SoftTransform(torch.nn.Module): + """ + Soft Transform + """ + + def __init__(self, a: float = 0.2, b: float = 3.0, trainable=False): + super().__init__() + self.register_buffer( + "covalent_radii", + torch.tensor( + ase.data.covalent_radii, + dtype=torch.get_default_dtype(), + ), + ) + if trainable: + self.a = torch.nn.Parameter(torch.tensor(a, requires_grad=True)) + self.b = torch.nn.Parameter(torch.tensor(b, requires_grad=True)) + else: + self.register_buffer("a", torch.tensor(a)) + self.register_buffer("b", torch.tensor(b)) + + def forward( + self, + x: torch.Tensor, + node_attrs: torch.Tensor, + edge_index: torch.Tensor, + atomic_numbers: torch.Tensor, + ) -> torch.Tensor: + sender = edge_index[0] + receiver = edge_index[1] + node_atomic_numbers = atomic_numbers[torch.argmax(node_attrs, dim=1)].unsqueeze( + -1 + ) + Z_u = node_atomic_numbers[sender] + Z_v = node_atomic_numbers[receiver] + r_0 = (self.covalent_radii[Z_u] + self.covalent_radii[Z_v]) / 4 + y = ( + x + + (1 / 2) * torch.tanh(-(x / r_0) - self.a * ((x / r_0) ** self.b)) + + 1 / 2 + ) + return y + + def __repr__(self): + return f"{self.__class__.__name__}(a={self.a.item()}, b={self.b.item()})" diff --git a/mace/modules/symmetric_contraction.py b/mace/modules/symmetric_contraction.py new file mode 100644 index 00000000..9db75da0 --- /dev/null +++ b/mace/modules/symmetric_contraction.py @@ -0,0 +1,233 @@ +########################################################################################### +# Implementation of the symmetric contraction algorithm presented in the MACE paper +# (Batatia et al, MACE: Higher Order Equivariant Message Passing Neural Networks for Fast and Accurate Force Fields , Eq.10 and 11) +# Authors: Ilyes Batatia +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +from typing import Dict, Optional, Union + +import opt_einsum_fx +import torch +import torch.fx +from e3nn import o3 +from e3nn.util.codegen import CodeGenMixin +from e3nn.util.jit import compile_mode + +from mace.tools.cg import U_matrix_real + +BATCH_EXAMPLE = 10 +ALPHABET = ["w", "x", "v", "n", "z", "r", "t", "y", "u", "o", "p", "s"] + + +@compile_mode("script") +class SymmetricContraction(CodeGenMixin, torch.nn.Module): + def __init__( + self, + irreps_in: o3.Irreps, + irreps_out: o3.Irreps, + correlation: Union[int, Dict[str, int]], + irrep_normalization: str = "component", + path_normalization: str = "element", + internal_weights: Optional[bool] = None, + shared_weights: Optional[bool] = None, + num_elements: Optional[int] = None, + ) -> None: + super().__init__() + + if irrep_normalization is None: + irrep_normalization = "component" + + if path_normalization is None: + path_normalization = "element" + + assert irrep_normalization in ["component", "norm", "none"] + assert path_normalization in ["element", "path", "none"] + + self.irreps_in = o3.Irreps(irreps_in) + self.irreps_out = o3.Irreps(irreps_out) + + del irreps_in, irreps_out + + if not isinstance(correlation, tuple): + corr = correlation + correlation = {} + for irrep_out in self.irreps_out: + correlation[irrep_out] = corr + + assert shared_weights or not internal_weights + + if internal_weights is None: + internal_weights = True + + self.internal_weights = internal_weights + self.shared_weights = shared_weights + + del internal_weights, shared_weights + + self.contractions = torch.nn.ModuleList() + for irrep_out in self.irreps_out: + self.contractions.append( + Contraction( + irreps_in=self.irreps_in, + irrep_out=o3.Irreps(str(irrep_out.ir)), + correlation=correlation[irrep_out], + internal_weights=self.internal_weights, + num_elements=num_elements, + weights=self.shared_weights, + ) + ) + + def forward(self, x: torch.Tensor, y: torch.Tensor): + outs = [contraction(x, y) for contraction in self.contractions] + return torch.cat(outs, dim=-1) + + +@compile_mode("script") +class Contraction(torch.nn.Module): + def __init__( + self, + irreps_in: o3.Irreps, + irrep_out: o3.Irreps, + correlation: int, + internal_weights: bool = True, + num_elements: Optional[int] = None, + weights: Optional[torch.Tensor] = None, + ) -> None: + super().__init__() + + self.num_features = irreps_in.count((0, 1)) + self.coupling_irreps = o3.Irreps([irrep.ir for irrep in irreps_in]) + self.correlation = correlation + dtype = torch.get_default_dtype() + for nu in range(1, correlation + 1): + U_matrix = U_matrix_real( + irreps_in=self.coupling_irreps, + irreps_out=irrep_out, + correlation=nu, + dtype=dtype, + )[-1] + self.register_buffer(f"U_matrix_{nu}", U_matrix) + + # Tensor contraction equations + self.contractions_weighting = torch.nn.ModuleList() + self.contractions_features = torch.nn.ModuleList() + + # Create weight for product basis + self.weights = torch.nn.ParameterList([]) + + for i in range(correlation, 0, -1): + # Shapes definying + num_params = self.U_tensors(i).size()[-1] + num_equivariance = 2 * irrep_out.lmax + 1 + num_ell = self.U_tensors(i).size()[-2] + + if i == correlation: + parse_subscript_main = ( + [ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1) - 1)] + + ["ik,ekc,bci,be -> bc"] + + [ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1) - 1)] + ) + graph_module_main = torch.fx.symbolic_trace( + lambda x, y, w, z: torch.einsum( + "".join(parse_subscript_main), x, y, w, z + ) + ) + + # Optimizing the contractions + self.graph_opt_main = opt_einsum_fx.optimize_einsums_full( + model=graph_module_main, + example_inputs=( + torch.randn( + [num_equivariance] + [num_ell] * i + [num_params] + ).squeeze(0), + torch.randn((num_elements, num_params, self.num_features)), + torch.randn((BATCH_EXAMPLE, self.num_features, num_ell)), + torch.randn((BATCH_EXAMPLE, num_elements)), + ), + ) + # Parameters for the product basis + w = torch.nn.Parameter( + torch.randn((num_elements, num_params, self.num_features)) + / num_params + ) + self.weights_max = w + else: + # Generate optimized contractions equations + parse_subscript_weighting = ( + [ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1))] + + ["k,ekc,be->bc"] + + [ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1))] + ) + parse_subscript_features = ( + ["bc"] + + [ALPHABET[j] for j in range(i - 1 + min(irrep_out.lmax, 1))] + + ["i,bci->bc"] + + [ALPHABET[j] for j in range(i - 1 + min(irrep_out.lmax, 1))] + ) + + # Symbolic tracing of contractions + graph_module_weighting = torch.fx.symbolic_trace( + lambda x, y, z: torch.einsum( + "".join(parse_subscript_weighting), x, y, z + ) + ) + graph_module_features = torch.fx.symbolic_trace( + lambda x, y: torch.einsum("".join(parse_subscript_features), x, y) + ) + + # Optimizing the contractions + graph_opt_weighting = opt_einsum_fx.optimize_einsums_full( + model=graph_module_weighting, + example_inputs=( + torch.randn( + [num_equivariance] + [num_ell] * i + [num_params] + ).squeeze(0), + torch.randn((num_elements, num_params, self.num_features)), + torch.randn((BATCH_EXAMPLE, num_elements)), + ), + ) + graph_opt_features = opt_einsum_fx.optimize_einsums_full( + model=graph_module_features, + example_inputs=( + torch.randn( + [BATCH_EXAMPLE, self.num_features, num_equivariance] + + [num_ell] * i + ).squeeze(2), + torch.randn((BATCH_EXAMPLE, self.num_features, num_ell)), + ), + ) + self.contractions_weighting.append(graph_opt_weighting) + self.contractions_features.append(graph_opt_features) + # Parameters for the product basis + w = torch.nn.Parameter( + torch.randn((num_elements, num_params, self.num_features)) + / num_params + ) + self.weights.append(w) + if not internal_weights: + self.weights = weights[:-1] + self.weights_max = weights[-1] + + def forward(self, x: torch.Tensor, y: torch.Tensor): + out = self.graph_opt_main( + self.U_tensors(self.correlation), + self.weights_max, + x, + y, + ) + for i, (weight, contract_weights, contract_features) in enumerate( + zip(self.weights, self.contractions_weighting, self.contractions_features) + ): + c_tensor = contract_weights( + self.U_tensors(self.correlation - i - 1), + weight, + y, + ) + c_tensor = c_tensor + out + out = contract_features(c_tensor, x) + + return out.view(out.shape[0], -1) + + def U_tensors(self, nu: int): + return dict(self.named_buffers())[f"U_matrix_{nu}"] diff --git a/mace/modules/utils.py b/mace/modules/utils.py new file mode 100644 index 00000000..d0a1e5f6 --- /dev/null +++ b/mace/modules/utils.py @@ -0,0 +1,442 @@ +########################################################################################### +# Utilities +# Authors: Ilyes Batatia, Gregor Simm and David Kovacs +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import logging +from typing import List, Optional, Tuple + +import numpy as np +import torch +import torch.nn +import torch.utils.data +from scipy.constants import c, e + +from mace.tools import to_numpy +from mace.tools.scatter import scatter_mean, scatter_std, scatter_sum +from mace.tools.torch_geometric.batch import Batch + +from .blocks import AtomicEnergiesBlock + + +def compute_forces( + energy: torch.Tensor, positions: torch.Tensor, training: bool = True +) -> torch.Tensor: + grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(energy)] + gradient = torch.autograd.grad( + outputs=[energy], # [n_graphs, ] + inputs=[positions], # [n_nodes, 3] + grad_outputs=grad_outputs, + retain_graph=training, # Make sure the graph is not destroyed during training + create_graph=training, # Create graph for second derivative + allow_unused=True, # For complete dissociation turn to true + )[ + 0 + ] # [n_nodes, 3] + if gradient is None: + return torch.zeros_like(positions) + return -1 * gradient + + +def compute_forces_virials( + energy: torch.Tensor, + positions: torch.Tensor, + displacement: torch.Tensor, + cell: torch.Tensor, + training: bool = True, + compute_stress: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(energy)] + forces, virials = torch.autograd.grad( + outputs=[energy], # [n_graphs, ] + inputs=[positions, displacement], # [n_nodes, 3] + grad_outputs=grad_outputs, + retain_graph=training, # Make sure the graph is not destroyed during training + create_graph=training, # Create graph for second derivative + allow_unused=True, + ) + stress = torch.zeros_like(displacement) + if compute_stress and virials is not None: + cell = cell.view(-1, 3, 3) + volume = torch.linalg.det(cell).abs().unsqueeze(-1) + stress = virials / volume.view(-1, 1, 1) + stress = torch.where(torch.abs(stress) < 1e10, stress, torch.zeros_like(stress)) + if forces is None: + forces = torch.zeros_like(positions) + if virials is None: + virials = torch.zeros((1, 3, 3)) + + return -1 * forces, -1 * virials, stress + + +def get_symmetric_displacement( + positions: torch.Tensor, + unit_shifts: torch.Tensor, + cell: Optional[torch.Tensor], + edge_index: torch.Tensor, + num_graphs: int, + batch: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if cell is None: + cell = torch.zeros( + num_graphs * 3, + 3, + dtype=positions.dtype, + device=positions.device, + ) + sender = edge_index[0] + displacement = torch.zeros( + (num_graphs, 3, 3), + dtype=positions.dtype, + device=positions.device, + ) + displacement.requires_grad_(True) + symmetric_displacement = 0.5 * ( + displacement + displacement.transpose(-1, -2) + ) # From https://github.com/mir-group/nequip + positions = positions + torch.einsum( + "be,bec->bc", positions, symmetric_displacement[batch] + ) + cell = cell.view(-1, 3, 3) + cell = cell + torch.matmul(cell, symmetric_displacement) + shifts = torch.einsum( + "be,bec->bc", + unit_shifts, + cell[batch[sender]], + ) + return positions, shifts, displacement + + +@torch.jit.unused +def compute_hessians_vmap( + forces: torch.Tensor, + positions: torch.Tensor, +) -> torch.Tensor: + forces_flatten = forces.view(-1) + num_elements = forces_flatten.shape[0] + + def get_vjp(v): + return torch.autograd.grad( + -1 * forces_flatten, + positions, + v, + retain_graph=True, + create_graph=False, + allow_unused=False, + ) + + I_N = torch.eye(num_elements).to(forces.device) + try: + chunk_size = 1 if num_elements < 64 else 16 + gradient = torch.vmap(get_vjp, in_dims=0, out_dims=0, chunk_size=chunk_size)( + I_N + )[0] + except RuntimeError: + gradient = compute_hessians_loop(forces, positions) + if gradient is None: + return torch.zeros((positions.shape[0], forces.shape[0], 3, 3)) + return gradient + + +@torch.jit.unused +def compute_hessians_loop( + forces: torch.Tensor, + positions: torch.Tensor, +) -> torch.Tensor: + hessian = [] + for grad_elem in forces.view(-1): + hess_row = torch.autograd.grad( + outputs=[-1 * grad_elem], + inputs=[positions], + grad_outputs=torch.ones_like(grad_elem), + retain_graph=True, + create_graph=False, + allow_unused=False, + )[0] + hess_row = hess_row.detach() # this makes it very slow? but needs less memory + if hess_row is None: + hessian.append(torch.zeros_like(positions)) + else: + hessian.append(hess_row) + hessian = torch.stack(hessian) + return hessian + + +def get_outputs( + energy: torch.Tensor, + positions: torch.Tensor, + displacement: Optional[torch.Tensor], + cell: torch.Tensor, + training: bool = False, + compute_force: bool = True, + compute_virials: bool = True, + compute_stress: bool = True, + compute_hessian: bool = False, +) -> Tuple[ + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], +]: + if (compute_virials or compute_stress) and displacement is not None: + forces, virials, stress = compute_forces_virials( + energy=energy, + positions=positions, + displacement=displacement, + cell=cell, + compute_stress=compute_stress, + training=(training or compute_hessian), + ) + elif compute_force: + forces, virials, stress = ( + compute_forces( + energy=energy, + positions=positions, + training=(training or compute_hessian), + ), + None, + None, + ) + else: + forces, virials, stress = (None, None, None) + if compute_hessian: + assert forces is not None, "Forces must be computed to get the hessian" + hessian = compute_hessians_vmap(forces, positions) + else: + hessian = None + return forces, virials, stress, hessian + + +def get_edge_vectors_and_lengths( + positions: torch.Tensor, # [n_nodes, 3] + edge_index: torch.Tensor, # [2, n_edges] + shifts: torch.Tensor, # [n_edges, 3] + normalize: bool = False, + eps: float = 1e-9, +) -> Tuple[torch.Tensor, torch.Tensor]: + sender = edge_index[0] + receiver = edge_index[1] + vectors = positions[receiver] - positions[sender] + shifts # [n_edges, 3] + lengths = torch.linalg.norm(vectors, dim=-1, keepdim=True) # [n_edges, 1] + if normalize: + vectors_normed = vectors / (lengths + eps) + return vectors_normed, lengths + + return vectors, lengths + + +def _check_non_zero(std): + if np.any(std == 0): + logging.warning( + "Standard deviation of the scaling is zero, Changing to no scaling" + ) + std[std == 0] = 1 + return std + + +def extract_invariant(x: torch.Tensor, num_layers: int, num_features: int, l_max: int): + out = [] + for i in range(num_layers - 1): + out.append( + x[ + :, + i + * (l_max + 1) ** 2 + * num_features : (i * (l_max + 1) ** 2 + 1) + * num_features, + ] + ) + out.append(x[:, -num_features:]) + return torch.cat(out, dim=-1) + + +def compute_mean_std_atomic_inter_energy( + data_loader: torch.utils.data.DataLoader, + atomic_energies: np.ndarray, +) -> Tuple[float, float]: + atomic_energies_fn = AtomicEnergiesBlock(atomic_energies=atomic_energies) + + avg_atom_inter_es_list = [] + head_list = [] + + for batch in data_loader: + node_e0 = atomic_energies_fn(batch.node_attrs) + graph_e0s = scatter_sum( + src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs + )[torch.arange(batch.num_graphs), batch.head] + graph_sizes = batch.ptr[1:] - batch.ptr[:-1] + avg_atom_inter_es_list.append( + (batch.energy - graph_e0s) / graph_sizes + ) # {[n_graphs], } + head_list.append(batch.head) + + avg_atom_inter_es = torch.cat(avg_atom_inter_es_list) # [total_n_graphs] + head = torch.cat(head_list, dim=0) # [total_n_graphs] + # mean = to_numpy(torch.mean(avg_atom_inter_es)).item() + # std = to_numpy(torch.std(avg_atom_inter_es)).item() + mean = to_numpy(scatter_mean(src=avg_atom_inter_es, index=head, dim=0).squeeze(-1)) + std = to_numpy(scatter_std(src=avg_atom_inter_es, index=head, dim=0).squeeze(-1)) + std = _check_non_zero(std) + + return mean, std + + +def _compute_mean_std_atomic_inter_energy( + batch: Batch, + atomic_energies_fn: AtomicEnergiesBlock, +) -> Tuple[torch.Tensor, torch.Tensor]: + head = batch.head + node_e0 = atomic_energies_fn(batch.node_attrs) + graph_e0s = scatter_sum( + src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs + )[torch.arange(batch.num_graphs), head] + graph_sizes = batch.ptr[1:] - batch.ptr[:-1] + atom_energies = (batch.energy - graph_e0s) / graph_sizes + return atom_energies + + +def compute_mean_rms_energy_forces( + data_loader: torch.utils.data.DataLoader, + atomic_energies: np.ndarray, +) -> Tuple[float, float]: + atomic_energies_fn = AtomicEnergiesBlock(atomic_energies=atomic_energies) + + atom_energy_list = [] + forces_list = [] + head_list = [] + head_batch = [] + + for batch in data_loader: + head = batch.head + node_e0 = atomic_energies_fn(batch.node_attrs) + graph_e0s = scatter_sum( + src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs + )[torch.arange(batch.num_graphs), head] + graph_sizes = batch.ptr[1:] - batch.ptr[:-1] + atom_energy_list.append( + (batch.energy - graph_e0s) / graph_sizes + ) # {[n_graphs], } + forces_list.append(batch.forces) # {[n_graphs*n_atoms,3], } + head_list.append(head) + head_batch.append(head[batch.batch]) + + atom_energies = torch.cat(atom_energy_list, dim=0) # [total_n_graphs] + forces = torch.cat(forces_list, dim=0) # {[total_n_graphs*n_atoms,3], } + head = torch.cat(head_list, dim=0) # [total_n_graphs] + head_batch = torch.cat(head_batch, dim=0) # [total_n_graphs] + + # mean = to_numpy(torch.mean(atom_energies)).item() + # rms = to_numpy(torch.sqrt(torch.mean(torch.square(forces)))).item() + mean = to_numpy(scatter_mean(src=atom_energies, index=head, dim=0).squeeze(-1)) + rms = to_numpy( + torch.sqrt( + scatter_mean(src=torch.square(forces), index=head_batch, dim=0).mean(-1) + ) + ) + rms = _check_non_zero(rms) + + return mean, rms + + +def _compute_mean_rms_energy_forces( + batch: Batch, + atomic_energies_fn: AtomicEnergiesBlock, +) -> Tuple[torch.Tensor, torch.Tensor]: + head = batch.head + node_e0 = atomic_energies_fn(batch.node_attrs) + graph_e0s = scatter_sum( + src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs + )[torch.arange(batch.num_graphs), head] + graph_sizes = batch.ptr[1:] - batch.ptr[:-1] + atom_energies = (batch.energy - graph_e0s) / graph_sizes # {[n_graphs], } + forces = batch.forces # {[n_graphs*n_atoms,3], } + + return atom_energies, forces + + +def compute_avg_num_neighbors(data_loader: torch.utils.data.DataLoader) -> float: + num_neighbors = [] + for batch in data_loader: + _, receivers = batch.edge_index + _, counts = torch.unique(receivers, return_counts=True) + num_neighbors.append(counts) + + avg_num_neighbors = torch.mean( + torch.cat(num_neighbors, dim=0).type(torch.get_default_dtype()) + ) + return to_numpy(avg_num_neighbors).item() + + +def compute_statistics( + data_loader: torch.utils.data.DataLoader, + atomic_energies: np.ndarray, +) -> Tuple[float, float, float, float]: + atomic_energies_fn = AtomicEnergiesBlock(atomic_energies=atomic_energies) + + atom_energy_list = [] + forces_list = [] + num_neighbors = [] + head_list = [] + head_batch = [] + + for batch in data_loader: + head = batch.head + node_e0 = atomic_energies_fn(batch.node_attrs) + graph_e0s = scatter_sum( + src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs + )[torch.arange(batch.num_graphs), head] + graph_sizes = batch.ptr[1:] - batch.ptr[:-1] + atom_energy_list.append( + (batch.energy - graph_e0s) / graph_sizes + ) # {[n_graphs], } + forces_list.append(batch.forces) # {[n_graphs*n_atoms,3], } + head_list.append(head) # {[n_graphs], } + head_batch.append(head[batch.batch]) + _, receivers = batch.edge_index + _, counts = torch.unique(receivers, return_counts=True) + num_neighbors.append(counts) + + atom_energies = torch.cat(atom_energy_list, dim=0) # [total_n_graphs] + forces = torch.cat(forces_list, dim=0) # {[total_n_graphs*n_atoms,3], } + head = torch.cat(head_list, dim=0) # [total_n_graphs] + head_batch = torch.cat(head_batch, dim=0) # [total_n_graphs] + + # mean = to_numpy(torch.mean(atom_energies)).item() + mean = to_numpy(scatter_mean(src=atom_energies, index=head, dim=0).squeeze(-1)) + rms = to_numpy( + torch.sqrt( + scatter_mean(src=torch.square(forces), index=head_batch, dim=0).mean(-1) + ) + ) + + avg_num_neighbors = torch.mean( + torch.cat(num_neighbors, dim=0).type(torch.get_default_dtype()) + ) + + return to_numpy(avg_num_neighbors).item(), mean, rms + + +def compute_rms_dipoles( + data_loader: torch.utils.data.DataLoader, +) -> Tuple[float, float]: + dipoles_list = [] + for batch in data_loader: + dipoles_list.append(batch.dipole) # {[n_graphs,3], } + + dipoles = torch.cat(dipoles_list, dim=0) # {[total_n_graphs,3], } + rms = to_numpy(torch.sqrt(torch.mean(torch.square(dipoles)))).item() + rms = _check_non_zero(rms) + return rms + + +def compute_fixed_charge_dipole( + charges: torch.Tensor, + positions: torch.Tensor, + batch: torch.Tensor, + num_graphs: int, +) -> torch.Tensor: + mu = positions * charges.unsqueeze(-1) / (1e-11 / c / e) # [N_atoms,3] + return scatter_sum( + src=mu, index=batch.unsqueeze(-1), dim=0, dim_size=num_graphs + ) # [N_graphs,3] diff --git a/mace/py.typed b/mace/py.typed new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/mace/py.typed @@ -0,0 +1 @@ + diff --git a/mace/tools/__init__.py b/mace/tools/__init__.py new file mode 100644 index 00000000..8ad80243 --- /dev/null +++ b/mace/tools/__init__.py @@ -0,0 +1,71 @@ +from .arg_parser import build_default_arg_parser, build_preprocess_arg_parser +from .arg_parser_tools import check_args +from .cg import U_matrix_real +from .checkpoint import CheckpointHandler, CheckpointIO, CheckpointState +from .finetuning_utils import load_foundations, load_foundations_elements +from .torch_tools import ( + TensorDict, + cartesian_to_spherical, + count_parameters, + init_device, + init_wandb, + set_default_dtype, + set_seeds, + spherical_to_cartesian, + to_numpy, + to_one_hot, + voigt_to_matrix, +) +from .train import SWAContainer, evaluate, train +from .utils import ( + AtomicNumberTable, + MetricsLogger, + atomic_numbers_to_indices, + compute_c, + compute_mae, + compute_q95, + compute_rel_mae, + compute_rel_rmse, + compute_rmse, + get_atomic_number_table_from_zs, + get_tag, + setup_logger, +) + +__all__ = [ + "TensorDict", + "AtomicNumberTable", + "atomic_numbers_to_indices", + "to_numpy", + "to_one_hot", + "build_default_arg_parser", + "check_args", + "set_seeds", + "init_device", + "setup_logger", + "get_tag", + "count_parameters", + "MetricsLogger", + "get_atomic_number_table_from_zs", + "train", + "evaluate", + "SWAContainer", + "CheckpointHandler", + "CheckpointIO", + "CheckpointState", + "set_default_dtype", + "compute_mae", + "compute_rel_mae", + "compute_rmse", + "compute_rel_rmse", + "compute_q95", + "compute_c", + "U_matrix_real", + "spherical_to_cartesian", + "cartesian_to_spherical", + "voigt_to_matrix", + "init_wandb", + "load_foundations", + "load_foundations_elements", + "build_preprocess_arg_parser", +] diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py new file mode 100644 index 00000000..cb4f8ac5 --- /dev/null +++ b/mace/tools/arg_parser.py @@ -0,0 +1,878 @@ +########################################################################################### +# Parsing functionalities +# Authors: Ilyes Batatia, Gregor Simm, David Kovacs +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import argparse +import os +from typing import Optional + + +def build_default_arg_parser() -> argparse.ArgumentParser: + try: + import configargparse + + parser = configargparse.ArgumentParser( + config_file_parser_class=configargparse.YAMLConfigFileParser, + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add( + "--config", + type=str, + is_config_file=True, + help="config file to agregate options", + ) + except ImportError: + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # Name and seed + parser.add_argument("--name", help="experiment name", required=True) + parser.add_argument("--seed", help="random seed", type=int, default=123) + + # Directories + parser.add_argument( + "--work_dir", + help="set directory for all files and folders", + type=str, + default=".", + ) + parser.add_argument( + "--log_dir", help="directory for log files", type=str, default=None + ) + parser.add_argument( + "--model_dir", help="directory for final model", type=str, default=None + ) + parser.add_argument( + "--checkpoints_dir", + help="directory for checkpoint files", + type=str, + default=None, + ) + parser.add_argument( + "--results_dir", help="directory for results", type=str, default=None + ) + parser.add_argument( + "--downloads_dir", help="directory for downloads", type=str, default=None + ) + + # Device and logging + parser.add_argument( + "--device", + help="select device", + type=str, + choices=["cpu", "cuda", "mps", "xpu"], + default="cpu", + ) + parser.add_argument( + "--default_dtype", + help="set default dtype", + type=str, + choices=["float32", "float64"], + default="float64", + ) + parser.add_argument( + "--distributed", + help="train in multi-GPU data parallel mode", + action="store_true", + default=False, + ) + parser.add_argument("--log_level", help="log level", type=str, default="INFO") + + parser.add_argument( + "--error_table", + help="Type of error table produced at the end of the training", + type=str, + choices=[ + "PerAtomRMSE", + "TotalRMSE", + "PerAtomRMSEstressvirials", + "PerAtomMAEstressvirials", + "PerAtomMAE", + "TotalMAE", + "DipoleRMSE", + "DipoleMAE", + "EnergyDipoleRMSE", + ], + default="PerAtomRMSE", + ) + + # Model + parser.add_argument( + "--model", + help="model type", + default="MACE", + choices=[ + "BOTNet", + "MACE", + "ScaleShiftMACE", + "ScaleShiftBOTNet", + "AtomicDipolesMACE", + "EnergyDipolesMACE", + ], + ) + parser.add_argument( + "--r_max", help="distance cutoff (in Ang)", type=float, default=5.0 + ) + parser.add_argument( + "--radial_type", + help="type of radial basis functions", + type=str, + default="bessel", + choices=["bessel", "gaussian", "chebyshev"], + ) + parser.add_argument( + "--num_radial_basis", + help="number of radial basis functions", + type=int, + default=8, + ) + parser.add_argument( + "--num_cutoff_basis", + help="number of basis functions for smooth cutoff", + type=int, + default=5, + ) + parser.add_argument( + "--pair_repulsion", + help="use pair repulsion term with ZBL potential", + action="store_true", + default=False, + ) + parser.add_argument( + "--distance_transform", + help="use distance transform for radial basis functions", + default="None", + choices=["None", "Agnesi", "Soft"], + ) + parser.add_argument( + "--interaction", + help="name of interaction block", + type=str, + default="RealAgnosticResidualInteractionBlock", + choices=[ + "RealAgnosticResidualInteractionBlock", + "RealAgnosticAttResidualInteractionBlock", + "RealAgnosticInteractionBlock", + "RealAgnosticDensityInteractionBlock", + "RealAgnosticDensityResidualInteractionBlock", + ], + ) + parser.add_argument( + "--interaction_first", + help="name of interaction block", + type=str, + default="RealAgnosticResidualInteractionBlock", + choices=[ + "RealAgnosticResidualInteractionBlock", + "RealAgnosticInteractionBlock", + "RealAgnosticDensityInteractionBlock", + "RealAgnosticDensityResidualInteractionBlock", + ], + ) + parser.add_argument( + "--max_ell", help=r"highest \ell of spherical harmonics", type=int, default=3 + ) + parser.add_argument( + "--correlation", help="correlation order at each layer", type=int, default=3 + ) + parser.add_argument( + "--num_interactions", help="number of interactions", type=int, default=2 + ) + parser.add_argument( + "--MLP_irreps", + help="hidden irreps of the MLP in last readout", + type=str, + default="16x0e", + ) + parser.add_argument( + "--radial_MLP", + help="width of the radial MLP", + type=str, + default="[64, 64, 64]", + ) + parser.add_argument( + "--hidden_irreps", + help="irreps for hidden node states", + type=str, + default=None, + ) + # add option to specify irreps by channel number and max L + parser.add_argument( + "--num_channels", + help="number of embedding channels", + type=int, + default=None, + ) + parser.add_argument( + "--max_L", + help="max L equivariance of the message", + type=int, + default=None, + ) + parser.add_argument( + "--gate", + help="non linearity for last readout", + type=str, + default="silu", + choices=["silu", "tanh", "abs", "None"], + ) + parser.add_argument( + "--scaling", + help="type of scaling to the output", + type=str, + default="rms_forces_scaling", + choices=["std_scaling", "rms_forces_scaling", "no_scaling"], + ) + parser.add_argument( + "--avg_num_neighbors", + help="normalization factor for the message", + type=float, + default=1, + ) + parser.add_argument( + "--compute_avg_num_neighbors", + help="normalization factor for the message", + type=str2bool, + default=True, + ) + parser.add_argument( + "--compute_stress", + help="Select True to compute stress", + type=str2bool, + default=False, + ) + parser.add_argument( + "--compute_forces", + help="Select True to compute forces", + type=str2bool, + default=True, + ) + + # Dataset + parser.add_argument( + "--train_file", + help="Training set file, format is .xyz or .h5", + type=str, + required=False, + ) + parser.add_argument( + "--valid_file", + help="Validation set .xyz or .h5 file", + default=None, + type=str, + required=False, + ) + parser.add_argument( + "--valid_fraction", + help="Fraction of training set used for validation", + type=float, + default=0.1, + required=False, + ) + parser.add_argument( + "--test_file", + help="Test set .xyz pt .h5 file", + type=str, + ) + parser.add_argument( + "--test_dir", + help="Path to directory with test files named as test_*.h5", + type=str, + default=None, + required=False, + ) + parser.add_argument( + "--multi_processed_test", + help="Boolean value for whether the test data was multiprocessed", + type=str2bool, + default=False, + required=False, + ) + parser.add_argument( + "--num_workers", + help="Number of workers for data loading", + type=int, + default=0, + ) + parser.add_argument( + "--pin_memory", + help="Pin memory for data loading", + default=True, + type=str2bool, + ) + parser.add_argument( + "--atomic_numbers", + help="List of atomic numbers", + type=str, + default=None, + required=False, + ) + parser.add_argument( + "--mean", + help="Mean energy per atom of training set", + type=float, + default=None, + required=False, + ) + parser.add_argument( + "--std", + help="Standard deviation of force components in the training set", + type=float, + default=None, + required=False, + ) + parser.add_argument( + "--statistics_file", + help="json file containing statistics of training set", + type=str, + default=None, + required=False, + ) + parser.add_argument( + "--E0s", + help="Dictionary of isolated atom energies", + type=str, + default=None, + required=False, + ) + + # Fine-tuning + parser.add_argument( + "--foundation_filter_elements", + help="Filter element during fine-tuning", + type=str2bool, + default=True, + required=False, + ) + parser.add_argument( + "--heads", + help="Dict of heads: containing individual files and E0s", + type=str, + default=None, + required=False, + ) + parser.add_argument( + "--multiheads_finetuning", + help="Boolean value for whether the model is multiheaded", + type=str2bool, + default=True, + ) + parser.add_argument( + "--foundation_head", + help="Name of the head to use for fine-tuning", + type=str, + default=None, + required=False, + ) + parser.add_argument( + "--weight_pt_head", + help="Weight of the pretrained head in the loss function", + type=float, + default=1.0, + ) + parser.add_argument( + "--num_samples_pt", + help="Number of samples in the pretrained head", + type=int, + default=1000, + ) + parser.add_argument( + "--subselect_pt", + help="Method to subselect the configurations of the pretraining set", + choices=["fps", "random"], + default="random", + ) + parser.add_argument( + "--pt_train_file", + help="Training set file for the pretrained head", + type=str, + default=None, + ) + parser.add_argument( + "--pt_valid_file", + help="Validation set file for the pretrained head", + type=str, + default=None, + ) + parser.add_argument( + "--keep_isolated_atoms", + help="Keep isolated atoms in the dataset, useful for transfer learning", + type=str2bool, + default=False, + ) + + # Keys + parser.add_argument( + "--energy_key", + help="Key of reference energies in training xyz", + type=str, + default="REF_energy", + ) + parser.add_argument( + "--forces_key", + help="Key of reference forces in training xyz", + type=str, + default="REF_forces", + ) + parser.add_argument( + "--virials_key", + help="Key of reference virials in training xyz", + type=str, + default="REF_virials", + ) + parser.add_argument( + "--stress_key", + help="Key of reference stress in training xyz", + type=str, + default="REF_stress", + ) + parser.add_argument( + "--dipole_key", + help="Key of reference dipoles in training xyz", + type=str, + default="REF_dipole", + ) + parser.add_argument( + "--charges_key", + help="Key of atomic charges in training xyz", + type=str, + default="REF_charges", + ) + + # Loss and optimization + parser.add_argument( + "--loss", + help="type of loss", + default="weighted", + choices=[ + "ef", + "weighted", + "forces_only", + "virials", + "stress", + "dipole", + "huber", + "universal", + "energy_forces_dipole", + ], + ) + parser.add_argument( + "--forces_weight", help="weight of forces loss", type=float, default=100.0 + ) + parser.add_argument( + "--swa_forces_weight", + "--stage_two_forces_weight", + help="weight of forces loss after starting Stage Two (previously called swa)", + type=float, + default=100.0, + dest="swa_forces_weight", + ) + parser.add_argument( + "--energy_weight", help="weight of energy loss", type=float, default=1.0 + ) + parser.add_argument( + "--swa_energy_weight", + "--stage_two_energy_weight", + help="weight of energy loss after starting Stage Two (previously called swa)", + type=float, + default=1000.0, + dest="swa_energy_weight", + ) + parser.add_argument( + "--virials_weight", help="weight of virials loss", type=float, default=1.0 + ) + parser.add_argument( + "--swa_virials_weight", + "--stage_two_virials_weight", + help="weight of virials loss after starting Stage Two (previously called swa)", + type=float, + default=10.0, + dest="swa_virials_weight", + ) + parser.add_argument( + "--stress_weight", help="weight of virials loss", type=float, default=1.0 + ) + parser.add_argument( + "--swa_stress_weight", + "--stage_two_stress_weight", + help="weight of stress loss after starting Stage Two (previously called swa)", + type=float, + default=10.0, + dest="swa_stress_weight", + ) + parser.add_argument( + "--dipole_weight", help="weight of dipoles loss", type=float, default=1.0 + ) + parser.add_argument( + "--swa_dipole_weight", + "--stage_two_dipole_weight", + help="weight of dipoles after starting Stage Two (previously called swa)", + type=float, + default=1.0, + dest="swa_dipole_weight", + ) + parser.add_argument( + "--config_type_weights", + help="String of dictionary containing the weights for each config type", + type=str, + default='{"Default":1.0}', + ) + parser.add_argument( + "--huber_delta", + help="delta parameter for huber loss", + type=float, + default=0.01, + ) + parser.add_argument( + "--optimizer", + help="Optimizer for parameter optimization", + type=str, + default="adam", + choices=["adam", "adamw", "schedulefree"], + ) + parser.add_argument( + "--beta", + help="Beta parameter for the optimizer", + type=float, + default=0.9, + ) + parser.add_argument("--batch_size", help="batch size", type=int, default=10) + parser.add_argument( + "--valid_batch_size", help="Validation batch size", type=int, default=10 + ) + parser.add_argument( + "--lr", help="Learning rate of optimizer", type=float, default=0.01 + ) + parser.add_argument( + "--swa_lr", + "--stage_two_lr", + help="Learning rate of optimizer in Stage Two (previously called swa)", + type=float, + default=1e-3, + dest="swa_lr", + ) + parser.add_argument( + "--weight_decay", help="weight decay (L2 penalty)", type=float, default=5e-7 + ) + parser.add_argument( + "--amsgrad", + help="use amsgrad variant of optimizer", + action="store_true", + default=True, + ) + parser.add_argument( + "--scheduler", help="Type of scheduler", type=str, default="ReduceLROnPlateau" + ) + parser.add_argument( + "--lr_factor", help="Learning rate factor", type=float, default=0.8 + ) + parser.add_argument( + "--scheduler_patience", help="Learning rate factor", type=int, default=50 + ) + parser.add_argument( + "--lr_scheduler_gamma", + help="Gamma of learning rate scheduler", + type=float, + default=0.9993, + ) + parser.add_argument( + "--swa", + "--stage_two", + help="use Stage Two loss weight, which decreases the learning rate and increases the energy weight at the end of the training to help converge them", + action="store_true", + default=False, + dest="swa", + ) + parser.add_argument( + "--start_swa", + "--start_stage_two", + help="Number of epochs before changing to Stage Two loss weights", + type=int, + default=None, + dest="start_swa", + ) + parser.add_argument( + "--ema", + help="use Exponential Moving Average", + action="store_true", + default=False, + ) + parser.add_argument( + "--ema_decay", + help="Exponential Moving Average decay", + type=float, + default=0.99, + ) + parser.add_argument( + "--max_num_epochs", help="Maximum number of epochs", type=int, default=2048 + ) + parser.add_argument( + "--patience", + help="Maximum number of consecutive epochs of increasing loss", + type=int, + default=2048, + ) + parser.add_argument( + "--foundation_model", + help="Path to the foundation model for transfer learning", + type=str, + default=None, + ) + parser.add_argument( + "--foundation_model_readout", + help="Use readout of foundation model for transfer learning", + action="store_false", + default=True, + ) + parser.add_argument( + "--eval_interval", help="evaluate model every epochs", type=int, default=1 + ) + parser.add_argument( + "--keep_checkpoints", + help="keep all checkpoints", + action="store_true", + default=False, + ) + parser.add_argument( + "--save_all_checkpoints", + help="save all checkpoints", + action="store_true", + default=False, + ) + parser.add_argument( + "--restart_latest", + help="restart optimizer from latest checkpoint", + action="store_true", + default=False, + ) + parser.add_argument( + "--save_cpu", + help="Save a model to be loaded on cpu", + action="store_true", + default=False, + ) + parser.add_argument( + "--clip_grad", + help="Gradient Clipping Value", + type=check_float_or_none, + default=10.0, + ) + # options for using Weights and Biases for experiment tracking + # to install see https://wandb.ai + parser.add_argument( + "--wandb", + help="Use Weights and Biases for experiment tracking", + action="store_true", + default=False, + ) + parser.add_argument( + "--wandb_dir", + help="An absolute path to a directory where Weights and Biases metadata will be stored", + type=str, + default=None, + ) + parser.add_argument( + "--wandb_project", + help="Weights and Biases project name", + type=str, + default="", + ) + parser.add_argument( + "--wandb_entity", + help="Weights and Biases entity name", + type=str, + default="", + ) + parser.add_argument( + "--wandb_name", + help="Weights and Biases experiment name", + type=str, + default="", + ) + parser.add_argument( + "--wandb_log_hypers", + help="The hyperparameters to log in Weights and Biases", + type=list, + default=[ + "num_channels", + "max_L", + "correlation", + "lr", + "swa_lr", + "weight_decay", + "batch_size", + "max_num_epochs", + "start_swa", + "energy_weight", + "forces_weight", + ], + ) + return parser + + +def build_preprocess_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--train_file", + help="Training set h5 file", + type=str, + default=None, + required=True, + ) + parser.add_argument( + "--valid_file", + help="Training set xyz file", + type=str, + default=None, + required=False, + ) + parser.add_argument( + "--num_process", + help="The user defined number of processes to use, as well as the number of files created.", + type=int, + default=int(os.cpu_count() / 4), + ) + parser.add_argument( + "--valid_fraction", + help="Fraction of training set used for validation", + type=float, + default=0.1, + required=False, + ) + parser.add_argument( + "--test_file", + help="Test set xyz file", + type=str, + default=None, + required=False, + ) + parser.add_argument( + "--work_dir", + help="set directory for all files and folders", + type=str, + default=".", + ) + parser.add_argument( + "--h5_prefix", + help="Prefix for h5 files when saving", + type=str, + default="", + ) + parser.add_argument( + "--r_max", help="distance cutoff (in Ang)", type=float, default=5.0 + ) + parser.add_argument( + "--config_type_weights", + help="String of dictionary containing the weights for each config type", + type=str, + default='{"Default":1.0}', + ) + parser.add_argument( + "--energy_key", + help="Key of reference energies in training xyz", + type=str, + default="REF_energy", + ) + parser.add_argument( + "--forces_key", + help="Key of reference forces in training xyz", + type=str, + default="REF_forces", + ) + parser.add_argument( + "--virials_key", + help="Key of reference virials in training xyz", + type=str, + default="REF_virials", + ) + parser.add_argument( + "--stress_key", + help="Key of reference stress in training xyz", + type=str, + default="REF_stress", + ) + parser.add_argument( + "--dipole_key", + help="Key of reference dipoles in training xyz", + type=str, + default="REF_dipole", + ) + parser.add_argument( + "--charges_key", + help="Key of atomic charges in training xyz", + type=str, + default="REF_charges", + ) + parser.add_argument( + "--atomic_numbers", + help="List of atomic numbers", + type=str, + default=None, + required=False, + ) + parser.add_argument( + "--compute_statistics", + help="Compute statistics for the dataset", + action="store_true", + default=False, + ) + parser.add_argument( + "--batch_size", + help="batch size to compute average number of neighbours", + type=int, + default=16, + ) + + parser.add_argument( + "--scaling", + help="type of scaling to the output", + type=str, + default="rms_forces_scaling", + choices=["std_scaling", "rms_forces_scaling", "no_scaling"], + ) + parser.add_argument( + "--E0s", + help="Dictionary of isolated atom energies", + type=str, + default=None, + required=False, + ) + parser.add_argument( + "--shuffle", + help="Shuffle the training dataset", + type=str2bool, + default=True, + ) + parser.add_argument( + "--seed", + help="Random seed for splitting training and validation sets", + type=int, + default=123, + ) + return parser + + +def check_float_or_none(value: str) -> Optional[float]: + try: + return float(value) + except ValueError: + if value != "None": + raise argparse.ArgumentTypeError( + f"{value} is an invalid value (float or None)" + ) from None + return None + + +def str2bool(value): + if isinstance(value, bool): + return value + if value.lower() in ("yes", "true", "t", "y", "1"): + return True + if value.lower() in ("no", "false", "f", "n", "0"): + return False + raise argparse.ArgumentTypeError("Boolean value expected.") diff --git a/mace/tools/arg_parser_tools.py b/mace/tools/arg_parser_tools.py new file mode 100644 index 00000000..da64806a --- /dev/null +++ b/mace/tools/arg_parser_tools.py @@ -0,0 +1,113 @@ +import logging +import os + +from e3nn import o3 + + +def check_args(args): + """ + Check input arguments, update them if necessary for valid and consistent inputs, and return a tuple containing + the (potentially) modified args and a list of log messages. + """ + log_messages = [] + + # Directories + # Use work_dir for all other directories as well, unless they were specified by the user + if args.log_dir is None: + args.log_dir = os.path.join(args.work_dir, "logs") + if args.model_dir is None: + args.model_dir = args.work_dir + if args.checkpoints_dir is None: + args.checkpoints_dir = os.path.join(args.work_dir, "checkpoints") + if args.results_dir is None: + args.results_dir = os.path.join(args.work_dir, "results") + if args.downloads_dir is None: + args.downloads_dir = os.path.join(args.work_dir, "downloads") + + # Model + # Check if hidden_irreps, num_channels and max_L are consistent + if args.hidden_irreps is None and args.num_channels is None and args.max_L is None: + args.hidden_irreps, args.num_channels, args.max_L = "128x0e + 128x1o", 128, 1 + elif ( + args.hidden_irreps is not None + and args.num_channels is not None + and args.max_L is not None + ): + args.hidden_irreps = o3.Irreps( + (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) + .sort() + .irreps.simplify() + ) + log_messages.append( + ( + "All of hidden_irreps, num_channels and max_L are specified", + logging.WARNING, + ) + ) + log_messages.append( + ( + f"Using num_channels and max_L to create hidden_irreps: {args.hidden_irreps}.", + logging.WARNING, + ) + ) + assert ( + len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 + ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" + elif args.num_channels is not None and args.max_L is not None: + assert args.num_channels > 0, "num_channels must be positive integer" + assert args.max_L >= 0, "max_L must be non-negative integer" + args.hidden_irreps = o3.Irreps( + (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) + .sort() + .irreps.simplify() + ) + assert ( + len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 + ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" + elif args.hidden_irreps is not None: + assert ( + len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 + ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" + + args.num_channels = list( + {irrep.mul for irrep in o3.Irreps(args.hidden_irreps)} + )[0] + args.max_L = o3.Irreps(args.hidden_irreps).lmax + elif args.max_L is not None and args.num_channels is None: + assert args.max_L >= 0, "max_L must be non-negative integer" + args.num_channels = 128 + args.hidden_irreps = o3.Irreps( + (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) + .sort() + .irreps.simplify() + ) + elif args.max_L is None and args.num_channels is not None: + assert args.num_channels > 0, "num_channels must be positive integer" + args.max_L = 1 + args.hidden_irreps = o3.Irreps( + (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) + .sort() + .irreps.simplify() + ) + + # Loss and optimization + # Check Stage Two loss start + if args.swa: + if args.start_swa is None: + args.start_swa = max(1, args.max_num_epochs // 4 * 3) + if args.start_swa > args.max_num_epochs: + log_messages.append( + ( + f"start_stage_two must be less than max_num_epochs, got {args.start_swa} > {args.max_num_epochs}", + logging.WARNING, + ) + ) + log_messages.append( + ( + "Stage Two will not start, as start_stage_two > max_num_epochs", + logging.WARNING, + ) + ) + args.swa = False + + return args, log_messages diff --git a/mace/tools/cg.py b/mace/tools/cg.py new file mode 100644 index 00000000..2cca09c9 --- /dev/null +++ b/mace/tools/cg.py @@ -0,0 +1,131 @@ +########################################################################################### +# Higher Order Real Clebsch Gordan (based on e3nn by Mario Geiger) +# Authors: Ilyes Batatia +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import collections +from typing import List, Union + +import torch +from e3nn import o3 + +_TP = collections.namedtuple("_TP", "op, args") +_INPUT = collections.namedtuple("_INPUT", "tensor, start, stop") + + +def _wigner_nj( + irrepss: List[o3.Irreps], + normalization: str = "component", + filter_ir_mid=None, + dtype=None, +): + irrepss = [o3.Irreps(irreps) for irreps in irrepss] + if filter_ir_mid is not None: + filter_ir_mid = [o3.Irrep(ir) for ir in filter_ir_mid] + + if len(irrepss) == 1: + (irreps,) = irrepss + ret = [] + e = torch.eye(irreps.dim, dtype=dtype) + i = 0 + for mul, ir in irreps: + for _ in range(mul): + sl = slice(i, i + ir.dim) + ret += [(ir, _INPUT(0, sl.start, sl.stop), e[sl])] + i += ir.dim + return ret + + *irrepss_left, irreps_right = irrepss + ret = [] + for ir_left, path_left, C_left in _wigner_nj( + irrepss_left, + normalization=normalization, + filter_ir_mid=filter_ir_mid, + dtype=dtype, + ): + i = 0 + for mul, ir in irreps_right: + for ir_out in ir_left * ir: + if filter_ir_mid is not None and ir_out not in filter_ir_mid: + continue + + C = o3.wigner_3j(ir_out.l, ir_left.l, ir.l, dtype=dtype) + if normalization == "component": + C *= ir_out.dim**0.5 + if normalization == "norm": + C *= ir_left.dim**0.5 * ir.dim**0.5 + + C = torch.einsum("jk,ijl->ikl", C_left.flatten(1), C) + C = C.reshape( + ir_out.dim, *(irreps.dim for irreps in irrepss_left), ir.dim + ) + for u in range(mul): + E = torch.zeros( + ir_out.dim, + *(irreps.dim for irreps in irrepss_left), + irreps_right.dim, + dtype=dtype, + ) + sl = slice(i + u * ir.dim, i + (u + 1) * ir.dim) + E[..., sl] = C + ret += [ + ( + ir_out, + _TP( + op=(ir_left, ir, ir_out), + args=( + path_left, + _INPUT(len(irrepss_left), sl.start, sl.stop), + ), + ), + E, + ) + ] + i += mul * ir.dim + return sorted(ret, key=lambda x: x[0]) + + +def U_matrix_real( + irreps_in: Union[str, o3.Irreps], + irreps_out: Union[str, o3.Irreps], + correlation: int, + normalization: str = "component", + filter_ir_mid=None, + dtype=None, +): + irreps_out = o3.Irreps(irreps_out) + irrepss = [o3.Irreps(irreps_in)] * correlation + if correlation == 4: + filter_ir_mid = [ + (0, 1), + (1, -1), + (2, 1), + (3, -1), + (4, 1), + (5, -1), + (6, 1), + (7, -1), + (8, 1), + (9, -1), + (10, 1), + (11, -1), + ] + wigners = _wigner_nj(irrepss, normalization, filter_ir_mid, dtype) + current_ir = wigners[0][0] + out = [] + stack = torch.tensor([]) + + for ir, _, base_o3 in wigners: + if ir in irreps_out and ir == current_ir: + stack = torch.cat((stack, base_o3.squeeze().unsqueeze(-1)), dim=-1) + last_ir = current_ir + elif ir in irreps_out and ir != current_ir: + if len(stack) != 0: + out += [last_ir, stack] + stack = base_o3.squeeze().unsqueeze(-1) + current_ir, last_ir = ir, ir + else: + current_ir = ir + out += [last_ir, stack] + return out diff --git a/mace/tools/checkpoint.py b/mace/tools/checkpoint.py new file mode 100644 index 00000000..81161ccc --- /dev/null +++ b/mace/tools/checkpoint.py @@ -0,0 +1,227 @@ +########################################################################################### +# Checkpointing +# Authors: Gregor Simm +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import dataclasses +import logging +import os +import re +from typing import Dict, List, Optional, Tuple + +import torch + +from .torch_tools import TensorDict + +Checkpoint = Dict[str, TensorDict] + + +@dataclasses.dataclass +class CheckpointState: + model: torch.nn.Module + optimizer: torch.optim.Optimizer + lr_scheduler: torch.optim.lr_scheduler.ExponentialLR + + +class CheckpointBuilder: + @staticmethod + def create_checkpoint(state: CheckpointState) -> Checkpoint: + return { + "model": state.model.state_dict(), + "optimizer": state.optimizer.state_dict(), + "lr_scheduler": state.lr_scheduler.state_dict(), + } + + @staticmethod + def load_checkpoint( + state: CheckpointState, checkpoint: Checkpoint, strict: bool + ) -> None: + state.model.load_state_dict(checkpoint["model"], strict=strict) # type: ignore + state.optimizer.load_state_dict(checkpoint["optimizer"]) + state.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) + + +@dataclasses.dataclass +class CheckpointPathInfo: + path: str + tag: str + epochs: int + swa: bool + + +class CheckpointIO: + def __init__( + self, directory: str, tag: str, keep: bool = False, swa_start: int = None + ) -> None: + self.directory = directory + self.tag = tag + self.keep = keep + self.old_path: Optional[str] = None + self.swa_start = swa_start + + self._epochs_string = "_epoch-" + self._filename_extension = "pt" + + def _get_checkpoint_filename(self, epochs: int, swa_start=None) -> str: + if swa_start is not None and epochs >= swa_start: + return ( + self.tag + + self._epochs_string + + str(epochs) + + "_swa" + + "." + + self._filename_extension + ) + return ( + self.tag + + self._epochs_string + + str(epochs) + + "." + + self._filename_extension + ) + + def _list_file_paths(self) -> List[str]: + if not os.path.isdir(self.directory): + return [] + all_paths = [ + os.path.join(self.directory, f) for f in os.listdir(self.directory) + ] + return [path for path in all_paths if os.path.isfile(path)] + + def _parse_checkpoint_path(self, path: str) -> Optional[CheckpointPathInfo]: + filename = os.path.basename(path) + regex = re.compile( + rf"^(?P.+){self._epochs_string}(?P\d+)\.{self._filename_extension}$" + ) + regex2 = re.compile( + rf"^(?P.+){self._epochs_string}(?P\d+)_swa\.{self._filename_extension}$" + ) + match = regex.match(filename) + match2 = regex2.match(filename) + swa = False + if not match: + if not match2: + return None + match = match2 + swa = True + + return CheckpointPathInfo( + path=path, + tag=match.group("tag"), + epochs=int(match.group("epochs")), + swa=swa, + ) + + def _get_latest_checkpoint_path(self, swa) -> Optional[str]: + all_file_paths = self._list_file_paths() + checkpoint_info_list = [ + self._parse_checkpoint_path(path) for path in all_file_paths + ] + selected_checkpoint_info_list = [ + info for info in checkpoint_info_list if info and info.tag == self.tag + ] + + if len(selected_checkpoint_info_list) == 0: + logging.warning( + f"Cannot find checkpoint with tag '{self.tag}' in '{self.directory}'" + ) + return None + + selected_checkpoint_info_list_swa = [] + selected_checkpoint_info_list_no_swa = [] + + for ckp in selected_checkpoint_info_list: + if ckp.swa: + selected_checkpoint_info_list_swa.append(ckp) + else: + selected_checkpoint_info_list_no_swa.append(ckp) + if swa: + try: + latest_checkpoint_info = max( + selected_checkpoint_info_list_swa, key=lambda info: info.epochs + ) + except ValueError: + logging.warning( + "No SWA checkpoint found, while SWA is enabled. Compare the swa_start parameter and the latest checkpoint." + ) + else: + latest_checkpoint_info = max( + selected_checkpoint_info_list_no_swa, key=lambda info: info.epochs + ) + return latest_checkpoint_info.path + + def save( + self, checkpoint: Checkpoint, epochs: int, keep_last: bool = False + ) -> None: + if not self.keep and self.old_path and not keep_last: + logging.debug(f"Deleting old checkpoint file: {self.old_path}") + os.remove(self.old_path) + + filename = self._get_checkpoint_filename(epochs, self.swa_start) + path = os.path.join(self.directory, filename) + logging.debug(f"Saving checkpoint: {path}") + os.makedirs(self.directory, exist_ok=True) + torch.save(obj=checkpoint, f=path) + self.old_path = path + + def load_latest( + self, swa: Optional[bool] = False, device: Optional[torch.device] = None + ) -> Optional[Tuple[Checkpoint, int]]: + path = self._get_latest_checkpoint_path(swa=swa) + if path is None: + return None + + return self.load(path, device=device) + + def load( + self, path: str, device: Optional[torch.device] = None + ) -> Tuple[Checkpoint, int]: + checkpoint_info = self._parse_checkpoint_path(path) + + if checkpoint_info is None: + raise RuntimeError(f"Cannot find path '{path}'") + + logging.info(f"Loading checkpoint: {checkpoint_info.path}") + return ( + torch.load(f=checkpoint_info.path, map_location=device), + checkpoint_info.epochs, + ) + + +class CheckpointHandler: + def __init__(self, *args, **kwargs) -> None: + self.io = CheckpointIO(*args, **kwargs) + self.builder = CheckpointBuilder() + + def save( + self, state: CheckpointState, epochs: int, keep_last: bool = False + ) -> None: + checkpoint = self.builder.create_checkpoint(state) + self.io.save(checkpoint, epochs, keep_last) + + def load_latest( + self, + state: CheckpointState, + swa: Optional[bool] = False, + device: Optional[torch.device] = None, + strict=False, + ) -> Optional[int]: + result = self.io.load_latest(swa=swa, device=device) + if result is None: + return None + + checkpoint, epochs = result + self.builder.load_checkpoint(state=state, checkpoint=checkpoint, strict=strict) + return epochs + + def load( + self, + state: CheckpointState, + path: str, + strict=False, + device: Optional[torch.device] = None, + ) -> int: + checkpoint, epochs = self.io.load(path, device=device) + self.builder.load_checkpoint(state=state, checkpoint=checkpoint, strict=strict) + return epochs diff --git a/mace/tools/compile.py b/mace/tools/compile.py new file mode 100644 index 00000000..03282067 --- /dev/null +++ b/mace/tools/compile.py @@ -0,0 +1,95 @@ +from contextlib import contextmanager +from functools import wraps +from typing import Callable, Tuple + +try: + import torch._dynamo as dynamo +except ImportError: + dynamo = None +from e3nn import get_optimization_defaults, set_optimization_defaults +from torch import autograd, nn +from torch.fx import symbolic_trace + +ModuleFactory = Callable[..., nn.Module] +TypeTuple = Tuple[type, ...] + + +@contextmanager +def disable_e3nn_codegen(): + """Context manager that disables the legacy PyTorch code generation used in e3nn.""" + init_val = get_optimization_defaults()["jit_script_fx"] + set_optimization_defaults(jit_script_fx=False) + yield + set_optimization_defaults(jit_script_fx=init_val) + + +def prepare(func: ModuleFactory, allow_autograd: bool = True) -> ModuleFactory: + """Function transform that prepares a MACE module for torch.compile + + Args: + func (ModuleFactory): A function that creates an nn.Module + allow_autograd (bool, optional): Force inductor compiler to inline call to + `torch.autograd.grad`. Defaults to True. + + Returns: + ModuleFactory: Decorated function that creates a torch.compile compatible module + """ + if allow_autograd: + dynamo.allow_in_graph(autograd.grad) + else: + dynamo.disallow_in_graph(autograd.grad) + + @wraps(func) + def wrapper(*args, **kwargs): + with disable_e3nn_codegen(): + model = func(*args, **kwargs) + + model = simplify(model) + return model + + return wrapper + + +_SIMPLIFY_REGISTRY = set() + + +def simplify_if_compile(module: nn.Module) -> nn.Module: + """Decorator to register a module for symbolic simplification + + The decorated module will be simplifed using `torch.fx.symbolic_trace`. + This constrains the module to not have any dynamic control flow, see: + + https://pytorch.org/docs/stable/fx.html#limitations-of-symbolic-tracing + + Args: + module (nn.Module): the module to register + + Returns: + nn.Module: registered module + """ + _SIMPLIFY_REGISTRY.add(module) + return module + + +def simplify(module: nn.Module) -> nn.Module: + """Recursively searches for registered modules to simplify with + `torch.fx.symbolic_trace` to support compiling with the PyTorch Dynamo compiler. + + Modules are registered with the `simplify_if_compile` decorator and + + Args: + module (nn.Module): the module to simplify + + Returns: + nn.Module: the simplified module + """ + simplify_types = tuple(_SIMPLIFY_REGISTRY) + + for name, child in module.named_children(): + if isinstance(child, simplify_types): + traced = symbolic_trace(child) + setattr(module, name, traced) + else: + simplify(child) + + return module diff --git a/mace/tools/finetuning_utils.py b/mace/tools/finetuning_utils.py new file mode 100644 index 00000000..8df0b0d1 --- /dev/null +++ b/mace/tools/finetuning_utils.py @@ -0,0 +1,204 @@ +import torch + +from mace.tools.utils import AtomicNumberTable + + +def load_foundations_elements( + model: torch.nn.Module, + model_foundations: torch.nn.Module, + table: AtomicNumberTable, + load_readout=False, + use_shift=True, + use_scale=True, + max_L=2, +): + """ + Load the foundations of a model into a model for fine-tuning. + """ + assert model_foundations.r_max == model.r_max + z_table = AtomicNumberTable([int(z) for z in model_foundations.atomic_numbers]) + model_heads = model.heads + new_z_table = table + num_species_foundations = len(z_table.zs) + num_channels_foundation = ( + model_foundations.node_embedding.linear.weight.shape[0] + // num_species_foundations + ) + indices_weights = [z_table.z_to_index(z) for z in new_z_table.zs] + num_radial = model.radial_embedding.out_dim + num_species = len(indices_weights) + max_ell = model.spherical_harmonics._lmax # pylint: disable=protected-access + model.node_embedding.linear.weight = torch.nn.Parameter( + model_foundations.node_embedding.linear.weight.view( + num_species_foundations, -1 + )[indices_weights, :] + .flatten() + .clone() + / (num_species_foundations / num_species) ** 0.5 + ) + if model.radial_embedding.bessel_fn.__class__.__name__ == "BesselBasis": + model.radial_embedding.bessel_fn.bessel_weights = torch.nn.Parameter( + model_foundations.radial_embedding.bessel_fn.bessel_weights.clone() + ) + for i in range(int(model.num_interactions)): + model.interactions[i].linear_up.weight = torch.nn.Parameter( + model_foundations.interactions[i].linear_up.weight.clone() + ) + model.interactions[i].avg_num_neighbors = model_foundations.interactions[ + i + ].avg_num_neighbors + for j in range(4): # Assuming 4 layers in conv_tp_weights, + layer_name = f"layer{j}" + if j == 0: + getattr(model.interactions[i].conv_tp_weights, layer_name).weight = ( + torch.nn.Parameter( + getattr( + model_foundations.interactions[i].conv_tp_weights, + layer_name, + ) + .weight[:num_radial, :] + .clone() + ) + ) + else: + getattr(model.interactions[i].conv_tp_weights, layer_name).weight = ( + torch.nn.Parameter( + getattr( + model_foundations.interactions[i].conv_tp_weights, + layer_name, + ).weight.clone() + ) + ) + + model.interactions[i].linear.weight = torch.nn.Parameter( + model_foundations.interactions[i].linear.weight.clone() + ) + if model.interactions[i].__class__.__name__ in [ + "RealAgnosticResidualInteractionBlock", + "RealAgnosticDensityResidualInteractionBlock", + ]: + model.interactions[i].skip_tp.weight = torch.nn.Parameter( + model_foundations.interactions[i] + .skip_tp.weight.reshape( + num_channels_foundation, + num_species_foundations, + num_channels_foundation, + )[:, indices_weights, :] + .flatten() + .clone() + / (num_species_foundations / num_species) ** 0.5 + ) + else: + model.interactions[i].skip_tp.weight = torch.nn.Parameter( + model_foundations.interactions[i] + .skip_tp.weight.reshape( + num_channels_foundation, + (max_ell + 1), + num_species_foundations, + num_channels_foundation, + )[:, :, indices_weights, :] + .flatten() + .clone() + / (num_species_foundations / num_species) ** 0.5 + ) + if model.interactions[i].__class__.__name__ in [ + "RealAgnosticDensityInteractionBlock", + "RealAgnosticDensityResidualInteractionBlock", + ]: + # Assuming only 1 layer in density_fn + getattr(model.interactions[i].density_fn, "layer0").weight = ( + torch.nn.Parameter( + getattr( + model_foundations.interactions[i].density_fn, + "layer0", + ).weight.clone() + ) + ) + # Transferring products + for i in range(2): # Assuming 2 products modules + max_range = max_L + 1 if i == 0 else 1 + for j in range(max_range): # Assuming 3 contractions in symmetric_contractions + model.products[i].symmetric_contractions.contractions[j].weights_max = ( + torch.nn.Parameter( + model_foundations.products[i] + .symmetric_contractions.contractions[j] + .weights_max[indices_weights, :, :] + .clone() + ) + ) + + for k in range(2): # Assuming 2 weights in each contraction + model.products[i].symmetric_contractions.contractions[j].weights[k] = ( + torch.nn.Parameter( + model_foundations.products[i] + .symmetric_contractions.contractions[j] + .weights[k][indices_weights, :, :] + .clone() + ) + ) + + model.products[i].linear.weight = torch.nn.Parameter( + model_foundations.products[i].linear.weight.clone() + ) + + if load_readout: + # Transferring readouts + model_readouts_zero_linear_weight = model.readouts[0].linear.weight.clone() + model_readouts_zero_linear_weight = ( + model_foundations.readouts[0] + .linear.weight.view(num_channels_foundation, -1) + .repeat(1, len(model_heads)) + .flatten() + .clone() + ) + model.readouts[0].linear.weight = torch.nn.Parameter( + model_readouts_zero_linear_weight + ) + + shape_input_1 = ( + model_foundations.readouts[1].linear_1.__dict__["irreps_out"].num_irreps + ) + shape_output_1 = model.readouts[1].linear_1.__dict__["irreps_out"].num_irreps + model_readouts_one_linear_1_weight = model.readouts[1].linear_1.weight.clone() + model_readouts_one_linear_1_weight = ( + model_foundations.readouts[1] + .linear_1.weight.view(num_channels_foundation, -1) + .repeat(1, len(model_heads)) + .flatten() + .clone() + ) + model.readouts[1].linear_1.weight = torch.nn.Parameter( + model_readouts_one_linear_1_weight + ) + model_readouts_one_linear_2_weight = model.readouts[1].linear_2.weight.clone() + model_readouts_one_linear_2_weight = model_foundations.readouts[ + 1 + ].linear_2.weight.view(shape_input_1, -1).repeat( + len(model_heads), len(model_heads) + ).flatten().clone() / ( + ((shape_input_1) / (shape_output_1)) ** 0.5 + ) + model.readouts[1].linear_2.weight = torch.nn.Parameter( + model_readouts_one_linear_2_weight + ) + if model_foundations.scale_shift is not None: + if use_scale: + model.scale_shift.scale = model_foundations.scale_shift.scale.repeat( + len(model_heads) + ).clone() + if use_shift: + model.scale_shift.shift = model_foundations.scale_shift.shift.repeat( + len(model_heads) + ).clone() + return model + + +def load_foundations( + model, + model_foundations, +): + for name, param in model_foundations.named_parameters(): + if name in model.state_dict().keys(): + if "readouts" not in name: + model.state_dict()[name].copy_(param) + return model diff --git a/mace/tools/model_script_utils.py b/mace/tools/model_script_utils.py new file mode 100644 index 00000000..3f49eb41 --- /dev/null +++ b/mace/tools/model_script_utils.py @@ -0,0 +1,228 @@ +import ast +import logging + +import numpy as np +from e3nn import o3 + +from mace import modules +from mace.tools.finetuning_utils import load_foundations_elements +from mace.tools.scripts_utils import extract_config_mace_model + + +def configure_model( + args, train_loader, atomic_energies, model_foundation=None, heads=None, z_table=None +): + # Selecting outputs + compute_virials = args.loss in ("stress", "virials", "huber", "universal") + if compute_virials: + args.compute_stress = True + args.error_table = "PerAtomRMSEstressvirials" + + output_args = { + "energy": args.compute_energy, + "forces": args.compute_forces, + "virials": compute_virials, + "stress": args.compute_stress, + "dipoles": args.compute_dipole, + } + logging.info( + f"During training the following quantities will be reported: {', '.join([f'{report}' for report, value in output_args.items() if value])}" + ) + logging.info("===========MODEL DETAILS===========") + + if args.scaling == "no_scaling": + args.std = 1.0 + logging.info("No scaling selected") + elif (args.mean is None or args.std is None) and args.model != "AtomicDipolesMACE": + args.mean, args.std = modules.scaling_classes[args.scaling]( + train_loader, atomic_energies + ) + + # Build model + if model_foundation is not None and args.model in ["MACE", "ScaleShiftMACE"]: + logging.info("Loading FOUNDATION model") + model_config_foundation = extract_config_mace_model(model_foundation) + model_config_foundation["atomic_energies"] = atomic_energies + model_config_foundation["atomic_numbers"] = z_table.zs + model_config_foundation["num_elements"] = len(z_table) + args.max_L = model_config_foundation["hidden_irreps"].lmax + + if args.model == "MACE" and model_foundation.__class__.__name__ == "MACE": + model_config_foundation["atomic_inter_shift"] = [0.0] * len(heads) + else: + model_config_foundation["atomic_inter_shift"] = ( + _determine_atomic_inter_shift(args.mean, heads) + ) + model_config_foundation["atomic_inter_scale"] = [1.0] * len(heads) + args.avg_num_neighbors = model_config_foundation["avg_num_neighbors"] + args.model = "FoundationMACE" + model_config_foundation["heads"] = heads + model_config = model_config_foundation + + logging.info("Model configuration extracted from foundation model") + logging.info("Using universal loss function for fine-tuning") + logging.info( + f"Message passing with hidden irreps {model_config_foundation['hidden_irreps']})" + ) + logging.info( + f"{model_config_foundation['num_interactions']} layers, each with correlation order: {model_config_foundation['correlation']} (body order: {model_config_foundation['correlation']+1}) and spherical harmonics up to: l={model_config_foundation['max_ell']}" + ) + logging.info( + f"Radial cutoff: {model_config_foundation['r_max']} A (total receptive field for each atom: {model_config_foundation['r_max'] * model_config_foundation['num_interactions']} A)" + ) + logging.info( + f"Distance transform for radial basis functions: {model_config_foundation['distance_transform']}" + ) + else: + logging.info("Building model") + logging.info( + f"Message passing with {args.num_channels} channels and max_L={args.max_L} ({args.hidden_irreps})" + ) + logging.info( + f"{args.num_interactions} layers, each with correlation order: {args.correlation} (body order: {args.correlation+1}) and spherical harmonics up to: l={args.max_ell}" + ) + logging.info( + f"{args.num_radial_basis} radial and {args.num_cutoff_basis} basis functions" + ) + logging.info( + f"Radial cutoff: {args.r_max} A (total receptive field for each atom: {args.r_max * args.num_interactions} A)" + ) + logging.info( + f"Distance transform for radial basis functions: {args.distance_transform}" + ) + + assert ( + len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 + ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" + + logging.info(f"Hidden irreps: {args.hidden_irreps}") + + model_config = dict( + r_max=args.r_max, + num_bessel=args.num_radial_basis, + num_polynomial_cutoff=args.num_cutoff_basis, + max_ell=args.max_ell, + interaction_cls=modules.interaction_classes[args.interaction], + num_interactions=args.num_interactions, + num_elements=len(z_table), + hidden_irreps=o3.Irreps(args.hidden_irreps), + atomic_energies=atomic_energies, + avg_num_neighbors=args.avg_num_neighbors, + atomic_numbers=z_table.zs, + ) + model_config_foundation = None + + model = _build_model(args, model_config, model_config_foundation, heads) + + if model_foundation is not None: + model = load_foundations_elements( + model, + model_foundation, + z_table, + load_readout=args.foundation_filter_elements, + max_L=args.max_L, + ) + + return model, output_args + + +def _determine_atomic_inter_shift(mean, heads): + if isinstance(mean, np.ndarray): + if mean.size == 1: + return mean.item() + if mean.size == len(heads): + return mean.tolist() + logging.info("Mean not in correct format, using default value of 0.0") + return [0.0] * len(heads) + if isinstance(mean, list) and len(mean) == len(heads): + return mean + if isinstance(mean, float): + return [mean] * len(heads) + logging.info("Mean not in correct format, using default value of 0.0") + return [0.0] * len(heads) + + +def _build_model( + args, model_config, model_config_foundation, heads +): # pylint: disable=too-many-return-statements + if args.model == "MACE": + return modules.ScaleShiftMACE( + **model_config, + pair_repulsion=args.pair_repulsion, + distance_transform=args.distance_transform, + correlation=args.correlation, + gate=modules.gate_dict[args.gate], + interaction_cls_first=modules.interaction_classes[ + "RealAgnosticInteractionBlock" + ], + MLP_irreps=o3.Irreps(args.MLP_irreps), + atomic_inter_scale=args.std, + atomic_inter_shift=[0.0] * len(heads), + radial_MLP=ast.literal_eval(args.radial_MLP), + radial_type=args.radial_type, + heads=heads, + ) + if args.model == "ScaleShiftMACE": + return modules.ScaleShiftMACE( + **model_config, + pair_repulsion=args.pair_repulsion, + distance_transform=args.distance_transform, + correlation=args.correlation, + gate=modules.gate_dict[args.gate], + interaction_cls_first=modules.interaction_classes[args.interaction_first], + MLP_irreps=o3.Irreps(args.MLP_irreps), + atomic_inter_scale=args.std, + atomic_inter_shift=args.mean, + radial_MLP=ast.literal_eval(args.radial_MLP), + radial_type=args.radial_type, + heads=heads, + ) + if args.model == "FoundationMACE": + return modules.ScaleShiftMACE(**model_config_foundation) + if args.model == "ScaleShiftBOTNet": + return modules.ScaleShiftBOTNet( + **model_config, + gate=modules.gate_dict[args.gate], + interaction_cls_first=modules.interaction_classes[args.interaction_first], + MLP_irreps=o3.Irreps(args.MLP_irreps), + atomic_inter_scale=args.std, + atomic_inter_shift=args.mean, + ) + if args.model == "BOTNet": + return modules.BOTNet( + **model_config, + gate=modules.gate_dict[args.gate], + interaction_cls_first=modules.interaction_classes[args.interaction_first], + MLP_irreps=o3.Irreps(args.MLP_irreps), + ) + if args.model == "AtomicDipolesMACE": + assert args.loss == "dipole", "Use dipole loss with AtomicDipolesMACE model" + assert ( + args.error_table == "DipoleRMSE" + ), "Use error_table DipoleRMSE with AtomicDipolesMACE model" + return modules.AtomicDipolesMACE( + **model_config, + correlation=args.correlation, + gate=modules.gate_dict[args.gate], + interaction_cls_first=modules.interaction_classes[ + "RealAgnosticInteractionBlock" + ], + MLP_irreps=o3.Irreps(args.MLP_irreps), + ) + if args.model == "EnergyDipolesMACE": + assert ( + args.loss == "energy_forces_dipole" + ), "Use energy_forces_dipole loss with EnergyDipolesMACE model" + assert ( + args.error_table == "EnergyDipoleRMSE" + ), "Use error_table EnergyDipoleRMSE with AtomicDipolesMACE model" + return modules.EnergyDipolesMACE( + **model_config, + correlation=args.correlation, + gate=modules.gate_dict[args.gate], + interaction_cls_first=modules.interaction_classes[ + "RealAgnosticInteractionBlock" + ], + MLP_irreps=o3.Irreps(args.MLP_irreps), + ) + raise RuntimeError(f"Unknown model: '{args.model}'") diff --git a/mace/tools/multihead_tools.py b/mace/tools/multihead_tools.py new file mode 100644 index 00000000..ffde107f --- /dev/null +++ b/mace/tools/multihead_tools.py @@ -0,0 +1,185 @@ +import argparse +import dataclasses +import logging +import os +import urllib.request +from typing import Any, Dict, List, Optional, Union + +import torch + +from mace.cli.fine_tuning_select import select_samples +from mace.tools.scripts_utils import ( + SubsetCollection, + dict_to_namespace, + get_dataset_from_xyz, +) + + +@dataclasses.dataclass +class HeadConfig: + head_name: str + train_file: Optional[str] = None + valid_file: Optional[str] = None + test_file: Optional[str] = None + test_dir: Optional[str] = None + E0s: Optional[Any] = None + statistics_file: Optional[str] = None + valid_fraction: Optional[float] = None + config_type_weights: Optional[Dict[str, float]] = None + energy_key: Optional[str] = None + forces_key: Optional[str] = None + stress_key: Optional[str] = None + virials_key: Optional[str] = None + dipole_key: Optional[str] = None + charges_key: Optional[str] = None + keep_isolated_atoms: Optional[bool] = None + atomic_numbers: Optional[Union[List[int], List[str]]] = None + mean: Optional[float] = None + std: Optional[float] = None + avg_num_neighbors: Optional[float] = None + compute_avg_num_neighbors: Optional[bool] = None + collections: Optional[SubsetCollection] = None + train_loader: torch.utils.data.DataLoader = None + z_table: Optional[Any] = None + atomic_energies_dict: Optional[Dict[str, float]] = None + + +def dict_head_to_dataclass( + head: Dict[str, Any], head_name: str, args: argparse.Namespace +) -> HeadConfig: + + return HeadConfig( + head_name=head_name, + train_file=head.get("train_file", args.train_file), + valid_file=head.get("valid_file", args.valid_file), + test_file=head.get("test_file", None), + test_dir=head.get("test_dir", None), + E0s=head.get("E0s", args.E0s), + statistics_file=head.get("statistics_file", args.statistics_file), + valid_fraction=head.get("valid_fraction", args.valid_fraction), + config_type_weights=head.get("config_type_weights", args.config_type_weights), + compute_avg_num_neighbors=head.get( + "compute_avg_num_neighbors", args.compute_avg_num_neighbors + ), + atomic_numbers=head.get("atomic_numbers", args.atomic_numbers), + mean=head.get("mean", args.mean), + std=head.get("std", args.std), + avg_num_neighbors=head.get("avg_num_neighbors", args.avg_num_neighbors), + energy_key=head.get("energy_key", args.energy_key), + forces_key=head.get("forces_key", args.forces_key), + stress_key=head.get("stress_key", args.stress_key), + virials_key=head.get("virials_key", args.virials_key), + dipole_key=head.get("dipole_key", args.dipole_key), + charges_key=head.get("charges_key", args.charges_key), + keep_isolated_atoms=head.get("keep_isolated_atoms", args.keep_isolated_atoms), + ) + + +def prepare_default_head(args: argparse.Namespace) -> Dict[str, Any]: + return { + "default": { + "train_file": args.train_file, + "valid_file": args.valid_file, + "test_file": args.test_file, + "test_dir": args.test_dir, + "E0s": args.E0s, + "statistics_file": args.statistics_file, + "valid_fraction": args.valid_fraction, + "config_type_weights": args.config_type_weights, + "energy_key": args.energy_key, + "forces_key": args.forces_key, + "stress_key": args.stress_key, + "virials_key": args.virials_key, + "dipole_key": args.dipole_key, + "charges_key": args.charges_key, + "keep_isolated_atoms": args.keep_isolated_atoms, + } + } + + +def assemble_mp_data( + args: argparse.Namespace, tag: str, head_configs: List[HeadConfig] +) -> Dict[str, Any]: + try: + checkpoint_url = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mp_traj_combined.xyz" + descriptors_url = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/descriptors.npy" + cache_dir = os.path.expanduser("~/.cache/mace") + checkpoint_url_name = "".join( + c for c in os.path.basename(checkpoint_url) if c.isalnum() or c in "_" + ) + cached_dataset_path = f"{cache_dir}/{checkpoint_url_name}" + descriptors_url_name = "".join( + c for c in os.path.basename(descriptors_url) if c.isalnum() or c in "_" + ) + cached_descriptors_path = f"{cache_dir}/{descriptors_url_name}" + if not os.path.isfile(cached_dataset_path): + os.makedirs(cache_dir, exist_ok=True) + # download and save to disk + logging.info("Downloading MP structures for finetuning") + _, http_msg = urllib.request.urlretrieve( + checkpoint_url, cached_dataset_path + ) + if "Content-Type: text/html" in http_msg: + raise RuntimeError( + f"Dataset download failed, please check the URL {checkpoint_url}" + ) + logging.info(f"Materials Project dataset to {cached_dataset_path}") + if not os.path.isfile(cached_descriptors_path): + os.makedirs(cache_dir, exist_ok=True) + # download and save to disk + logging.info("Downloading MP descriptors for finetuning") + _, http_msg = urllib.request.urlretrieve( + descriptors_url, cached_descriptors_path + ) + if "Content-Type: text/html" in http_msg: + raise RuntimeError( + f"Descriptors download failed, please check the URL {descriptors_url}" + ) + logging.info(f"Materials Project descriptors to {cached_descriptors_path}") + dataset_mp = cached_dataset_path + descriptors_mp = cached_descriptors_path + msg = f"Using Materials Project dataset with {dataset_mp}" + logging.info(msg) + msg = f"Using Materials Project descriptors with {descriptors_mp}" + logging.info(msg) + config_pt_paths = [head.train_file for head in head_configs] + args_samples = { + "configs_pt": dataset_mp, + "configs_ft": config_pt_paths, + "num_samples": args.num_samples_pt, + "seed": args.seed, + "model": args.foundation_model, + "head_pt": "pbe_mp", + "head_ft": "Default", + "weight_pt": args.weight_pt_head, + "weight_ft": 1.0, + "filtering_type": "combination", + "output": f"mp_finetuning-{tag}.xyz", + "descriptors": descriptors_mp, + "subselect": args.subselect_pt, + "device": args.device, + "default_dtype": args.default_dtype, + } + select_samples(dict_to_namespace(args_samples)) + collections_mp, _ = get_dataset_from_xyz( + work_dir=args.work_dir, + train_path=f"mp_finetuning-{tag}.xyz", + valid_path=None, + valid_fraction=args.valid_fraction, + config_type_weights=None, + test_path=None, + seed=args.seed, + energy_key="energy", + forces_key="forces", + stress_key="stress", + head_name="pt_head", + virials_key=args.virials_key, + dipole_key=args.dipole_key, + charges_key=args.charges_key, + keep_isolated_atoms=args.keep_isolated_atoms, + ) + return collections_mp + except Exception as exc: + raise RuntimeError( + "Model or descriptors download failed and no local model found" + ) from exc diff --git a/mace/tools/scatter.py b/mace/tools/scatter.py new file mode 100644 index 00000000..7e1139a9 --- /dev/null +++ b/mace/tools/scatter.py @@ -0,0 +1,112 @@ +"""basic scatter_sum operations from torch_scatter from +https://github.com/mir-group/pytorch_runstats/blob/main/torch_runstats/scatter_sum.py +Using code from https://github.com/rusty1s/pytorch_scatter, but cut down to avoid a dependency. +PyTorch plans to move these features into the main repo, but until then, +to make installation simpler, we need this pure python set of wrappers +that don't require installing PyTorch C++ extensions. +See https://github.com/pytorch/pytorch/issues/63780. +""" + +from typing import Optional + +import torch + + +def _broadcast(src: torch.Tensor, other: torch.Tensor, dim: int): + if dim < 0: + dim = other.dim() + dim + if src.dim() == 1: + for _ in range(0, dim): + src = src.unsqueeze(0) + for _ in range(src.dim(), other.dim()): + src = src.unsqueeze(-1) + src = src.expand_as(other) + return src + + +def scatter_sum( + src: torch.Tensor, + index: torch.Tensor, + dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None, + reduce: str = "sum", +) -> torch.Tensor: + assert reduce == "sum" # for now, TODO + index = _broadcast(index, src, dim) + if out is None: + size = list(src.size()) + if dim_size is not None: + size[dim] = dim_size + elif index.numel() == 0: + size[dim] = 0 + else: + size[dim] = int(index.max()) + 1 + out = torch.zeros(size, dtype=src.dtype, device=src.device) + return out.scatter_add_(dim, index, src) + else: + return out.scatter_add_(dim, index, src) + + +def scatter_std( + src: torch.Tensor, + index: torch.Tensor, + dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None, + unbiased: bool = True, +) -> torch.Tensor: + if out is not None: + dim_size = out.size(dim) + + if dim < 0: + dim = src.dim() + dim + + count_dim = dim + if index.dim() <= dim: + count_dim = index.dim() - 1 + + ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) + count = scatter_sum(ones, index, count_dim, dim_size=dim_size) + + index = _broadcast(index, src, dim) + tmp = scatter_sum(src, index, dim, dim_size=dim_size) + count = _broadcast(count, tmp, dim).clamp(1) + mean = tmp.div(count) + + var = src - mean.gather(dim, index) + var = var * var + out = scatter_sum(var, index, dim, out, dim_size) + + if unbiased: + count = count.sub(1).clamp_(1) + out = out.div(count + 1e-6).sqrt() + + return out + + +def scatter_mean( + src: torch.Tensor, + index: torch.Tensor, + dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None, +) -> torch.Tensor: + out = scatter_sum(src, index, dim, out, dim_size) + dim_size = out.size(dim) + + index_dim = dim + if index_dim < 0: + index_dim = index_dim + src.dim() + if index.dim() <= index_dim: + index_dim = index.dim() - 1 + + ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) + count = scatter_sum(ones, index, index_dim, None, dim_size) + count[count < 1] = 1 + count = _broadcast(count, out, dim) + if out.is_floating_point(): + out.true_divide_(count) + else: + out.div_(count, rounding_mode="floor") + return out diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py new file mode 100644 index 00000000..be96558d --- /dev/null +++ b/mace/tools/scripts_utils.py @@ -0,0 +1,785 @@ +########################################################################################### +# Training utils +# Authors: David Kovacs, Ilyes Batatia +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import argparse +import ast +import dataclasses +import json +import logging +import os +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import torch +import torch.distributed +from e3nn import o3 +from torch.optim.swa_utils import SWALR, AveragedModel + +from mace import data, modules, tools +from mace.tools.train import SWAContainer + + +@dataclasses.dataclass +class SubsetCollection: + train: data.Configurations + valid: data.Configurations + tests: List[Tuple[str, data.Configurations]] + + +def get_dataset_from_xyz( + work_dir: str, + train_path: str, + valid_path: Optional[str], + valid_fraction: float, + config_type_weights: Dict, + test_path: str = None, + seed: int = 1234, + keep_isolated_atoms: bool = False, + head_name: str = "Default", + energy_key: str = "REF_energy", + forces_key: str = "REF_forces", + stress_key: str = "REF_stress", + virials_key: str = "virials", + dipole_key: str = "dipoles", + charges_key: str = "charges", + head_key: str = "head", +) -> Tuple[SubsetCollection, Optional[Dict[int, float]]]: + """Load training and test dataset from xyz file""" + atomic_energies_dict, all_train_configs = data.load_from_xyz( + file_path=train_path, + config_type_weights=config_type_weights, + energy_key=energy_key, + forces_key=forces_key, + stress_key=stress_key, + virials_key=virials_key, + dipole_key=dipole_key, + charges_key=charges_key, + head_key=head_key, + extract_atomic_energies=True, + keep_isolated_atoms=keep_isolated_atoms, + head_name=head_name, + ) + logging.info( + f"Training set [{len(all_train_configs)} configs, {np.sum([1 if config.energy else 0 for config in all_train_configs])} energy, {np.sum([config.forces.size for config in all_train_configs])} forces] loaded from '{train_path}'" + ) + if valid_path is not None: + _, valid_configs = data.load_from_xyz( + file_path=valid_path, + config_type_weights=config_type_weights, + energy_key=energy_key, + forces_key=forces_key, + stress_key=stress_key, + virials_key=virials_key, + dipole_key=dipole_key, + charges_key=charges_key, + head_key=head_key, + extract_atomic_energies=False, + head_name=head_name, + ) + logging.info( + f"Validation set [{len(valid_configs)} configs, {np.sum([1 if config.energy else 0 for config in valid_configs])} energy, {np.sum([config.forces.size for config in valid_configs])} forces] loaded from '{valid_path}'" + ) + train_configs = all_train_configs + else: + train_configs, valid_configs = data.random_train_valid_split( + all_train_configs, valid_fraction, seed, work_dir + ) + logging.info( + f"Validaton set contains {len(valid_configs)} configurations [{np.sum([1 if config.energy else 0 for config in valid_configs])} energy, {np.sum([config.forces.size for config in valid_configs])} forces]" + ) + + test_configs = [] + if test_path is not None: + _, all_test_configs = data.load_from_xyz( + file_path=test_path, + config_type_weights=config_type_weights, + energy_key=energy_key, + forces_key=forces_key, + dipole_key=dipole_key, + stress_key=stress_key, + virials_key=virials_key, + charges_key=charges_key, + head_key=head_key, + extract_atomic_energies=False, + head_name=head_name, + ) + # create list of tuples (config_type, list(Atoms)) + test_configs = data.test_config_types(all_test_configs) + logging.info( + f"Test set ({len(all_test_configs)} configs) loaded from '{test_path}':" + ) + for name, tmp_configs in test_configs: + logging.info( + f"{name}: {len(tmp_configs)} configs, {np.sum([1 if config.energy else 0 for config in tmp_configs])} energy, {np.sum([config.forces.size for config in tmp_configs])} forces" + ) + + return ( + SubsetCollection(train=train_configs, valid=valid_configs, tests=test_configs), + atomic_energies_dict, + ) + + +def get_config_type_weights(ct_weights): + """ + Get config type weights from command line argument + """ + try: + config_type_weights = ast.literal_eval(ct_weights) + assert isinstance(config_type_weights, dict) + except Exception as e: # pylint: disable=W0703 + logging.warning( + f"Config type weights not specified correctly ({e}), using Default" + ) + config_type_weights = {"Default": 1.0} + return config_type_weights + + +def print_git_commit(): + try: + import git + + repo = git.Repo(search_parent_directories=True) + commit = repo.head.commit.hexsha + logging.debug(f"Current Git commit: {commit}") + return commit + except Exception as e: # pylint: disable=W0703 + logging.debug(f"Error accessing Git repository: {e}") + return "None" + + +def extract_config_mace_model(model: torch.nn.Module) -> Dict[str, Any]: + if model.__class__.__name__ != "ScaleShiftMACE": + return {"error": "Model is not a ScaleShiftMACE model"} + + def radial_to_name(radial_type): + if radial_type == "BesselBasis": + return "bessel" + if radial_type == "GaussianBasis": + return "gaussian" + if radial_type == "ChebychevBasis": + return "chebyshev" + return radial_type + + def radial_to_transform(radial): + if not hasattr(radial, "distance_transform"): + return None + if radial.distance_transform.__class__.__name__ == "AgnesiTransform": + return "Agnesi" + if radial.distance_transform.__class__.__name__ == "SoftTransform": + return "Soft" + return radial.distance_transform.__class__.__name__ + + scale = model.scale_shift.scale + shift = model.scale_shift.shift + config = { + "r_max": model.r_max.item(), + "num_bessel": len(model.radial_embedding.bessel_fn.bessel_weights), + "num_polynomial_cutoff": model.radial_embedding.cutoff_fn.p.item(), + "max_ell": model.spherical_harmonics._lmax, # pylint: disable=protected-access + "interaction_cls": model.interactions[-1].__class__, + "interaction_cls_first": model.interactions[0].__class__, + "num_interactions": model.num_interactions.item(), + "num_elements": len(model.atomic_numbers), + "hidden_irreps": o3.Irreps(str(model.products[0].linear.irreps_out)), + "MLP_irreps": ( + o3.Irreps(str(model.readouts[-1].hidden_irreps)) + if model.num_interactions.item() > 1 + else 1 + ), + "gate": ( + model.readouts[-1] # pylint: disable=protected-access + .non_linearity._modules["acts"][0] + .f + if model.num_interactions.item() > 1 + else None + ), + "atomic_energies": model.atomic_energies_fn.atomic_energies.cpu().numpy(), + "avg_num_neighbors": model.interactions[0].avg_num_neighbors, + "atomic_numbers": model.atomic_numbers, + "correlation": len( + model.products[0].symmetric_contractions.contractions[0].weights + ) + + 1, + "radial_type": radial_to_name( + model.radial_embedding.bessel_fn.__class__.__name__ + ), + "radial_MLP": model.interactions[0].conv_tp_weights.hs[1:-1], + "pair_repulsion": hasattr(model, "pair_repulsion_fn"), + "distance_transform": radial_to_transform(model.radial_embedding), + "atomic_inter_scale": scale.cpu().numpy(), + "atomic_inter_shift": shift.cpu().numpy(), + } + return config + + +def extract_load(f: str, map_location: str = "cpu") -> torch.nn.Module: + return extract_model( + torch.load(f=f, map_location=map_location), map_location=map_location + ) + + +def remove_pt_head( + model: torch.nn.Module, head_to_keep: Optional[str] = None +) -> torch.nn.Module: + """Converts a multihead MACE model to a single head model by removing the pretraining head. + + Args: + model (ScaleShiftMACE): The multihead MACE model to convert + head_to_keep (Optional[str]): The name of the head to keep. If None, keeps the first non-PT head. + + Returns: + ScaleShiftMACE: A new MACE model with only the specified head + + Raises: + ValueError: If the model is not a multihead model or if the specified head is not found + """ + if not hasattr(model, "heads") or len(model.heads) <= 1: + raise ValueError("Model must be a multihead model with more than one head") + + # Get index of head to keep + if head_to_keep is None: + # Find first non-PT head + try: + head_idx = next(i for i, h in enumerate(model.heads) if h != "pt_head") + except StopIteration as e: + raise ValueError("No non-PT head found in model") from e + else: + try: + head_idx = model.heads.index(head_to_keep) + except ValueError as e: + raise ValueError(f"Head {head_to_keep} not found in model") from e + + # Extract config and modify for single head + model_config = extract_config_mace_model(model) + model_config["heads"] = [model.heads[head_idx]] + model_config["atomic_energies"] = ( + model.atomic_energies_fn.atomic_energies[head_idx] + .unsqueeze(0) + .detach() + .cpu() + .numpy() + ) + model_config["atomic_inter_scale"] = model.scale_shift.scale[head_idx].item() + model_config["atomic_inter_shift"] = model.scale_shift.shift[head_idx].item() + mlp_count_irreps = model_config["MLP_irreps"].count((0, 1)) // len(model.heads) + model_config["MLP_irreps"] = o3.Irreps(f"{mlp_count_irreps}x0e") + + new_model = model.__class__(**model_config) + state_dict = model.state_dict() + new_state_dict = {} + + for name, param in state_dict.items(): + if "atomic_energies" in name: + new_state_dict[name] = param[head_idx : head_idx + 1] + elif "scale" in name or "shift" in name: + new_state_dict[name] = param[head_idx : head_idx + 1] + elif "readouts" in name: + channels_per_head = param.shape[0] // len(model.heads) + start_idx = head_idx * channels_per_head + end_idx = start_idx + channels_per_head + if "linear_2.weight" in name: + end_idx = start_idx + channels_per_head // 2 + # if ( + # "readouts.0.linear.weight" in name + # or "readouts.1.linear_2.weight" in name + # ): + # new_state_dict[name] = param[start_idx:end_idx] / ( + # len(model.heads) ** 0.5 + # ) + if "readouts.0.linear.weight" in name: + new_state_dict[name] = param.reshape(-1, len(model.heads))[ + :, head_idx + ].flatten() + elif "readouts.1.linear_1.weight" in name: + new_state_dict[name] = param.reshape( + -1, len(model.heads), mlp_count_irreps + )[:, head_idx, :].flatten() + elif "readouts.1.linear_2.weight" in name: + new_state_dict[name] = param.reshape( + len(model.heads), -1, len(model.heads) + )[head_idx, :, head_idx].flatten() / (len(model.heads) ** 0.5) + else: + new_state_dict[name] = param[start_idx:end_idx] + + else: + new_state_dict[name] = param + + # Load state dict into new model + new_model.load_state_dict(new_state_dict) + + return new_model + + +def extract_model(model: torch.nn.Module, map_location: str = "cpu") -> torch.nn.Module: + model_copy = model.__class__(**extract_config_mace_model(model)) + model_copy.load_state_dict(model.state_dict()) + return model_copy.to(map_location) + + +def convert_to_json_format(dict_input): + for key, value in dict_input.items(): + if isinstance(value, (np.ndarray, torch.Tensor)): + dict_input[key] = value.tolist() + # # check if the value is a class and convert it to a string + elif hasattr(value, "__class__"): + dict_input[key] = str(value) + return dict_input + + +def convert_from_json_format(dict_input): + dict_output = dict_input.copy() + if ( + dict_input["interaction_cls"] + == "" + ): + dict_output["interaction_cls"] = ( + modules.blocks.RealAgnosticResidualInteractionBlock + ) + if ( + dict_input["interaction_cls"] + == "" + ): + dict_output["interaction_cls"] = modules.blocks.RealAgnosticInteractionBlock + if ( + dict_input["interaction_cls_first"] + == "" + ): + dict_output["interaction_cls_first"] = ( + modules.blocks.RealAgnosticResidualInteractionBlock + ) + if ( + dict_input["interaction_cls_first"] + == "" + ): + dict_output["interaction_cls_first"] = ( + modules.blocks.RealAgnosticInteractionBlock + ) + dict_output["r_max"] = float(dict_input["r_max"]) + dict_output["num_bessel"] = int(dict_input["num_bessel"]) + dict_output["num_polynomial_cutoff"] = float(dict_input["num_polynomial_cutoff"]) + dict_output["max_ell"] = int(dict_input["max_ell"]) + dict_output["num_interactions"] = int(dict_input["num_interactions"]) + dict_output["num_elements"] = int(dict_input["num_elements"]) + dict_output["hidden_irreps"] = o3.Irreps(dict_input["hidden_irreps"]) + dict_output["MLP_irreps"] = o3.Irreps(dict_input["MLP_irreps"]) + dict_output["avg_num_neighbors"] = float(dict_input["avg_num_neighbors"]) + dict_output["gate"] = torch.nn.functional.silu + dict_output["atomic_energies"] = np.array(dict_input["atomic_energies"]) + dict_output["atomic_numbers"] = dict_input["atomic_numbers"] + dict_output["correlation"] = int(dict_input["correlation"]) + dict_output["radial_type"] = dict_input["radial_type"] + dict_output["radial_MLP"] = ast.literal_eval(dict_input["radial_MLP"]) + dict_output["pair_repulsion"] = ast.literal_eval(dict_input["pair_repulsion"]) + dict_output["distance_transform"] = dict_input["distance_transform"] + dict_output["atomic_inter_scale"] = float(dict_input["atomic_inter_scale"]) + dict_output["atomic_inter_shift"] = float(dict_input["atomic_inter_shift"]) + + return dict_output + + +def load_from_json(f: str, map_location: str = "cpu") -> torch.nn.Module: + extra_files_extract = {"commit.txt": None, "config.json": None} + model_jit_load = torch.jit.load( + f, _extra_files=extra_files_extract, map_location=map_location + ) + model_load_yaml = modules.ScaleShiftMACE( + **convert_from_json_format(json.loads(extra_files_extract["config.json"])) + ) + model_load_yaml.load_state_dict(model_jit_load.state_dict()) + return model_load_yaml.to(map_location) + + +def get_atomic_energies(E0s, train_collection, z_table) -> dict: + if E0s is not None: + logging.info( + "Isolated Atomic Energies (E0s) not in training file, using command line argument" + ) + if E0s.lower() == "average": + logging.info( + "Computing average Atomic Energies using least squares regression" + ) + # catch if colections.train not defined above + try: + assert train_collection is not None + atomic_energies_dict = data.compute_average_E0s( + train_collection, z_table + ) + except Exception as e: + raise RuntimeError( + f"Could not compute average E0s if no training xyz given, error {e} occured" + ) from e + else: + if E0s.endswith(".json"): + logging.info(f"Loading atomic energies from {E0s}") + with open(E0s, "r", encoding="utf-8") as f: + atomic_energies_dict = json.load(f) + atomic_energies_dict = { + int(key): value for key, value in atomic_energies_dict.items() + } + else: + try: + atomic_energies_eval = ast.literal_eval(E0s) + if not all( + isinstance(value, dict) + for value in atomic_energies_eval.values() + ): + atomic_energies_dict = atomic_energies_eval + else: + atomic_energies_dict = atomic_energies_eval + assert isinstance(atomic_energies_dict, dict) + except Exception as e: + raise RuntimeError( + f"E0s specified invalidly, error {e} occured" + ) from e + else: + raise RuntimeError( + "E0s not found in training file and not specified in command line" + ) + return atomic_energies_dict + + +def get_avg_num_neighbors(head_configs, args, train_loader, device): + if all(head_config.compute_avg_num_neighbors for head_config in head_configs): + logging.info("Computing average number of neighbors") + avg_num_neighbors = modules.compute_avg_num_neighbors(train_loader) + if args.distributed: + num_graphs = torch.tensor(len(train_loader.dataset)).to(device) + num_neighbors = num_graphs * torch.tensor(avg_num_neighbors).to(device) + torch.distributed.all_reduce(num_graphs, op=torch.distributed.ReduceOp.SUM) + torch.distributed.all_reduce( + num_neighbors, op=torch.distributed.ReduceOp.SUM + ) + avg_num_neighbors_out = (num_neighbors / num_graphs).item() + else: + avg_num_neighbors_out = avg_num_neighbors + else: + assert any( + head_config.avg_num_neighbors is not None for head_config in head_configs + ), "Average number of neighbors must be provided in the configuration" + avg_num_neighbors_out = max( + head_config.avg_num_neighbors + for head_config in head_configs + if head_config.avg_num_neighbors is not None + ) + if avg_num_neighbors_out < 2 or avg_num_neighbors_out > 100: + logging.warning( + f"Unusual average number of neighbors: {avg_num_neighbors_out:.1f}" + ) + else: + logging.info(f"Average number of neighbors: {avg_num_neighbors_out}") + return avg_num_neighbors_out + + +def get_loss_fn( + args: argparse.Namespace, + dipole_only: bool, + compute_dipole: bool, +) -> torch.nn.Module: + if args.loss == "weighted": + loss_fn = modules.WeightedEnergyForcesLoss( + energy_weight=args.energy_weight, forces_weight=args.forces_weight + ) + elif args.loss == "forces_only": + loss_fn = modules.WeightedForcesLoss(forces_weight=args.forces_weight) + elif args.loss == "virials": + loss_fn = modules.WeightedEnergyForcesVirialsLoss( + energy_weight=args.energy_weight, + forces_weight=args.forces_weight, + virials_weight=args.virials_weight, + ) + elif args.loss == "stress": + loss_fn = modules.WeightedEnergyForcesStressLoss( + energy_weight=args.energy_weight, + forces_weight=args.forces_weight, + stress_weight=args.stress_weight, + ) + elif args.loss == "huber": + loss_fn = modules.WeightedHuberEnergyForcesStressLoss( + energy_weight=args.energy_weight, + forces_weight=args.forces_weight, + stress_weight=args.stress_weight, + huber_delta=args.huber_delta, + ) + elif args.loss == "universal": + loss_fn = modules.UniversalLoss( + energy_weight=args.energy_weight, + forces_weight=args.forces_weight, + stress_weight=args.stress_weight, + huber_delta=args.huber_delta, + ) + elif args.loss == "dipole": + assert ( + dipole_only is True + ), "dipole loss can only be used with AtomicDipolesMACE model" + loss_fn = modules.DipoleSingleLoss( + dipole_weight=args.dipole_weight, + ) + elif args.loss == "energy_forces_dipole": + assert dipole_only is False and compute_dipole is True + loss_fn = modules.WeightedEnergyForcesDipoleLoss( + energy_weight=args.energy_weight, + forces_weight=args.forces_weight, + dipole_weight=args.dipole_weight, + ) + else: + loss_fn = modules.WeightedEnergyForcesLoss(energy_weight=1.0, forces_weight=1.0) + return loss_fn + + +def get_swa( + args: argparse.Namespace, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + swas: List[bool], + dipole_only: bool = False, +): + assert dipole_only is False, "Stage Two for dipole fitting not implemented" + swas.append(True) + if args.start_swa is None: + args.start_swa = max(1, args.max_num_epochs // 4 * 3) + else: + if args.start_swa >= args.max_num_epochs: + logging.warning( + f"Start Stage Two must be less than max_num_epochs, got {args.start_swa} > {args.max_num_epochs}" + ) + swas[-1] = False + if args.loss == "forces_only": + raise ValueError("Can not select Stage Two with forces only loss.") + if args.loss == "virials": + loss_fn_energy = modules.WeightedEnergyForcesVirialsLoss( + energy_weight=args.swa_energy_weight, + forces_weight=args.swa_forces_weight, + virials_weight=args.swa_virials_weight, + ) + logging.info( + f"Stage Two (after {args.start_swa} epochs) with loss function: {loss_fn_energy}, energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight}, virials_weight: {args.swa_virials_weight} and learning rate : {args.swa_lr}" + ) + elif args.loss == "stress": + loss_fn_energy = modules.WeightedEnergyForcesStressLoss( + energy_weight=args.swa_energy_weight, + forces_weight=args.swa_forces_weight, + stress_weight=args.swa_stress_weight, + ) + logging.info( + f"Stage Two (after {args.start_swa} epochs) with loss function: {loss_fn_energy}, energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight}, stress weight : {args.stress_weight} and learning rate : {args.swa_lr}" + ) + elif args.loss == "energy_forces_dipole": + loss_fn_energy = modules.WeightedEnergyForcesDipoleLoss( + args.swa_energy_weight, + forces_weight=args.swa_forces_weight, + dipole_weight=args.swa_dipole_weight, + ) + logging.info( + f"Stage Two (after {args.start_swa} epochs) with loss function: {loss_fn_energy}, with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight}, dipole weight : {args.swa_dipole_weight} and learning rate : {args.swa_lr}" + ) + elif args.loss == "universal": + loss_fn_energy = modules.UniversalLoss( + energy_weight=args.swa_energy_weight, + forces_weight=args.swa_forces_weight, + stress_weight=args.swa_stress_weight, + huber_delta=args.huber_delta, + ) + logging.info( + f"Stage Two (after {args.start_swa} epochs) with loss function: {loss_fn_energy}, with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight}, stress weight : {args.swa_stress_weight} and learning rate : {args.swa_lr}" + ) + else: + loss_fn_energy = modules.WeightedEnergyForcesLoss( + energy_weight=args.swa_energy_weight, + forces_weight=args.swa_forces_weight, + ) + logging.info( + f"Stage Two (after {args.start_swa} epochs) with loss function: {loss_fn_energy}, with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight} and learning rate : {args.swa_lr}" + ) + swa = SWAContainer( + model=AveragedModel(model), + scheduler=SWALR( + optimizer=optimizer, + swa_lr=args.swa_lr, + anneal_epochs=1, + anneal_strategy="linear", + ), + start=args.start_swa, + loss_fn=loss_fn_energy, + ) + return swa, swas + + +def get_params_options( + args: argparse.Namespace, model: torch.nn.Module +) -> Dict[str, Any]: + decay_interactions = {} + no_decay_interactions = {} + for name, param in model.interactions.named_parameters(): + if "linear.weight" in name or "skip_tp_full.weight" in name: + decay_interactions[name] = param + else: + no_decay_interactions[name] = param + + param_options = dict( + params=[ + { + "name": "embedding", + "params": model.node_embedding.parameters(), + "weight_decay": 0.0, + }, + { + "name": "interactions_decay", + "params": list(decay_interactions.values()), + "weight_decay": args.weight_decay, + }, + { + "name": "interactions_no_decay", + "params": list(no_decay_interactions.values()), + "weight_decay": 0.0, + }, + { + "name": "products", + "params": model.products.parameters(), + "weight_decay": args.weight_decay, + }, + { + "name": "readouts", + "params": model.readouts.parameters(), + "weight_decay": 0.0, + }, + ], + lr=args.lr, + amsgrad=args.amsgrad, + betas=(args.beta, 0.999), + ) + return param_options + + +def get_optimizer( + args: argparse.Namespace, param_options: Dict[str, Any] +) -> torch.optim.Optimizer: + if args.optimizer == "adamw": + optimizer = torch.optim.AdamW(**param_options) + elif args.optimizer == "schedulefree": + try: + from schedulefree import adamw_schedulefree + except ImportError as exc: + raise ImportError( + "`schedulefree` is not installed. Please install it via `pip install schedulefree` or `pip install mace-torch[schedulefree]`" + ) from exc + _param_options = {k: v for k, v in param_options.items() if k != "amsgrad"} + optimizer = adamw_schedulefree.AdamWScheduleFree(**_param_options) + else: + optimizer = torch.optim.Adam(**param_options) + return optimizer + + +def setup_wandb(args: argparse.Namespace): + logging.info("Using Weights and Biases for logging") + import wandb + + wandb_config = {} + args_dict = vars(args) + + for key, value in args_dict.items(): + if isinstance(value, np.ndarray): + args_dict[key] = value.tolist() + + args_dict_json = json.dumps(args_dict) + for key in args.wandb_log_hypers: + wandb_config[key] = args_dict[key] + tools.init_wandb( + project=args.wandb_project, + entity=args.wandb_entity, + name=args.wandb_name, + config=wandb_config, + directory=args.wandb_dir, + ) + wandb.run.summary["params"] = args_dict_json + + +def get_files_with_suffix(dir_path: str, suffix: str) -> List[str]: + return [ + os.path.join(dir_path, f) for f in os.listdir(dir_path) if f.endswith(suffix) + ] + + +def dict_to_array(input_data, heads): + if all(isinstance(value, np.ndarray) for value in input_data.values()): + return np.array([input_data[head] for head in heads]) + if not all(isinstance(value, dict) for value in input_data.values()): + return np.array([[input_data[head]] for head in heads]) + unique_keys = set() + for inner_dict in input_data.values(): + unique_keys.update(inner_dict.keys()) + unique_keys = list(unique_keys) + sorted_keys = sorted([int(key) for key in unique_keys]) + result_array = np.zeros((len(input_data), len(sorted_keys))) + for _, (head_name, inner_dict) in enumerate(input_data.items()): + for key, value in inner_dict.items(): + key_index = sorted_keys.index(int(key)) + head_index = heads.index(head_name) + result_array[head_index][key_index] = value + return result_array + + +class LRScheduler: + def __init__(self, optimizer, args) -> None: + self.scheduler = args.scheduler + self._optimizer_type = ( + args.optimizer + ) # Schedulefree does not need an optimizer but checkpoint handler does. + if args.scheduler == "ExponentialLR": + self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR( + optimizer=optimizer, gamma=args.lr_scheduler_gamma + ) + elif args.scheduler == "ReduceLROnPlateau": + self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer=optimizer, + factor=args.lr_factor, + patience=args.scheduler_patience, + ) + else: + raise RuntimeError(f"Unknown scheduler: '{args.scheduler}'") + + def step(self, metrics=None, epoch=None): # pylint: disable=E1123 + if self._optimizer_type == "schedulefree": + return # In principle, schedulefree optimizer can be used with a scheduler but the paper suggests it's not necessary + if self.scheduler == "ExponentialLR": + self.lr_scheduler.step(epoch=epoch) + elif self.scheduler == "ReduceLROnPlateau": + self.lr_scheduler.step( # pylint: disable=E1123 + metrics=metrics, epoch=epoch + ) + + def __getattr__(self, name): + if name == "step": + return self.step + return getattr(self.lr_scheduler, name) + + +def check_folder_subfolder(folder_path): + entries = os.listdir(folder_path) + for entry in entries: + full_path = os.path.join(folder_path, entry) + if os.path.isdir(full_path): + return True + return False + + +def check_path_ase_read(filename: str) -> str: + filepath = Path(filename) + if filepath.is_dir(): + if len(list(filepath.glob("*.h5")) + list(filepath.glob("*.hdf5"))) == 0: + raise RuntimeError(f"Got directory {filename} with no .h5/.hdf5 files") + return False + if filepath.suffix in (".h5", ".hdf5"): + return False + return True + + +def dict_to_namespace(dictionary): + # Convert the dictionary into an argparse.Namespace + namespace = argparse.Namespace() + for key, value in dictionary.items(): + setattr(namespace, key, value) + return namespace diff --git a/mace/tools/slurm_distributed.py b/mace/tools/slurm_distributed.py new file mode 100644 index 00000000..78de52a1 --- /dev/null +++ b/mace/tools/slurm_distributed.py @@ -0,0 +1,34 @@ +########################################################################################### +# Slurm environment setup for distributed training. +# This code is refactored from rsarm's contribution at: +# https://github.com/Lumi-supercomputer/lumi-reframe-tests/blob/main/checks/apps/deeplearning/pytorch/src/pt_distr_env.py +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import os + +import hostlist + + +class DistributedEnvironment: + def __init__(self): + self._setup_distr_env() + self.master_addr = os.environ["MASTER_ADDR"] + self.master_port = os.environ["MASTER_PORT"] + self.world_size = int(os.environ["WORLD_SIZE"]) + self.local_rank = int(os.environ["LOCAL_RANK"]) + self.rank = int(os.environ["RANK"]) + + def _setup_distr_env(self): + hostname = hostlist.expand_hostlist(os.environ["SLURM_JOB_NODELIST"])[0] + os.environ["MASTER_ADDR"] = hostname + os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "33333") + os.environ["WORLD_SIZE"] = os.environ.get( + "SLURM_NTASKS", + str( + int(os.environ["SLURM_NTASKS_PER_NODE"]) + * int(os.environ["SLURM_NNODES"]) + ), + ) + os.environ["LOCAL_RANK"] = os.environ["SLURM_LOCALID"] + os.environ["RANK"] = os.environ["SLURM_PROCID"] diff --git a/mace/tools/tables_utils.py b/mace/tools/tables_utils.py new file mode 100644 index 00000000..07f41401 --- /dev/null +++ b/mace/tools/tables_utils.py @@ -0,0 +1,241 @@ +import logging +from typing import Dict + +import torch +from prettytable import PrettyTable + +from mace.tools import evaluate + + +def custom_key(key): + """ + Helper function to sort the keys of the data loader dictionary + to ensure that the training set, and validation set + are evaluated first + """ + if key == "train": + return (0, key) + if key == "valid": + return (1, key) + return (2, key) + + +def create_error_table( + table_type: str, + all_data_loaders: dict, + model: torch.nn.Module, + loss_fn: torch.nn.Module, + output_args: Dict[str, bool], + log_wandb: bool, + device: str, + distributed: bool = False, +) -> PrettyTable: + if log_wandb: + import wandb + table = PrettyTable() + if table_type == "TotalRMSE": + table.field_names = [ + "config_type", + "RMSE E / meV", + "RMSE F / meV / A", + "relative F RMSE %", + ] + elif table_type == "PerAtomRMSE": + table.field_names = [ + "config_type", + "RMSE E / meV / atom", + "RMSE F / meV / A", + "relative F RMSE %", + ] + elif table_type == "PerAtomRMSEstressvirials": + table.field_names = [ + "config_type", + "RMSE E / meV / atom", + "RMSE F / meV / A", + "relative F RMSE %", + "RMSE Stress (Virials) / meV / A (A^3)", + ] + elif table_type == "PerAtomMAEstressvirials": + table.field_names = [ + "config_type", + "MAE E / meV / atom", + "MAE F / meV / A", + "relative F MAE %", + "MAE Stress (Virials) / meV / A (A^3)", + ] + elif table_type == "TotalMAE": + table.field_names = [ + "config_type", + "MAE E / meV", + "MAE F / meV / A", + "relative F MAE %", + ] + elif table_type == "PerAtomMAE": + table.field_names = [ + "config_type", + "MAE E / meV / atom", + "MAE F / meV / A", + "relative F MAE %", + ] + elif table_type == "DipoleRMSE": + table.field_names = [ + "config_type", + "RMSE MU / mDebye / atom", + "relative MU RMSE %", + ] + elif table_type == "DipoleMAE": + table.field_names = [ + "config_type", + "MAE MU / mDebye / atom", + "relative MU MAE %", + ] + elif table_type == "EnergyDipoleRMSE": + table.field_names = [ + "config_type", + "RMSE E / meV / atom", + "RMSE F / meV / A", + "rel F RMSE %", + "RMSE MU / mDebye / atom", + "rel MU RMSE %", + ] + + for name in sorted(all_data_loaders, key=custom_key): + data_loader = all_data_loaders[name] + logging.info(f"Evaluating {name} ...") + _, metrics = evaluate( + model, + loss_fn=loss_fn, + data_loader=data_loader, + output_args=output_args, + device=device, + ) + if distributed: + torch.distributed.barrier() + + del data_loader + torch.cuda.empty_cache() + if log_wandb: + wandb_log_dict = { + name + + "_final_rmse_e_per_atom": metrics["rmse_e_per_atom"] + * 1e3, # meV / atom + name + "_final_rmse_f": metrics["rmse_f"] * 1e3, # meV / A + name + "_final_rel_rmse_f": metrics["rel_rmse_f"], + } + wandb.log(wandb_log_dict) + if table_type == "TotalRMSE": + table.add_row( + [ + name, + f"{metrics['rmse_e'] * 1000:8.1f}", + f"{metrics['rmse_f'] * 1000:8.1f}", + f"{metrics['rel_rmse_f']:8.2f}", + ] + ) + elif table_type == "PerAtomRMSE": + table.add_row( + [ + name, + f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", + f"{metrics['rmse_f'] * 1000:8.1f}", + f"{metrics['rel_rmse_f']:8.2f}", + ] + ) + elif ( + table_type == "PerAtomRMSEstressvirials" + and metrics["rmse_stress"] is not None + ): + table.add_row( + [ + name, + f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", + f"{metrics['rmse_f'] * 1000:8.1f}", + f"{metrics['rel_rmse_f']:8.2f}", + f"{metrics['rmse_stress'] * 1000:8.1f}", + ] + ) + elif ( + table_type == "PerAtomRMSEstressvirials" + and metrics["rmse_virials"] is not None + ): + table.add_row( + [ + name, + f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", + f"{metrics['rmse_f'] * 1000:8.1f}", + f"{metrics['rel_rmse_f']:8.2f}", + f"{metrics['rmse_virials'] * 1000:8.1f}", + ] + ) + elif ( + table_type == "PerAtomMAEstressvirials" + and metrics["mae_stress"] is not None + ): + table.add_row( + [ + name, + f"{metrics['mae_e_per_atom'] * 1000:8.1f}", + f"{metrics['mae_f'] * 1000:8.1f}", + f"{metrics['rel_mae_f']:8.2f}", + f"{metrics['mae_stress'] * 1000:8.1f}", + ] + ) + elif ( + table_type == "PerAtomMAEstressvirials" + and metrics["mae_virials"] is not None + ): + table.add_row( + [ + name, + f"{metrics['mae_e_per_atom'] * 1000:8.1f}", + f"{metrics['mae_f'] * 1000:8.1f}", + f"{metrics['rel_mae_f']:8.2f}", + f"{metrics['mae_virials'] * 1000:8.1f}", + ] + ) + elif table_type == "TotalMAE": + table.add_row( + [ + name, + f"{metrics['mae_e'] * 1000:8.1f}", + f"{metrics['mae_f'] * 1000:8.1f}", + f"{metrics['rel_mae_f']:8.2f}", + ] + ) + elif table_type == "PerAtomMAE": + table.add_row( + [ + name, + f"{metrics['mae_e_per_atom'] * 1000:8.1f}", + f"{metrics['mae_f'] * 1000:8.1f}", + f"{metrics['rel_mae_f']:8.2f}", + ] + ) + elif table_type == "DipoleRMSE": + table.add_row( + [ + name, + f"{metrics['rmse_mu_per_atom'] * 1000:8.2f}", + f"{metrics['rel_rmse_mu']:8.1f}", + ] + ) + elif table_type == "DipoleMAE": + table.add_row( + [ + name, + f"{metrics['mae_mu_per_atom'] * 1000:8.2f}", + f"{metrics['rel_mae_mu']:8.1f}", + ] + ) + elif table_type == "EnergyDipoleRMSE": + table.add_row( + [ + name, + f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", + f"{metrics['rmse_f'] * 1000:8.1f}", + f"{metrics['rel_rmse_f']:8.1f}", + f"{metrics['rmse_mu_per_atom'] * 1000:8.1f}", + f"{metrics['rel_rmse_mu']:8.1f}", + ] + ) + return table diff --git a/mace/tools/torch_geometric/README.md b/mace/tools/torch_geometric/README.md new file mode 100644 index 00000000..261ebbbc --- /dev/null +++ b/mace/tools/torch_geometric/README.md @@ -0,0 +1,12 @@ +# Trimmed-down `pytorch_geometric` + +MACE uses [`pytorch_geometric`](https://pytorch-geometric.readthedocs.io/en/latest/) [1, 2] framework. However as only use a very limited subset of that library: the most basic graph data structures. + +We follow the same approach to NequIP (https://github.com/mir-group/nequip/tree/main/nequip) and copy their code here. + +To avoid adding a large number of unnecessary second-degree dependencies, and to simplify installation, we include and modify here the small subset of `torch_geometric` that is necessary for our code. + +We are grateful to the developers of PyTorch Geometric for their ongoing and very useful work on graph learning with PyTorch. + +[1] Fey, M., & Lenssen, J. E. (2019). Fast Graph Representation Learning with PyTorch Geometric (Version 2.0.1) [Computer software]. https://github.com/pyg-team/pytorch_geometric
+[2] https://arxiv.org/abs/1903.02428 diff --git a/mace/tools/torch_geometric/__init__.py b/mace/tools/torch_geometric/__init__.py new file mode 100644 index 00000000..486f0d09 --- /dev/null +++ b/mace/tools/torch_geometric/__init__.py @@ -0,0 +1,7 @@ +from .batch import Batch +from .data import Data +from .dataloader import DataLoader +from .dataset import Dataset +from .seed import seed_everything + +__all__ = ["Batch", "Data", "Dataset", "DataLoader", "seed_everything"] diff --git a/mace/tools/torch_geometric/batch.py b/mace/tools/torch_geometric/batch.py new file mode 100644 index 00000000..be5ec9d0 --- /dev/null +++ b/mace/tools/torch_geometric/batch.py @@ -0,0 +1,257 @@ +from collections.abc import Sequence +from typing import List + +import numpy as np +import torch +from torch import Tensor + +from .data import Data +from .dataset import IndexType + + +class Batch(Data): + r"""A plain old python object modeling a batch of graphs as one big + (disconnected) graph. With :class:`torch_geometric.data.Data` being the + base class, all its methods can also be used here. + In addition, single graphs can be reconstructed via the assignment vector + :obj:`batch`, which maps each node to its respective graph identifier. + """ + + def __init__(self, batch=None, ptr=None, **kwargs): + super(Batch, self).__init__(**kwargs) + + for key, item in kwargs.items(): + if key == "num_nodes": + self.__num_nodes__ = item + else: + self[key] = item + + self.batch = batch + self.ptr = ptr + self.__data_class__ = Data + self.__slices__ = None + self.__cumsum__ = None + self.__cat_dims__ = None + self.__num_nodes_list__ = None + self.__num_graphs__ = None + + @classmethod + def from_data_list(cls, data_list, follow_batch=[], exclude_keys=[]): + r"""Constructs a batch object from a python list holding + :class:`torch_geometric.data.Data` objects. + The assignment vector :obj:`batch` is created on the fly. + Additionally, creates assignment batch vectors for each key in + :obj:`follow_batch`. + Will exclude any keys given in :obj:`exclude_keys`.""" + + keys = list(set(data_list[0].keys) - set(exclude_keys)) + assert "batch" not in keys and "ptr" not in keys + + batch = cls() + for key in data_list[0].__dict__.keys(): + if key[:2] != "__" and key[-2:] != "__": + batch[key] = None + + batch.__num_graphs__ = len(data_list) + batch.__data_class__ = data_list[0].__class__ + for key in keys + ["batch"]: + batch[key] = [] + batch["ptr"] = [0] + + device = None + slices = {key: [0] for key in keys} + cumsum = {key: [0] for key in keys} + cat_dims = {} + num_nodes_list = [] + for i, data in enumerate(data_list): + for key in keys: + item = data[key] + + # Increase values by `cumsum` value. + cum = cumsum[key][-1] + if isinstance(item, Tensor) and item.dtype != torch.bool: + if not isinstance(cum, int) or cum != 0: + item = item + cum + elif isinstance(item, (int, float)): + item = item + cum + + # Gather the size of the `cat` dimension. + size = 1 + cat_dim = data.__cat_dim__(key, data[key]) + # 0-dimensional tensors have no dimension along which to + # concatenate, so we set `cat_dim` to `None`. + if isinstance(item, Tensor) and item.dim() == 0: + cat_dim = None + cat_dims[key] = cat_dim + + # Add a batch dimension to items whose `cat_dim` is `None`: + if isinstance(item, Tensor) and cat_dim is None: + cat_dim = 0 # Concatenate along this new batch dimension. + item = item.unsqueeze(0) + device = item.device + elif isinstance(item, Tensor): + size = item.size(cat_dim) + device = item.device + + batch[key].append(item) # Append item to the attribute list. + + slices[key].append(size + slices[key][-1]) + inc = data.__inc__(key, item) + if isinstance(inc, (tuple, list)): + inc = torch.tensor(inc) + cumsum[key].append(inc + cumsum[key][-1]) + + if key in follow_batch: + if isinstance(size, Tensor): + for j, size in enumerate(size.tolist()): + tmp = f"{key}_{j}_batch" + batch[tmp] = [] if i == 0 else batch[tmp] + batch[tmp].append( + torch.full((size,), i, dtype=torch.long, device=device) + ) + else: + tmp = f"{key}_batch" + batch[tmp] = [] if i == 0 else batch[tmp] + batch[tmp].append( + torch.full((size,), i, dtype=torch.long, device=device) + ) + + if hasattr(data, "__num_nodes__"): + num_nodes_list.append(data.__num_nodes__) + else: + num_nodes_list.append(None) + + num_nodes = data.num_nodes + if num_nodes is not None: + item = torch.full((num_nodes,), i, dtype=torch.long, device=device) + batch.batch.append(item) + batch.ptr.append(batch.ptr[-1] + num_nodes) + + batch.batch = None if len(batch.batch) == 0 else batch.batch + batch.ptr = None if len(batch.ptr) == 1 else batch.ptr + batch.__slices__ = slices + batch.__cumsum__ = cumsum + batch.__cat_dims__ = cat_dims + batch.__num_nodes_list__ = num_nodes_list + + ref_data = data_list[0] + for key in batch.keys: + items = batch[key] + item = items[0] + cat_dim = ref_data.__cat_dim__(key, item) + cat_dim = 0 if cat_dim is None else cat_dim + if isinstance(item, Tensor): + batch[key] = torch.cat(items, cat_dim) + elif isinstance(item, (int, float)): + batch[key] = torch.tensor(items) + + # if torch_geometric.is_debug_enabled(): + # batch.debug() + + return batch.contiguous() + + def get_example(self, idx: int) -> Data: + r"""Reconstructs the :class:`torch_geometric.data.Data` object at index + :obj:`idx` from the batch object. + The batch object must have been created via :meth:`from_data_list` in + order to be able to reconstruct the initial objects.""" + + if self.__slices__ is None: + raise RuntimeError( + ( + "Cannot reconstruct data list from batch because the batch " + "object was not created using `Batch.from_data_list()`." + ) + ) + + data = self.__data_class__() + idx = self.num_graphs + idx if idx < 0 else idx + + for key in self.__slices__.keys(): + item = self[key] + if self.__cat_dims__[key] is None: + # The item was concatenated along a new batch dimension, + # so just index in that dimension: + item = item[idx] + else: + # Narrow the item based on the values in `__slices__`. + if isinstance(item, Tensor): + dim = self.__cat_dims__[key] + start = self.__slices__[key][idx] + end = self.__slices__[key][idx + 1] + item = item.narrow(dim, start, end - start) + else: + start = self.__slices__[key][idx] + end = self.__slices__[key][idx + 1] + item = item[start:end] + item = item[0] if len(item) == 1 else item + + # Decrease its value by `cumsum` value: + cum = self.__cumsum__[key][idx] + if isinstance(item, Tensor): + if not isinstance(cum, int) or cum != 0: + item = item - cum + elif isinstance(item, (int, float)): + item = item - cum + + data[key] = item + + if self.__num_nodes_list__[idx] is not None: + data.num_nodes = self.__num_nodes_list__[idx] + + return data + + def index_select(self, idx: IndexType) -> List[Data]: + if isinstance(idx, slice): + idx = list(range(self.num_graphs)[idx]) + + elif isinstance(idx, Tensor) and idx.dtype == torch.long: + idx = idx.flatten().tolist() + + elif isinstance(idx, Tensor) and idx.dtype == torch.bool: + idx = idx.flatten().nonzero(as_tuple=False).flatten().tolist() + + elif isinstance(idx, np.ndarray) and idx.dtype == np.int64: + idx = idx.flatten().tolist() + + elif isinstance(idx, np.ndarray) and idx.dtype == np.bool: + idx = idx.flatten().nonzero()[0].flatten().tolist() + + elif isinstance(idx, Sequence) and not isinstance(idx, str): + pass + + else: + raise IndexError( + f"Only integers, slices (':'), list, tuples, torch.tensor and " + f"np.ndarray of dtype long or bool are valid indices (got " + f"'{type(idx).__name__}')" + ) + + return [self.get_example(i) for i in idx] + + def __getitem__(self, idx): + if isinstance(idx, str): + return super(Batch, self).__getitem__(idx) + elif isinstance(idx, (int, np.integer)): + return self.get_example(idx) + else: + return self.index_select(idx) + + def to_data_list(self) -> List[Data]: + r"""Reconstructs the list of :class:`torch_geometric.data.Data` objects + from the batch object. + The batch object must have been created via :meth:`from_data_list` in + order to be able to reconstruct the initial objects.""" + return [self.get_example(i) for i in range(self.num_graphs)] + + @property + def num_graphs(self) -> int: + """Returns the number of graphs in the batch.""" + if self.__num_graphs__ is not None: + return self.__num_graphs__ + elif self.ptr is not None: + return self.ptr.numel() - 1 + elif self.batch is not None: + return int(self.batch.max()) + 1 + else: + raise ValueError diff --git a/mace/tools/torch_geometric/data.py b/mace/tools/torch_geometric/data.py new file mode 100644 index 00000000..4e1ab308 --- /dev/null +++ b/mace/tools/torch_geometric/data.py @@ -0,0 +1,441 @@ +import collections +import copy +import re + +import torch + +# from ..utils.num_nodes import maybe_num_nodes + +__num_nodes_warn_msg__ = ( + "The number of nodes in your data object can only be inferred by its {} " + "indices, and hence may result in unexpected batch-wise behavior, e.g., " + "in case there exists isolated nodes. Please consider explicitly setting " + "the number of nodes for this data object by assigning it to " + "data.num_nodes." +) + + +def size_repr(key, item, indent=0): + indent_str = " " * indent + if torch.is_tensor(item) and item.dim() == 0: + out = item.item() + elif torch.is_tensor(item): + out = str(list(item.size())) + elif isinstance(item, list) or isinstance(item, tuple): + out = str([len(item)]) + elif isinstance(item, dict): + lines = [indent_str + size_repr(k, v, 2) for k, v in item.items()] + out = "{\n" + ",\n".join(lines) + "\n" + indent_str + "}" + elif isinstance(item, str): + out = f'"{item}"' + else: + out = str(item) + + return f"{indent_str}{key}={out}" + + +class Data(object): + r"""A plain old python object modeling a single graph with various + (optional) attributes: + + Args: + x (Tensor, optional): Node feature matrix with shape :obj:`[num_nodes, + num_node_features]`. (default: :obj:`None`) + edge_index (LongTensor, optional): Graph connectivity in COO format + with shape :obj:`[2, num_edges]`. (default: :obj:`None`) + edge_attr (Tensor, optional): Edge feature matrix with shape + :obj:`[num_edges, num_edge_features]`. (default: :obj:`None`) + y (Tensor, optional): Graph or node targets with arbitrary shape. + (default: :obj:`None`) + pos (Tensor, optional): Node position matrix with shape + :obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`) + normal (Tensor, optional): Normal vector matrix with shape + :obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`) + face (LongTensor, optional): Face adjacency matrix with shape + :obj:`[3, num_faces]`. (default: :obj:`None`) + + The data object is not restricted to these attributes and can be extended + by any other additional data. + + Example:: + + data = Data(x=x, edge_index=edge_index) + data.train_idx = torch.tensor([...], dtype=torch.long) + data.test_mask = torch.tensor([...], dtype=torch.bool) + """ + + def __init__( + self, + x=None, + edge_index=None, + edge_attr=None, + y=None, + pos=None, + normal=None, + face=None, + **kwargs, + ): + self.x = x + self.edge_index = edge_index + self.edge_attr = edge_attr + self.y = y + self.pos = pos + self.normal = normal + self.face = face + for key, item in kwargs.items(): + if key == "num_nodes": + self.__num_nodes__ = item + else: + self[key] = item + + if edge_index is not None and edge_index.dtype != torch.long: + raise ValueError( + ( + f"Argument `edge_index` needs to be of type `torch.long` but " + f"found type `{edge_index.dtype}`." + ) + ) + + if face is not None and face.dtype != torch.long: + raise ValueError( + ( + f"Argument `face` needs to be of type `torch.long` but found " + f"type `{face.dtype}`." + ) + ) + + @classmethod + def from_dict(cls, dictionary): + r"""Creates a data object from a python dictionary.""" + data = cls() + + for key, item in dictionary.items(): + data[key] = item + + return data + + def to_dict(self): + return {key: item for key, item in self} + + def to_namedtuple(self): + keys = self.keys + DataTuple = collections.namedtuple("DataTuple", keys) + return DataTuple(*[self[key] for key in keys]) + + def __getitem__(self, key): + r"""Gets the data of the attribute :obj:`key`.""" + return getattr(self, key, None) + + def __setitem__(self, key, value): + """Sets the attribute :obj:`key` to :obj:`value`.""" + setattr(self, key, value) + + def __delitem__(self, key): + r"""Delete the data of the attribute :obj:`key`.""" + return delattr(self, key) + + @property + def keys(self): + r"""Returns all names of graph attributes.""" + keys = [key for key in self.__dict__.keys() if self[key] is not None] + keys = [key for key in keys if key[:2] != "__" and key[-2:] != "__"] + return keys + + def __len__(self): + r"""Returns the number of all present attributes.""" + return len(self.keys) + + def __contains__(self, key): + r"""Returns :obj:`True`, if the attribute :obj:`key` is present in the + data.""" + return key in self.keys + + def __iter__(self): + r"""Iterates over all present attributes in the data, yielding their + attribute names and content.""" + for key in sorted(self.keys): + yield key, self[key] + + def __call__(self, *keys): + r"""Iterates over all attributes :obj:`*keys` in the data, yielding + their attribute names and content. + If :obj:`*keys` is not given this method will iterative over all + present attributes.""" + for key in sorted(self.keys) if not keys else keys: + if key in self: + yield key, self[key] + + def __cat_dim__(self, key, value): + r"""Returns the dimension for which :obj:`value` of attribute + :obj:`key` will get concatenated when creating batches. + + .. note:: + + This method is for internal use only, and should only be overridden + if the batch concatenation process is corrupted for a specific data + attribute. + """ + if bool(re.search("(index|face)", key)): + return -1 + return 0 + + def __inc__(self, key, value): + r"""Returns the incremental count to cumulatively increase the value + of the next attribute of :obj:`key` when creating batches. + + .. note:: + + This method is for internal use only, and should only be overridden + if the batch concatenation process is corrupted for a specific data + attribute. + """ + # Only `*index*` and `*face*` attributes should be cumulatively summed + # up when creating batches. + return self.num_nodes if bool(re.search("(index|face)", key)) else 0 + + @property + def num_nodes(self): + r"""Returns or sets the number of nodes in the graph. + + .. note:: + The number of nodes in your data object is typically automatically + inferred, *e.g.*, when node features :obj:`x` are present. + In some cases however, a graph may only be given by its edge + indices :obj:`edge_index`. + PyTorch Geometric then *guesses* the number of nodes + according to :obj:`edge_index.max().item() + 1`, but in case there + exists isolated nodes, this number has not to be correct and can + therefore result in unexpected batch-wise behavior. + Thus, we recommend to set the number of nodes in your data object + explicitly via :obj:`data.num_nodes = ...`. + You will be given a warning that requests you to do so. + """ + if hasattr(self, "__num_nodes__"): + return self.__num_nodes__ + for key, item in self("x", "pos", "normal", "batch"): + return item.size(self.__cat_dim__(key, item)) + if hasattr(self, "adj"): + return self.adj.size(0) + if hasattr(self, "adj_t"): + return self.adj_t.size(1) + # if self.face is not None: + # logging.warning(__num_nodes_warn_msg__.format("face")) + # return maybe_num_nodes(self.face) + # if self.edge_index is not None: + # logging.warning(__num_nodes_warn_msg__.format("edge")) + # return maybe_num_nodes(self.edge_index) + return None + + @num_nodes.setter + def num_nodes(self, num_nodes): + self.__num_nodes__ = num_nodes + + @property + def num_edges(self): + """ + Returns the number of edges in the graph. + For undirected graphs, this will return the number of bi-directional + edges, which is double the amount of unique edges. + """ + for key, item in self("edge_index", "edge_attr"): + return item.size(self.__cat_dim__(key, item)) + for key, item in self("adj", "adj_t"): + return item.nnz() + return None + + @property + def num_faces(self): + r"""Returns the number of faces in the mesh.""" + if self.face is not None: + return self.face.size(self.__cat_dim__("face", self.face)) + return None + + @property + def num_node_features(self): + r"""Returns the number of features per node in the graph.""" + if self.x is None: + return 0 + return 1 if self.x.dim() == 1 else self.x.size(1) + + @property + def num_features(self): + r"""Alias for :py:attr:`~num_node_features`.""" + return self.num_node_features + + @property + def num_edge_features(self): + r"""Returns the number of features per edge in the graph.""" + if self.edge_attr is None: + return 0 + return 1 if self.edge_attr.dim() == 1 else self.edge_attr.size(1) + + def __apply__(self, item, func): + if torch.is_tensor(item): + return func(item) + elif isinstance(item, (tuple, list)): + return [self.__apply__(v, func) for v in item] + elif isinstance(item, dict): + return {k: self.__apply__(v, func) for k, v in item.items()} + else: + return item + + def apply(self, func, *keys): + r"""Applies the function :obj:`func` to all tensor attributes + :obj:`*keys`. If :obj:`*keys` is not given, :obj:`func` is applied to + all present attributes. + """ + for key, item in self(*keys): + self[key] = self.__apply__(item, func) + return self + + def contiguous(self, *keys): + r"""Ensures a contiguous memory layout for all attributes :obj:`*keys`. + If :obj:`*keys` is not given, all present attributes are ensured to + have a contiguous memory layout.""" + return self.apply(lambda x: x.contiguous(), *keys) + + def to(self, device, *keys, **kwargs): + r"""Performs tensor dtype and/or device conversion to all attributes + :obj:`*keys`. + If :obj:`*keys` is not given, the conversion is applied to all present + attributes.""" + return self.apply(lambda x: x.to(device, **kwargs), *keys) + + def cpu(self, *keys): + r"""Copies all attributes :obj:`*keys` to CPU memory. + If :obj:`*keys` is not given, the conversion is applied to all present + attributes.""" + return self.apply(lambda x: x.cpu(), *keys) + + def cuda(self, device=None, non_blocking=False, *keys): + r"""Copies all attributes :obj:`*keys` to CUDA memory. + If :obj:`*keys` is not given, the conversion is applied to all present + attributes.""" + return self.apply( + lambda x: x.cuda(device=device, non_blocking=non_blocking), *keys + ) + + def clone(self): + r"""Performs a deep-copy of the data object.""" + return self.__class__.from_dict( + { + k: v.clone() if torch.is_tensor(v) else copy.deepcopy(v) + for k, v in self.__dict__.items() + } + ) + + def pin_memory(self, *keys): + r"""Copies all attributes :obj:`*keys` to pinned memory. + If :obj:`*keys` is not given, the conversion is applied to all present + attributes.""" + return self.apply(lambda x: x.pin_memory(), *keys) + + def debug(self): + if self.edge_index is not None: + if self.edge_index.dtype != torch.long: + raise RuntimeError( + ( + "Expected edge indices of dtype {}, but found dtype " " {}" + ).format(torch.long, self.edge_index.dtype) + ) + + if self.face is not None: + if self.face.dtype != torch.long: + raise RuntimeError( + ( + "Expected face indices of dtype {}, but found dtype " " {}" + ).format(torch.long, self.face.dtype) + ) + + if self.edge_index is not None: + if self.edge_index.dim() != 2 or self.edge_index.size(0) != 2: + raise RuntimeError( + ( + "Edge indices should have shape [2, num_edges] but found" + " shape {}" + ).format(self.edge_index.size()) + ) + + if self.edge_index is not None and self.num_nodes is not None: + if self.edge_index.numel() > 0: + min_index = self.edge_index.min() + max_index = self.edge_index.max() + else: + min_index = max_index = 0 + if min_index < 0 or max_index > self.num_nodes - 1: + raise RuntimeError( + ( + "Edge indices must lay in the interval [0, {}]" + " but found them in the interval [{}, {}]" + ).format(self.num_nodes - 1, min_index, max_index) + ) + + if self.face is not None: + if self.face.dim() != 2 or self.face.size(0) != 3: + raise RuntimeError( + ( + "Face indices should have shape [3, num_faces] but found" + " shape {}" + ).format(self.face.size()) + ) + + if self.face is not None and self.num_nodes is not None: + if self.face.numel() > 0: + min_index = self.face.min() + max_index = self.face.max() + else: + min_index = max_index = 0 + if min_index < 0 or max_index > self.num_nodes - 1: + raise RuntimeError( + ( + "Face indices must lay in the interval [0, {}]" + " but found them in the interval [{}, {}]" + ).format(self.num_nodes - 1, min_index, max_index) + ) + + if self.edge_index is not None and self.edge_attr is not None: + if self.edge_index.size(1) != self.edge_attr.size(0): + raise RuntimeError( + ( + "Edge indices and edge attributes hold a differing " + "number of edges, found {} and {}" + ).format(self.edge_index.size(), self.edge_attr.size()) + ) + + if self.x is not None and self.num_nodes is not None: + if self.x.size(0) != self.num_nodes: + raise RuntimeError( + ( + "Node features should hold {} elements in the first " + "dimension but found {}" + ).format(self.num_nodes, self.x.size(0)) + ) + + if self.pos is not None and self.num_nodes is not None: + if self.pos.size(0) != self.num_nodes: + raise RuntimeError( + ( + "Node positions should hold {} elements in the first " + "dimension but found {}" + ).format(self.num_nodes, self.pos.size(0)) + ) + + if self.normal is not None and self.num_nodes is not None: + if self.normal.size(0) != self.num_nodes: + raise RuntimeError( + ( + "Node normals should hold {} elements in the first " + "dimension but found {}" + ).format(self.num_nodes, self.normal.size(0)) + ) + + def __repr__(self): + cls = str(self.__class__.__name__) + has_dict = any([isinstance(item, dict) for _, item in self]) + + if not has_dict: + info = [size_repr(key, item) for key, item in self] + return "{}({})".format(cls, ", ".join(info)) + else: + info = [size_repr(key, item, indent=2) for key, item in self] + return "{}(\n{}\n)".format(cls, ",\n".join(info)) diff --git a/mace/tools/torch_geometric/dataloader.py b/mace/tools/torch_geometric/dataloader.py new file mode 100644 index 00000000..396b7e72 --- /dev/null +++ b/mace/tools/torch_geometric/dataloader.py @@ -0,0 +1,87 @@ +from collections.abc import Mapping, Sequence +from typing import List, Optional, Union + +import torch.utils.data +from torch.utils.data.dataloader import default_collate + +from .batch import Batch +from .data import Data +from .dataset import Dataset + + +class Collater: + def __init__(self, follow_batch, exclude_keys): + self.follow_batch = follow_batch + self.exclude_keys = exclude_keys + + def __call__(self, batch): + elem = batch[0] + if isinstance(elem, Data): + return Batch.from_data_list( + batch, + follow_batch=self.follow_batch, + exclude_keys=self.exclude_keys, + ) + elif isinstance(elem, torch.Tensor): + return default_collate(batch) + elif isinstance(elem, float): + return torch.tensor(batch, dtype=torch.float) + elif isinstance(elem, int): + return torch.tensor(batch) + elif isinstance(elem, str): + return batch + elif isinstance(elem, Mapping): + return {key: self([data[key] for data in batch]) for key in elem} + elif isinstance(elem, tuple) and hasattr(elem, "_fields"): + return type(elem)(*(self(s) for s in zip(*batch))) + elif isinstance(elem, Sequence) and not isinstance(elem, str): + return [self(s) for s in zip(*batch)] + + raise TypeError(f"DataLoader found invalid type: {type(elem)}") + + def collate(self, batch): # Deprecated... + return self(batch) + + +class DataLoader(torch.utils.data.DataLoader): + r"""A data loader which merges data objects from a + :class:`torch_geometric.data.Dataset` to a mini-batch. + Data objects can be either of type :class:`~torch_geometric.data.Data` or + :class:`~torch_geometric.data.HeteroData`. + Args: + dataset (Dataset): The dataset from which to load the data. + batch_size (int, optional): How many samples per batch to load. + (default: :obj:`1`) + shuffle (bool, optional): If set to :obj:`True`, the data will be + reshuffled at every epoch. (default: :obj:`False`) + follow_batch (List[str], optional): Creates assignment batch + vectors for each key in the list. (default: :obj:`None`) + exclude_keys (List[str], optional): Will exclude each key in the + list. (default: :obj:`None`) + **kwargs (optional): Additional arguments of + :class:`torch.utils.data.DataLoader`. + """ + + def __init__( + self, + dataset: Dataset, + batch_size: int = 1, + shuffle: bool = False, + follow_batch: Optional[List[str]] = [None], + exclude_keys: Optional[List[str]] = [None], + **kwargs, + ): + if "collate_fn" in kwargs: + del kwargs["collate_fn"] + + # Save for PyTorch Lightning < 1.6: + self.follow_batch = follow_batch + self.exclude_keys = exclude_keys + + super().__init__( + dataset, + batch_size, + shuffle, + collate_fn=Collater(follow_batch, exclude_keys), + **kwargs, + ) diff --git a/mace/tools/torch_geometric/dataset.py b/mace/tools/torch_geometric/dataset.py new file mode 100644 index 00000000..b4aeb2be --- /dev/null +++ b/mace/tools/torch_geometric/dataset.py @@ -0,0 +1,280 @@ +import copy +import os.path as osp +import re +import warnings +from collections.abc import Sequence +from typing import Any, Callable, List, Optional, Tuple, Union + +import numpy as np +import torch.utils.data +from torch import Tensor + +from .data import Data +from .utils import makedirs + +IndexType = Union[slice, Tensor, np.ndarray, Sequence] + + +class Dataset(torch.utils.data.Dataset): + r"""Dataset base class for creating graph datasets. + See `here `__ for the accompanying tutorial. + + Args: + root (string, optional): Root directory where the dataset should be + saved. (optional: :obj:`None`) + transform (callable, optional): A function/transform that takes in an + :obj:`torch_geometric.data.Data` object and returns a transformed + version. The data object will be transformed before every access. + (default: :obj:`None`) + pre_transform (callable, optional): A function/transform that takes in + an :obj:`torch_geometric.data.Data` object and returns a + transformed version. The data object will be transformed before + being saved to disk. (default: :obj:`None`) + pre_filter (callable, optional): A function that takes in an + :obj:`torch_geometric.data.Data` object and returns a boolean + value, indicating whether the data object should be included in the + final dataset. (default: :obj:`None`) + """ + + @property + def raw_file_names(self) -> Union[str, List[str], Tuple]: + r"""The name of the files to find in the :obj:`self.raw_dir` folder in + order to skip the download.""" + raise NotImplementedError + + @property + def processed_file_names(self) -> Union[str, List[str], Tuple]: + r"""The name of the files to find in the :obj:`self.processed_dir` + folder in order to skip the processing.""" + raise NotImplementedError + + def download(self): + r"""Downloads the dataset to the :obj:`self.raw_dir` folder.""" + raise NotImplementedError + + def process(self): + r"""Processes the dataset to the :obj:`self.processed_dir` folder.""" + raise NotImplementedError + + def len(self) -> int: + raise NotImplementedError + + def get(self, idx: int) -> Data: + r"""Gets the data object at index :obj:`idx`.""" + raise NotImplementedError + + def __init__( + self, + root: Optional[str] = None, + transform: Optional[Callable] = None, + pre_transform: Optional[Callable] = None, + pre_filter: Optional[Callable] = None, + ): + super().__init__() + + if isinstance(root, str): + root = osp.expanduser(osp.normpath(root)) + + self.root = root + self.transform = transform + self.pre_transform = pre_transform + self.pre_filter = pre_filter + self._indices: Optional[Sequence] = None + + if "download" in self.__class__.__dict__.keys(): + self._download() + + if "process" in self.__class__.__dict__.keys(): + self._process() + + def indices(self) -> Sequence: + return range(self.len()) if self._indices is None else self._indices + + @property + def raw_dir(self) -> str: + return osp.join(self.root, "raw") + + @property + def processed_dir(self) -> str: + return osp.join(self.root, "processed") + + @property + def num_node_features(self) -> int: + r"""Returns the number of features per node in the dataset.""" + data = self[0] + if hasattr(data, "num_node_features"): + return data.num_node_features + raise AttributeError( + f"'{data.__class__.__name__}' object has no " + f"attribute 'num_node_features'" + ) + + @property + def num_features(self) -> int: + r"""Alias for :py:attr:`~num_node_features`.""" + return self.num_node_features + + @property + def num_edge_features(self) -> int: + r"""Returns the number of features per edge in the dataset.""" + data = self[0] + if hasattr(data, "num_edge_features"): + return data.num_edge_features + raise AttributeError( + f"'{data.__class__.__name__}' object has no " + f"attribute 'num_edge_features'" + ) + + @property + def raw_paths(self) -> List[str]: + r"""The filepaths to find in order to skip the download.""" + files = to_list(self.raw_file_names) + return [osp.join(self.raw_dir, f) for f in files] + + @property + def processed_paths(self) -> List[str]: + r"""The filepaths to find in the :obj:`self.processed_dir` + folder in order to skip the processing.""" + files = to_list(self.processed_file_names) + return [osp.join(self.processed_dir, f) for f in files] + + def _download(self): + if files_exist(self.raw_paths): # pragma: no cover + return + + makedirs(self.raw_dir) + self.download() + + def _process(self): + f = osp.join(self.processed_dir, "pre_transform.pt") + if osp.exists(f) and torch.load(f) != _repr(self.pre_transform): + warnings.warn( + f"The `pre_transform` argument differs from the one used in " + f"the pre-processed version of this dataset. If you want to " + f"make use of another pre-processing technique, make sure to " + f"sure to delete '{self.processed_dir}' first" + ) + + f = osp.join(self.processed_dir, "pre_filter.pt") + if osp.exists(f) and torch.load(f) != _repr(self.pre_filter): + warnings.warn( + "The `pre_filter` argument differs from the one used in the " + "pre-processed version of this dataset. If you want to make " + "use of another pre-fitering technique, make sure to delete " + "'{self.processed_dir}' first" + ) + + if files_exist(self.processed_paths): # pragma: no cover + return + + print("Processing...") + + makedirs(self.processed_dir) + self.process() + + path = osp.join(self.processed_dir, "pre_transform.pt") + torch.save(_repr(self.pre_transform), path) + path = osp.join(self.processed_dir, "pre_filter.pt") + torch.save(_repr(self.pre_filter), path) + + print("Done!") + + def __len__(self) -> int: + r"""The number of examples in the dataset.""" + return len(self.indices()) + + def __getitem__( + self, + idx: Union[int, np.integer, IndexType], + ) -> Union["Dataset", Data]: + r"""In case :obj:`idx` is of type integer, will return the data object + at index :obj:`idx` (and transforms it in case :obj:`transform` is + present). + In case :obj:`idx` is a slicing object, *e.g.*, :obj:`[2:5]`, a list, a + tuple, a PyTorch :obj:`LongTensor` or a :obj:`BoolTensor`, or a numpy + :obj:`np.array`, will return a subset of the dataset at the specified + indices.""" + if ( + isinstance(idx, (int, np.integer)) + or (isinstance(idx, Tensor) and idx.dim() == 0) + or (isinstance(idx, np.ndarray) and np.isscalar(idx)) + ): + data = self.get(self.indices()[idx]) + data = data if self.transform is None else self.transform(data) + return data + + else: + return self.index_select(idx) + + def index_select(self, idx: IndexType) -> "Dataset": + indices = self.indices() + + if isinstance(idx, slice): + indices = indices[idx] + + elif isinstance(idx, Tensor) and idx.dtype == torch.long: + return self.index_select(idx.flatten().tolist()) + + elif isinstance(idx, Tensor) and idx.dtype == torch.bool: + idx = idx.flatten().nonzero(as_tuple=False) + return self.index_select(idx.flatten().tolist()) + + elif isinstance(idx, np.ndarray) and idx.dtype == np.int64: + return self.index_select(idx.flatten().tolist()) + + elif isinstance(idx, np.ndarray) and idx.dtype == np.bool: + idx = idx.flatten().nonzero()[0] + return self.index_select(idx.flatten().tolist()) + + elif isinstance(idx, Sequence) and not isinstance(idx, str): + indices = [indices[i] for i in idx] + + else: + raise IndexError( + f"Only integers, slices (':'), list, tuples, torch.tensor and " + f"np.ndarray of dtype long or bool are valid indices (got " + f"'{type(idx).__name__}')" + ) + + dataset = copy.copy(self) + dataset._indices = indices + return dataset + + def shuffle( + self, + return_perm: bool = False, + ) -> Union["Dataset", Tuple["Dataset", Tensor]]: + r"""Randomly shuffles the examples in the dataset. + + Args: + return_perm (bool, optional): If set to :obj:`True`, will return + the random permutation used to shuffle the dataset in addition. + (default: :obj:`False`) + """ + perm = torch.randperm(len(self)) + dataset = self.index_select(perm) + return (dataset, perm) if return_perm is True else dataset + + def __repr__(self) -> str: + arg_repr = str(len(self)) if len(self) > 1 else "" + return f"{self.__class__.__name__}({arg_repr})" + + +def to_list(value: Any) -> Sequence: + if isinstance(value, Sequence) and not isinstance(value, str): + return value + else: + return [value] + + +def files_exist(files: List[str]) -> bool: + # NOTE: We return `False` in case `files` is empty, leading to a + # re-processing of files on every instantiation. + return len(files) != 0 and all([osp.exists(f) for f in files]) + + +def _repr(obj: Any) -> str: + if obj is None: + return "None" + return re.sub("(<.*?)\\s.*(>)", r"\1\2", obj.__repr__()) diff --git a/mace/tools/torch_geometric/seed.py b/mace/tools/torch_geometric/seed.py new file mode 100644 index 00000000..be27fcaa --- /dev/null +++ b/mace/tools/torch_geometric/seed.py @@ -0,0 +1,17 @@ +import random + +import numpy as np +import torch + + +def seed_everything(seed: int): + r"""Sets the seed for generating random numbers in :pytorch:`PyTorch`, + :obj:`numpy` and Python. + + Args: + seed (int): The desired seed. + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) diff --git a/mace/tools/torch_geometric/utils.py b/mace/tools/torch_geometric/utils.py new file mode 100644 index 00000000..f53b8f80 --- /dev/null +++ b/mace/tools/torch_geometric/utils.py @@ -0,0 +1,54 @@ +import os +import os.path as osp +import ssl +import urllib +import zipfile + + +def makedirs(dir): + os.makedirs(dir, exist_ok=True) + + +def download_url(url, folder, log=True): + r"""Downloads the content of an URL to a specific folder. + + Args: + url (string): The url. + folder (string): The folder. + log (bool, optional): If :obj:`False`, will not print anything to the + console. (default: :obj:`True`) + """ + + filename = url.rpartition("/")[2].split("?")[0] + path = osp.join(folder, filename) + + if osp.exists(path): # pragma: no cover + if log: + print("Using exist file", filename) + return path + + if log: + print("Downloading", url) + + makedirs(folder) + + context = ssl._create_unverified_context() + data = urllib.request.urlopen(url, context=context) + + with open(path, "wb") as f: + f.write(data.read()) + + return path + + +def extract_zip(path, folder, log=True): + r"""Extracts a zip archive to a specific folder. + + Args: + path (string): The path to the tar archive. + folder (string): The folder. + log (bool, optional): If :obj:`False`, will not print anything to the + console. (default: :obj:`True`) + """ + with zipfile.ZipFile(path, "r") as f: + f.extractall(folder) diff --git a/mace/tools/torch_tools.py b/mace/tools/torch_tools.py new file mode 100644 index 00000000..e42a74f8 --- /dev/null +++ b/mace/tools/torch_tools.py @@ -0,0 +1,141 @@ +########################################################################################### +# Tools for torch +# Authors: Ilyes Batatia, Gregor Simm +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import logging +from contextlib import contextmanager +from typing import Dict + +import numpy as np +import torch +from e3nn.io import CartesianTensor + +TensorDict = Dict[str, torch.Tensor] + + +def to_one_hot(indices: torch.Tensor, num_classes: int) -> torch.Tensor: + """ + Generates one-hot encoding with classes from + :param indices: (N x 1) tensor + :param num_classes: number of classes + :param device: torch device + :return: (N x num_classes) tensor + """ + shape = indices.shape[:-1] + (num_classes,) + oh = torch.zeros(shape, device=indices.device).view(shape) + + # scatter_ is the in-place version of scatter + oh.scatter_(dim=-1, index=indices, value=1) + + return oh.view(*shape) + + +def count_parameters(module: torch.nn.Module) -> int: + return int(sum(np.prod(p.shape) for p in module.parameters())) + + +def tensor_dict_to_device(td: TensorDict, device: torch.device) -> TensorDict: + return {k: v.to(device) if v is not None else None for k, v in td.items()} + + +def set_seeds(seed: int) -> None: + np.random.seed(seed) + torch.manual_seed(seed) + + +def to_numpy(t: torch.Tensor) -> np.ndarray: + return t.cpu().detach().numpy() + + +def init_device(device_str: str) -> torch.device: + if "cuda" in device_str: + assert torch.cuda.is_available(), "No CUDA device available!" + if ":" in device_str: + # Check if the desired device is available + assert int(device_str.split(":")[-1]) < torch.cuda.device_count() + logging.info( + f"CUDA version: {torch.version.cuda}, CUDA device: {torch.cuda.current_device()}" + ) + torch.cuda.init() + return torch.device(device_str) + if device_str == "mps": + assert torch.backends.mps.is_available(), "No MPS backend is available!" + logging.info("Using MPS GPU acceleration") + return torch.device("mps") + if device_str == "xpu": + torch.xpu.is_available() + return torch.device("xpu") + + logging.info("Using CPU") + return torch.device("cpu") + + +dtype_dict = {"float32": torch.float32, "float64": torch.float64} + + +def set_default_dtype(dtype: str) -> None: + torch.set_default_dtype(dtype_dict[dtype]) + + +def spherical_to_cartesian(t: torch.Tensor): + """ + Convert spherical notation to cartesian notation + """ + stress_cart_tensor = CartesianTensor("ij=ji") + stress_rtp = stress_cart_tensor.reduced_tensor_products() + return stress_cart_tensor.to_cartesian(t, rtp=stress_rtp) + + +def cartesian_to_spherical(t: torch.Tensor): + """ + Convert cartesian notation to spherical notation + """ + stress_cart_tensor = CartesianTensor("ij=ji") + stress_rtp = stress_cart_tensor.reduced_tensor_products() + return stress_cart_tensor.to_cartesian(t, rtp=stress_rtp) + + +def voigt_to_matrix(t: torch.Tensor): + """ + Convert voigt notation to matrix notation + :param t: (6,) tensor or (3, 3) tensor or (9,) tensor + :return: (3, 3) tensor + """ + if t.shape == (3, 3): + return t + if t.shape == (6,): + return torch.tensor( + [ + [t[0], t[5], t[4]], + [t[5], t[1], t[3]], + [t[4], t[3], t[2]], + ], + dtype=t.dtype, + ) + if t.shape == (9,): + return t.view(3, 3) + + raise ValueError( + f"Stress tensor must be of shape (6,) or (3, 3), or (9,) but has shape {t.shape}" + ) + + +def init_wandb(project: str, entity: str, name: str, config: dict, directory: str): + import wandb + + wandb.init(project=project, entity=entity, name=name, config=config, dir=directory) + + +@contextmanager +def default_dtype(dtype: torch.dtype): + """Context manager for configuring the default_dtype used by torch + + Args: + dtype (torch.dtype): the default dtype to use within this context manager + """ + init = torch.get_default_dtype() + torch.set_default_dtype(dtype) + yield + torch.set_default_dtype(init) diff --git a/mace/tools/train.py b/mace/tools/train.py new file mode 100644 index 00000000..3c6b8325 --- /dev/null +++ b/mace/tools/train.py @@ -0,0 +1,538 @@ +########################################################################################### +# Training script +# Authors: Ilyes Batatia, Gregor Simm, David Kovacs +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import dataclasses +import logging +import time +from contextlib import nullcontext +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.distributed +from torch.nn.parallel import DistributedDataParallel +from torch.optim.swa_utils import SWALR, AveragedModel +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from torch_ema import ExponentialMovingAverage +from torchmetrics import Metric + +from . import torch_geometric +from .checkpoint import CheckpointHandler, CheckpointState +from .torch_tools import to_numpy +from .utils import ( + MetricsLogger, + compute_mae, + compute_q95, + compute_rel_mae, + compute_rel_rmse, + compute_rmse, +) + + +@dataclasses.dataclass +class SWAContainer: + model: AveragedModel + scheduler: SWALR + start: int + loss_fn: torch.nn.Module + + +def valid_err_log( + valid_loss, + eval_metrics, + logger, + log_errors, + epoch=None, + valid_loader_name="Default", +): + eval_metrics["mode"] = "eval" + eval_metrics["epoch"] = epoch + logger.log(eval_metrics) + if epoch is None: + inintial_phrase = "Initial" + else: + inintial_phrase = f"Epoch {epoch}" + if log_errors == "PerAtomRMSE": + error_e = eval_metrics["rmse_e_per_atom"] * 1e3 + error_f = eval_metrics["rmse_f"] * 1e3 + logging.info( + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E_per_atom={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A" + ) + elif ( + log_errors == "PerAtomRMSEstressvirials" + and eval_metrics["rmse_stress"] is not None + ): + error_e = eval_metrics["rmse_e_per_atom"] * 1e3 + error_f = eval_metrics["rmse_f"] * 1e3 + error_stress = eval_metrics["rmse_stress"] * 1e3 + logging.info( + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E_per_atom={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A, RMSE_stress={error_stress:8.2f} meV / A^3", + ) + elif ( + log_errors == "PerAtomRMSEstressvirials" + and eval_metrics["rmse_virials_per_atom"] is not None + ): + error_e = eval_metrics["rmse_e_per_atom"] * 1e3 + error_f = eval_metrics["rmse_f"] * 1e3 + error_virials = eval_metrics["rmse_virials_per_atom"] * 1e3 + logging.info( + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E_per_atom={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A, RMSE_virials_per_atom={error_virials:8.2f} meV", + ) + elif ( + log_errors == "PerAtomMAEstressvirials" + and eval_metrics["mae_stress_per_atom"] is not None + ): + error_e = eval_metrics["mae_e_per_atom"] * 1e3 + error_f = eval_metrics["mae_f"] * 1e3 + error_stress = eval_metrics["mae_stress"] * 1e3 + logging.info( + f"{inintial_phrase}: loss={valid_loss:8.8f}, MAE_E_per_atom={error_e:8.2f} meV, MAE_F={error_f:8.2f} meV / A, MAE_stress={error_stress:8.2f} meV / A^3" + ) + elif ( + log_errors == "PerAtomMAEstressvirials" + and eval_metrics["mae_virials_per_atom"] is not None + ): + error_e = eval_metrics["mae_e_per_atom"] * 1e3 + error_f = eval_metrics["mae_f"] * 1e3 + error_virials = eval_metrics["mae_virials"] * 1e3 + logging.info( + f"{inintial_phrase}: loss={valid_loss:8.8f}, MAE_E_per_atom={error_e:8.2f} meV, MAE_F={error_f:8.2f} meV / A, MAE_virials={error_virials:8.2f} meV" + ) + elif log_errors == "TotalRMSE": + error_e = eval_metrics["rmse_e"] * 1e3 + error_f = eval_metrics["rmse_f"] * 1e3 + logging.info( + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A", + ) + elif log_errors == "PerAtomMAE": + error_e = eval_metrics["mae_e_per_atom"] * 1e3 + error_f = eval_metrics["mae_f"] * 1e3 + logging.info( + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, MAE_E_per_atom={error_e:8.2f} meV, MAE_F={error_f:8.2f} meV / A", + ) + elif log_errors == "TotalMAE": + error_e = eval_metrics["mae_e"] * 1e3 + error_f = eval_metrics["mae_f"] * 1e3 + logging.info( + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, MAE_E={error_e:8.2f} meV, MAE_F={error_f:8.2f} meV / A", + ) + elif log_errors == "DipoleRMSE": + error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3 + logging.info( + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_MU_per_atom={error_mu:8.2f} mDebye", + ) + elif log_errors == "EnergyDipoleRMSE": + error_e = eval_metrics["rmse_e_per_atom"] * 1e3 + error_f = eval_metrics["rmse_f"] * 1e3 + error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3 + logging.info( + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E_per_atom={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A, RMSE_Mu_per_atom={error_mu:8.2f} mDebye", + ) + + +def train( + model: torch.nn.Module, + loss_fn: torch.nn.Module, + train_loader: DataLoader, + valid_loaders: Dict[str, DataLoader], + optimizer: torch.optim.Optimizer, + lr_scheduler: torch.optim.lr_scheduler.ExponentialLR, + start_epoch: int, + max_num_epochs: int, + patience: int, + checkpoint_handler: CheckpointHandler, + logger: MetricsLogger, + eval_interval: int, + output_args: Dict[str, bool], + device: torch.device, + log_errors: str, + swa: Optional[SWAContainer] = None, + ema: Optional[ExponentialMovingAverage] = None, + max_grad_norm: Optional[float] = 10.0, + log_wandb: bool = False, + distributed: bool = False, + save_all_checkpoints: bool = False, + distributed_model: Optional[DistributedDataParallel] = None, + train_sampler: Optional[DistributedSampler] = None, + rank: Optional[int] = 0, +): + lowest_loss = np.inf + valid_loss = np.inf + patience_counter = 0 + swa_start = True + keep_last = False + if log_wandb: + import wandb + + if max_grad_norm is not None: + logging.info(f"Using gradient clipping with tolerance={max_grad_norm:.3f}") + + logging.info("") + logging.info("===========TRAINING===========") + logging.info("Started training, reporting errors on validation set") + logging.info("Loss metrics on validation set") + epoch = start_epoch + + # log validation loss before _any_ training + valid_loss = 0.0 + for valid_loader_name, valid_loader in valid_loaders.items(): + valid_loss_head, eval_metrics = evaluate( + model=model, + loss_fn=loss_fn, + data_loader=valid_loader, + output_args=output_args, + device=device, + ) + valid_err_log( + valid_loss_head, eval_metrics, logger, log_errors, None, valid_loader_name + ) + valid_loss = valid_loss_head # consider only the last head for the checkpoint + + while epoch < max_num_epochs: + # LR scheduler and SWA update + if swa is None or epoch < swa.start: + if epoch > start_epoch: + lr_scheduler.step( + metrics=valid_loss + ) # Can break if exponential LR, TODO fix that! + else: + if swa_start: + logging.info("Changing loss based on Stage Two Weights") + lowest_loss = np.inf + swa_start = False + keep_last = True + loss_fn = swa.loss_fn + swa.model.update_parameters(model) + if epoch > start_epoch: + swa.scheduler.step() + + # Train + if distributed: + train_sampler.set_epoch(epoch) + if "ScheduleFree" in type(optimizer).__name__: + optimizer.train() + train_one_epoch( + model=model, + loss_fn=loss_fn, + data_loader=train_loader, + optimizer=optimizer, + epoch=epoch, + output_args=output_args, + max_grad_norm=max_grad_norm, + ema=ema, + logger=logger, + device=device, + distributed_model=distributed_model, + rank=rank, + ) + if distributed: + torch.distributed.barrier() + + # Validate + if epoch % eval_interval == 0: + model_to_evaluate = ( + model if distributed_model is None else distributed_model + ) + param_context = ( + ema.average_parameters() if ema is not None else nullcontext() + ) + if "ScheduleFree" in type(optimizer).__name__: + optimizer.eval() + with param_context: + valid_loss = 0.0 + wandb_log_dict = {} + for valid_loader_name, valid_loader in valid_loaders.items(): + valid_loss_head, eval_metrics = evaluate( + model=model_to_evaluate, + loss_fn=loss_fn, + data_loader=valid_loader, + output_args=output_args, + device=device, + ) + if rank == 0: + valid_err_log( + valid_loss_head, + eval_metrics, + logger, + log_errors, + epoch, + valid_loader_name, + ) + if log_wandb: + wandb_log_dict[valid_loader_name] = { + "epoch": epoch, + "valid_loss": valid_loss_head, + "valid_rmse_e_per_atom": eval_metrics[ + "rmse_e_per_atom" + ], + "valid_rmse_f": eval_metrics["rmse_f"], + } + valid_loss = ( + valid_loss_head # consider only the last head for the checkpoint + ) + if log_wandb: + wandb.log(wandb_log_dict) + if rank == 0: + if valid_loss >= lowest_loss: + patience_counter += 1 + if patience_counter >= patience: + if swa is not None and epoch < swa.start: + logging.info( + f"Stopping optimization after {patience_counter} epochs without improvement and starting Stage Two" + ) + epoch = swa.start + else: + logging.info( + f"Stopping optimization after {patience_counter} epochs without improvement" + ) + break + if save_all_checkpoints: + param_context = ( + ema.average_parameters() + if ema is not None + else nullcontext() + ) + with param_context: + checkpoint_handler.save( + state=CheckpointState(model, optimizer, lr_scheduler), + epochs=epoch, + keep_last=True, + ) + else: + lowest_loss = valid_loss + patience_counter = 0 + param_context = ( + ema.average_parameters() if ema is not None else nullcontext() + ) + with param_context: + checkpoint_handler.save( + state=CheckpointState(model, optimizer, lr_scheduler), + epochs=epoch, + keep_last=keep_last, + ) + keep_last = False or save_all_checkpoints + if distributed: + torch.distributed.barrier() + epoch += 1 + + logging.info("Training complete") + + +def train_one_epoch( + model: torch.nn.Module, + loss_fn: torch.nn.Module, + data_loader: DataLoader, + optimizer: torch.optim.Optimizer, + epoch: int, + output_args: Dict[str, bool], + max_grad_norm: Optional[float], + ema: Optional[ExponentialMovingAverage], + logger: MetricsLogger, + device: torch.device, + distributed_model: Optional[DistributedDataParallel] = None, + rank: Optional[int] = 0, +) -> None: + model_to_train = model if distributed_model is None else distributed_model + for batch in data_loader: + _, opt_metrics = take_step( + model=model_to_train, + loss_fn=loss_fn, + batch=batch, + optimizer=optimizer, + ema=ema, + output_args=output_args, + max_grad_norm=max_grad_norm, + device=device, + ) + opt_metrics["mode"] = "opt" + opt_metrics["epoch"] = epoch + if rank == 0: + logger.log(opt_metrics) + + +def take_step( + model: torch.nn.Module, + loss_fn: torch.nn.Module, + batch: torch_geometric.batch.Batch, + optimizer: torch.optim.Optimizer, + ema: Optional[ExponentialMovingAverage], + output_args: Dict[str, bool], + max_grad_norm: Optional[float], + device: torch.device, +) -> Tuple[float, Dict[str, Any]]: + start_time = time.time() + batch = batch.to(device) + optimizer.zero_grad(set_to_none=True) + batch_dict = batch.to_dict() + output = model( + batch_dict, + training=True, + compute_force=output_args["forces"], + compute_virials=output_args["virials"], + compute_stress=output_args["stress"], + ) + loss = loss_fn(pred=output, ref=batch) + loss.backward() + if max_grad_norm is not None: + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_grad_norm) + optimizer.step() + + if ema is not None: + ema.update() + + loss_dict = { + "loss": to_numpy(loss), + "time": time.time() - start_time, + } + + return loss, loss_dict + + +def evaluate( + model: torch.nn.Module, + loss_fn: torch.nn.Module, + data_loader: DataLoader, + output_args: Dict[str, bool], + device: torch.device, +) -> Tuple[float, Dict[str, Any]]: + for param in model.parameters(): + param.requires_grad = False + + metrics = MACELoss(loss_fn=loss_fn).to(device) + + start_time = time.time() + for batch in data_loader: + batch = batch.to(device) + batch_dict = batch.to_dict() + output = model( + batch_dict, + training=False, + compute_force=output_args["forces"], + compute_virials=output_args["virials"], + compute_stress=output_args["stress"], + ) + avg_loss, aux = metrics(batch, output) + + avg_loss, aux = metrics.compute() + aux["time"] = time.time() - start_time + metrics.reset() + + for param in model.parameters(): + param.requires_grad = True + + return avg_loss, aux + + +class MACELoss(Metric): + def __init__(self, loss_fn: torch.nn.Module): + super().__init__() + self.loss_fn = loss_fn + self.add_state("total_loss", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("num_data", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("E_computed", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("delta_es", default=[], dist_reduce_fx="cat") + self.add_state("delta_es_per_atom", default=[], dist_reduce_fx="cat") + self.add_state("Fs_computed", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("fs", default=[], dist_reduce_fx="cat") + self.add_state("delta_fs", default=[], dist_reduce_fx="cat") + self.add_state( + "stress_computed", default=torch.tensor(0.0), dist_reduce_fx="sum" + ) + self.add_state("delta_stress", default=[], dist_reduce_fx="cat") + self.add_state( + "virials_computed", default=torch.tensor(0.0), dist_reduce_fx="sum" + ) + self.add_state("delta_virials", default=[], dist_reduce_fx="cat") + self.add_state("delta_virials_per_atom", default=[], dist_reduce_fx="cat") + self.add_state("Mus_computed", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("mus", default=[], dist_reduce_fx="cat") + self.add_state("delta_mus", default=[], dist_reduce_fx="cat") + self.add_state("delta_mus_per_atom", default=[], dist_reduce_fx="cat") + + def update(self, batch, output): # pylint: disable=arguments-differ + loss = self.loss_fn(pred=output, ref=batch) + self.total_loss += loss + self.num_data += batch.num_graphs + + if output.get("energy") is not None and batch.energy is not None: + self.E_computed += 1.0 + self.delta_es.append(batch.energy - output["energy"]) + self.delta_es_per_atom.append( + (batch.energy - output["energy"]) / (batch.ptr[1:] - batch.ptr[:-1]) + ) + if output.get("forces") is not None and batch.forces is not None: + self.Fs_computed += 1.0 + self.fs.append(batch.forces) + self.delta_fs.append(batch.forces - output["forces"]) + if output.get("stress") is not None and batch.stress is not None: + self.stress_computed += 1.0 + self.delta_stress.append(batch.stress - output["stress"]) + if output.get("virials") is not None and batch.virials is not None: + self.virials_computed += 1.0 + self.delta_virials.append(batch.virials - output["virials"]) + self.delta_virials_per_atom.append( + (batch.virials - output["virials"]) + / (batch.ptr[1:] - batch.ptr[:-1]).view(-1, 1, 1) + ) + if output.get("dipole") is not None and batch.dipole is not None: + self.Mus_computed += 1.0 + self.mus.append(batch.dipole) + self.delta_mus.append(batch.dipole - output["dipole"]) + self.delta_mus_per_atom.append( + (batch.dipole - output["dipole"]) + / (batch.ptr[1:] - batch.ptr[:-1]).unsqueeze(-1) + ) + + def convert(self, delta: Union[torch.Tensor, List[torch.Tensor]]) -> np.ndarray: + if isinstance(delta, list): + delta = torch.cat(delta) + return to_numpy(delta) + + def compute(self): + aux = {} + aux["loss"] = to_numpy(self.total_loss / self.num_data).item() + if self.E_computed: + delta_es = self.convert(self.delta_es) + delta_es_per_atom = self.convert(self.delta_es_per_atom) + aux["mae_e"] = compute_mae(delta_es) + aux["mae_e_per_atom"] = compute_mae(delta_es_per_atom) + aux["rmse_e"] = compute_rmse(delta_es) + aux["rmse_e_per_atom"] = compute_rmse(delta_es_per_atom) + aux["q95_e"] = compute_q95(delta_es) + if self.Fs_computed: + fs = self.convert(self.fs) + delta_fs = self.convert(self.delta_fs) + aux["mae_f"] = compute_mae(delta_fs) + aux["rel_mae_f"] = compute_rel_mae(delta_fs, fs) + aux["rmse_f"] = compute_rmse(delta_fs) + aux["rel_rmse_f"] = compute_rel_rmse(delta_fs, fs) + aux["q95_f"] = compute_q95(delta_fs) + if self.stress_computed: + delta_stress = self.convert(self.delta_stress) + aux["mae_stress"] = compute_mae(delta_stress) + aux["rmse_stress"] = compute_rmse(delta_stress) + aux["q95_stress"] = compute_q95(delta_stress) + if self.virials_computed: + delta_virials = self.convert(self.delta_virials) + delta_virials_per_atom = self.convert(self.delta_virials_per_atom) + aux["mae_virials"] = compute_mae(delta_virials) + aux["rmse_virials"] = compute_rmse(delta_virials) + aux["rmse_virials_per_atom"] = compute_rmse(delta_virials_per_atom) + aux["q95_virials"] = compute_q95(delta_virials) + if self.Mus_computed: + mus = self.convert(self.mus) + delta_mus = self.convert(self.delta_mus) + delta_mus_per_atom = self.convert(self.delta_mus_per_atom) + aux["mae_mu"] = compute_mae(delta_mus) + aux["mae_mu_per_atom"] = compute_mae(delta_mus_per_atom) + aux["rel_mae_mu"] = compute_rel_mae(delta_mus, mus) + aux["rmse_mu"] = compute_rmse(delta_mus) + aux["rmse_mu_per_atom"] = compute_rmse(delta_mus_per_atom) + aux["rel_rmse_mu"] = compute_rel_rmse(delta_mus, mus) + aux["q95_mu"] = compute_q95(delta_mus) + + return aux["loss"], aux diff --git a/mace/tools/utils.py b/mace/tools/utils.py new file mode 100644 index 00000000..28a77efe --- /dev/null +++ b/mace/tools/utils.py @@ -0,0 +1,147 @@ +########################################################################################### +# Statistics utilities +# Authors: Ilyes Batatia, Gregor Simm, David Kovacs +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import json +import logging +import os +import sys +from typing import Any, Dict, Iterable, Optional, Sequence, Union + +import numpy as np +import torch + +from .torch_tools import to_numpy + + +def compute_mae(delta: np.ndarray) -> float: + return np.mean(np.abs(delta)).item() + + +def compute_rel_mae(delta: np.ndarray, target_val: np.ndarray) -> float: + target_norm = np.mean(np.abs(target_val)) + return np.mean(np.abs(delta)).item() / (target_norm + 1e-9) * 100 + + +def compute_rmse(delta: np.ndarray) -> float: + return np.sqrt(np.mean(np.square(delta))).item() + + +def compute_rel_rmse(delta: np.ndarray, target_val: np.ndarray) -> float: + target_norm = np.sqrt(np.mean(np.square(target_val))).item() + return np.sqrt(np.mean(np.square(delta))).item() / (target_norm + 1e-9) * 100 + + +def compute_q95(delta: np.ndarray) -> float: + return np.percentile(np.abs(delta), q=95) + + +def compute_c(delta: np.ndarray, eta: float) -> float: + return np.mean(np.abs(delta) < eta).item() + + +def get_tag(name: str, seed: int) -> str: + return f"{name}_run-{seed}" + + +def setup_logger( + level: Union[int, str] = logging.INFO, + tag: Optional[str] = None, + directory: Optional[str] = None, + rank: Optional[int] = 0, +): + # Create a logger + logger = logging.getLogger() + logger.setLevel(logging.DEBUG) # Set to DEBUG to capture all levels + + # Create formatters + formatter = logging.Formatter( + "%(asctime)s.%(msecs)03d %(levelname)s: %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + + # Add filter for rank + logger.addFilter(lambda _: rank == 0) + + # Create console handler + ch = logging.StreamHandler(stream=sys.stdout) + ch.setLevel(level) + ch.setFormatter(formatter) + logger.addHandler(ch) + + if directory is not None and tag is not None: + os.makedirs(name=directory, exist_ok=True) + + # Create file handler for non-debug logs + main_log_path = os.path.join(directory, f"{tag}.log") + fh_main = logging.FileHandler(main_log_path) + fh_main.setLevel(level) + fh_main.setFormatter(formatter) + logger.addHandler(fh_main) + + # Create file handler for debug logs + debug_log_path = os.path.join(directory, f"{tag}_debug.log") + fh_debug = logging.FileHandler(debug_log_path) + fh_debug.setLevel(logging.DEBUG) + fh_debug.setFormatter(formatter) + fh_debug.addFilter(lambda record: record.levelno >= logging.DEBUG) + logger.addHandler(fh_debug) + + +class AtomicNumberTable: + def __init__(self, zs: Sequence[int]): + self.zs = zs + + def __len__(self) -> int: + return len(self.zs) + + def __str__(self): + return f"AtomicNumberTable: {tuple(s for s in self.zs)}" + + def index_to_z(self, index: int) -> int: + return self.zs[index] + + def z_to_index(self, atomic_number: str) -> int: + return self.zs.index(atomic_number) + + +def get_atomic_number_table_from_zs(zs: Iterable[int]) -> AtomicNumberTable: + z_set = set() + for z in zs: + z_set.add(z) + return AtomicNumberTable(sorted(list(z_set))) + + +def atomic_numbers_to_indices( + atomic_numbers: np.ndarray, z_table: AtomicNumberTable +) -> np.ndarray: + to_index_fn = np.vectorize(z_table.z_to_index) + return to_index_fn(atomic_numbers) + + +class UniversalEncoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, np.integer): + return int(o) + if isinstance(o, np.floating): + return float(o) + if isinstance(o, np.ndarray): + return o.tolist() + if isinstance(o, torch.Tensor): + return to_numpy(o) + return json.JSONEncoder.default(self, o) + + +class MetricsLogger: + def __init__(self, directory: str, tag: str) -> None: + self.directory = directory + self.filename = tag + ".txt" + self.path = os.path.join(self.directory, self.filename) + + def log(self, d: Dict[str, Any]) -> None: + os.makedirs(name=self.directory, exist_ok=True) + with open(self.path, mode="a", encoding="utf-8") as f: + f.write(json.dumps(d, cls=UniversalEncoder)) + f.write("\n") diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..489bc6e5 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,41 @@ +[build-system] +requires = [ + "setuptools>=42", + "wheel", +] +build-backend = "setuptools.build_meta" + +# Make isort compatible with black +[tool.isort] +profile = "black" + +# Pylint +[tool.pylint.'MESSAGES CONTROL'] +disable = [ + "line-too-long", + "no-member", + "missing-module-docstring", + "missing-class-docstring", + "missing-function-docstring", + "too-many-arguments", + "too-many-locals", + "not-callable", + "logging-fstring-interpolation", + "logging-not-lazy", + "logging-too-many-args", + "invalid-name", + "too-few-public-methods", + "too-many-instance-attributes", + "too-many-statements", + "too-many-branches", + "import-outside-toplevel", + "cell-var-from-loop", + "duplicate-code", + "use-dict-literal", +] + +[tool.pylint.MASTER] +ignore-paths = [ + "^mace/tools/torch_geometric/.*$", + "^mace/tools/scatter.py$", +] diff --git a/scripts/__init__.py b/scripts/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/scripts/distributed_example.sbatch b/scripts/distributed_example.sbatch new file mode 100644 index 00000000..c39f55fd --- /dev/null +++ b/scripts/distributed_example.sbatch @@ -0,0 +1,34 @@ +#!/bin/bash +#SBATCH --partition=gpu +#SBATCH --job-name=train +#SBATCH --output=train.out +#SBATCH --nodes=2 +#SBATCH --ntasks=20 +#SBATCH --ntasks-per-node=10 +#SBATCH --gpus-per-node=10 +#SBATCH --cpus-per-task=8 +#SBATCH --exclusive +#SBATCH --time=1:00:00 + +srun python mace/scripts/run_train.py \ + --name='model' \ + --model='MACE' \ + --num_interactions=2 \ + --num_channels=128 \ + --max_L=2 \ + --correlation=3 \ + --E0s='average' \ + --r_max=5.0 \ + --train_file='./h5_data/train.h5' \ + --valid_file='./h5_data/valid.h5' \ + --statistics_file='./h5_data/statistics.json' \ + --num_workers=8 \ + --batch_size=20 \ + --valid_batch_size=80 \ + --max_num_epochs=100 \ + --loss='weighted' \ + --error_table='PerAtomRMSE' \ + --default_dtype='float32' \ + --device='cuda' \ + --distributed \ + --seed=2222 \ \ No newline at end of file diff --git a/scripts/eval_configs.py b/scripts/eval_configs.py new file mode 100644 index 00000000..d2f4e217 --- /dev/null +++ b/scripts/eval_configs.py @@ -0,0 +1,6 @@ +## Wrapper for mace.cli.eval_configs.main ## + +from mace.cli.eval_configs import main + +if __name__ == "__main__": + main() diff --git a/scripts/preprocess_data.py b/scripts/preprocess_data.py new file mode 100644 index 00000000..3c2c288c --- /dev/null +++ b/scripts/preprocess_data.py @@ -0,0 +1,6 @@ +## Wrapper for mace.cli.run_train.main ## + +from mace.cli.preprocess_data import main + +if __name__ == "__main__": + main() diff --git a/scripts/run_checks.sh b/scripts/run_checks.sh new file mode 100644 index 00000000..bd1214a4 --- /dev/null +++ b/scripts/run_checks.sh @@ -0,0 +1,9 @@ +# Format +python -m black . +python -m isort . + +# Check +python -m pylint --rcfile=pyproject.toml mace tests scripts + +# Tests +python -m pytest tests diff --git a/scripts/run_train.py b/scripts/run_train.py new file mode 100644 index 00000000..d14952db --- /dev/null +++ b/scripts/run_train.py @@ -0,0 +1,6 @@ +## Wrapper for mace.cli.run_train.main ## + +from mace.cli.run_train import main + +if __name__ == "__main__": + main() diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 00000000..76467fda --- /dev/null +++ b/setup.cfg @@ -0,0 +1,59 @@ +[metadata] +name = mace-torch +version = attr: mace.__version__ +short_description = MACE - Fast and accurate machine learning interatomic potentials with higher order equivariant message passing. +long_description = file: README.md +long_description_content_type = text/markdown +url = https://github.com/ACEsuit/mace +classifiers = + Programming Language :: Python :: 3 + Operating System :: OS Independent + License :: OSI Approved :: MIT License + +[options] +packages = find: +python_requires = >=3.7 +install_requires = + torch>=1.12 + e3nn==0.4.4 + numpy<2.0 + opt_einsum + ase + torch-ema + prettytable + matscipy + h5py + torchmetrics + python-hostlist + configargparse + GitPython + pyYAML + tqdm + # for plotting: + matplotlib + pandas + +[options.entry_points] +console_scripts = + mace_active_learning_md = mace.cli.active_learning_md:main + mace_create_lammps_model = mace.cli.create_lammps_model:main + mace_eval_configs = mace.cli.eval_configs:main + mace_plot_train = mace.cli.plot_train:main + mace_run_train = mace.cli.run_train:main + mace_prepare_data = mace.cli.preprocess_data:main + mace_finetuning = mace.cli.fine_tuning_select:main + mace_convert_device = mace.cli.convert_device:main + mace_select_head = mace.cli.select_head:main + +[options.extras_require] +wandb = wandb +fpsample = fpsample +dev = + black + isort + mypy + pre-commit + pytest + pytest-benchmark + pylint +schedulefree = schedulefree diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_calculator.py b/tests/test_calculator.py new file mode 100644 index 00000000..6590935c --- /dev/null +++ b/tests/test_calculator.py @@ -0,0 +1,508 @@ +import os +import subprocess +import sys +from pathlib import Path + +import ase.io +import numpy as np +import pytest +import torch +from ase import build +from ase.atoms import Atoms +from ase.calculators.test import gradient_test +from ase.constraints import ExpCellFilter + +from mace.calculators import mace_mp, mace_off +from mace.calculators.mace import MACECalculator +from mace.modules.models import ScaleShiftMACE + +pytest_mace_dir = Path(__file__).parent.parent +run_train = Path(__file__).parent.parent / "mace" / "cli" / "run_train.py" + + +@pytest.fixture(scope="module", name="fitting_configs") +def fitting_configs_fixture(): + water = Atoms( + numbers=[8, 1, 1], + positions=[[0, -2.0, 0], [1, 0, 0], [0, 1, 0]], + cell=[4] * 3, + pbc=[True] * 3, + ) + fit_configs = [ + Atoms(numbers=[8], positions=[[0, 0, 0]], cell=[6] * 3), + Atoms(numbers=[1], positions=[[0, 0, 0]], cell=[6] * 3), + ] + fit_configs[0].info["REF_energy"] = 1.0 + fit_configs[0].info["config_type"] = "IsolatedAtom" + fit_configs[1].info["REF_energy"] = -0.5 + fit_configs[1].info["config_type"] = "IsolatedAtom" + + np.random.seed(5) + for _ in range(20): + c = water.copy() + c.positions += np.random.normal(0.1, size=c.positions.shape) + c.info["REF_energy"] = np.random.normal(0.1) + c.info["REF_dipole"] = np.random.normal(0.1, size=3) + c.new_array("REF_forces", np.random.normal(0.1, size=c.positions.shape)) + c.new_array("Qs", np.random.normal(0.1, size=c.positions.shape[0])) + c.info["REF_stress"] = np.random.normal(0.1, size=6) + fit_configs.append(c) + + return fit_configs + + +@pytest.fixture(scope="module", name="trained_model") +def trained_model_fixture(tmp_path_factory, fitting_configs): + _mace_params = { + "name": "MACE", + "valid_fraction": 0.05, + "energy_weight": 1.0, + "forces_weight": 10.0, + "stress_weight": 1.0, + "model": "MACE", + "hidden_irreps": "128x0e", + "r_max": 3.5, + "batch_size": 5, + "max_num_epochs": 10, + "swa": None, + "start_swa": 5, + "ema": None, + "ema_decay": 0.99, + "amsgrad": None, + "restart_latest": None, + "device": "cpu", + "seed": 5, + "loss": "stress", + "energy_key": "REF_energy", + "forces_key": "REF_forces", + "stress_key": "REF_stress", + "eval_interval": 2, + } + + tmp_path = tmp_path_factory.mktemp("run_") + + ase.io.write(tmp_path / "fit.xyz", fitting_configs) + + mace_params = _mace_params.copy() + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["train_file"] = tmp_path / "fit.xyz" + + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + + assert p.returncode == 0 + + return MACECalculator(model_paths=tmp_path / "MACE.model", device="cpu") + + +@pytest.fixture(scope="module", name="trained_equivariant_model") +def trained_model_equivariant_fixture(tmp_path_factory, fitting_configs): + _mace_params = { + "name": "MACE", + "valid_fraction": 0.05, + "energy_weight": 1.0, + "forces_weight": 10.0, + "stress_weight": 1.0, + "model": "MACE", + "hidden_irreps": "16x0e+16x1o", + "r_max": 3.5, + "batch_size": 5, + "max_num_epochs": 10, + "swa": None, + "start_swa": 5, + "ema": None, + "ema_decay": 0.99, + "amsgrad": None, + "restart_latest": None, + "device": "cpu", + "seed": 5, + "loss": "stress", + "energy_key": "REF_energy", + "forces_key": "REF_forces", + "stress_key": "REF_stress", + "eval_interval": 2, + } + + tmp_path = tmp_path_factory.mktemp("run_") + + ase.io.write(tmp_path / "fit.xyz", fitting_configs) + + mace_params = _mace_params.copy() + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["train_file"] = tmp_path / "fit.xyz" + + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + + assert p.returncode == 0 + + return MACECalculator(model_paths=tmp_path / "MACE.model", device="cpu") + + +@pytest.fixture(scope="module", name="trained_dipole_model") +def trained_dipole_fixture(tmp_path_factory, fitting_configs): + _mace_params = { + "name": "MACE", + "valid_fraction": 0.05, + "energy_weight": 1.0, + "forces_weight": 10.0, + "stress_weight": 1.0, + "model": "AtomicDipolesMACE", + "num_channels": 8, + "max_L": 2, + "r_max": 3.5, + "batch_size": 5, + "max_num_epochs": 10, + "ema": None, + "ema_decay": 0.99, + "amsgrad": None, + "restart_latest": None, + "device": "cpu", + "seed": 5, + "loss": "dipole", + "energy_key": "", + "forces_key": "", + "stress_key": "", + "dipole_key": "REF_dipole", + "error_table": "DipoleRMSE", + "eval_interval": 2, + } + + tmp_path = tmp_path_factory.mktemp("run_") + + ase.io.write(tmp_path / "fit.xyz", fitting_configs) + + mace_params = _mace_params.copy() + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["train_file"] = tmp_path / "fit.xyz" + + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + + assert p.returncode == 0 + + return MACECalculator( + model_paths=tmp_path / "MACE.model", device="cpu", model_type="DipoleMACE" + ) + + +@pytest.fixture(scope="module", name="trained_energy_dipole_model") +def trained_energy_dipole_fixture(tmp_path_factory, fitting_configs): + _mace_params = { + "name": "MACE", + "valid_fraction": 0.05, + "energy_weight": 1.0, + "forces_weight": 10.0, + "stress_weight": 1.0, + "model": "EnergyDipolesMACE", + "num_channels": 32, + "max_L": 1, + "r_max": 3.5, + "batch_size": 5, + "max_num_epochs": 10, + "ema": None, + "ema_decay": 0.99, + "amsgrad": None, + "restart_latest": None, + "device": "cpu", + "seed": 5, + "loss": "energy_forces_dipole", + "energy_key": "REF_energy", + "forces_key": "", + "stress_key": "", + "dipole_key": "REF_dipole", + "error_table": "EnergyDipoleRMSE", + "eval_interval": 2, + } + + tmp_path = tmp_path_factory.mktemp("run_") + + ase.io.write(tmp_path / "fit.xyz", fitting_configs) + + mace_params = _mace_params.copy() + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["train_file"] = tmp_path / "fit.xyz" + + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + + assert p.returncode == 0 + + return MACECalculator( + model_paths=tmp_path / "MACE.model", device="cpu", model_type="EnergyDipoleMACE" + ) + + +@pytest.fixture(scope="module", name="trained_committee") +def trained_committee_fixture(tmp_path_factory, fitting_configs): + _seeds = [5, 6, 7] + _model_paths = [] + for seed in _seeds: + _mace_params = { + "name": f"MACE{seed}", + "valid_fraction": 0.05, + "energy_weight": 1.0, + "forces_weight": 10.0, + "stress_weight": 1.0, + "model": "MACE", + "hidden_irreps": "16x0e", + "r_max": 3.5, + "batch_size": 5, + "max_num_epochs": 10, + "swa": None, + "start_swa": 5, + "ema": None, + "ema_decay": 0.99, + "amsgrad": None, + "restart_latest": None, + "device": "cpu", + "seed": seed, + "loss": "stress", + "energy_key": "REF_energy", + "forces_key": "REF_forces", + "stress_key": "REF_stress", + "eval_interval": 2, + } + + tmp_path = tmp_path_factory.mktemp(f"run{seed}_") + + ase.io.write(tmp_path / "fit.xyz", fitting_configs) + + mace_params = _mace_params.copy() + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["train_file"] = tmp_path / "fit.xyz" + + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + + assert p.returncode == 0 + + _model_paths.append(tmp_path / f"MACE{seed}.model") + + return MACECalculator(model_paths=_model_paths, device="cpu") + + +def test_calculator_node_energy(fitting_configs, trained_model): + for at in fitting_configs: + trained_model.calculate(at) + node_energies = trained_model.results["node_energy"] + batch = trained_model._atoms_to_batch(at) # pylint: disable=protected-access + node_heads = batch["head"][batch["batch"]] + num_atoms_arange = torch.arange(batch["positions"].shape[0]) + node_e0 = ( + trained_model.models[0].atomic_energies_fn(batch["node_attrs"]).detach() + ) + node_e0 = node_e0[num_atoms_arange, node_heads].cpu().numpy() + energy_via_nodes = np.sum(node_energies + node_e0) + energy = trained_model.results["energy"] + np.testing.assert_allclose(energy, energy_via_nodes, atol=1e-6) + + +def test_calculator_forces(fitting_configs, trained_model): + at = fitting_configs[2].copy() + at.calc = trained_model + + # test just forces + grads = gradient_test(at) + + assert np.allclose(grads[0], grads[1]) + + +def test_calculator_stress(fitting_configs, trained_model): + at = fitting_configs[2].copy() + at.calc = trained_model + + # test forces and stress + at_wrapped = ExpCellFilter(at) + grads = gradient_test(at_wrapped) + + assert np.allclose(grads[0], grads[1]) + + +def test_calculator_committee(fitting_configs, trained_committee): + at = fitting_configs[2].copy() + at.calc = trained_committee + + # test just forces + grads = gradient_test(at) + + assert np.allclose(grads[0], grads[1]) + + E = at.get_potential_energy() + energies = at.calc.results["energies"] + energies_var = at.calc.results["energy_var"] + forces_var = np.var(at.calc.results["forces_comm"], axis=0) + assert np.allclose(E, np.mean(energies)) + assert np.allclose(energies_var, np.var(energies)) + assert forces_var.shape == at.calc.results["forces"].shape + + +def test_calculator_from_model(fitting_configs, trained_committee): + # test single model + test_calculator_forces( + fitting_configs, + trained_model=MACECalculator(models=trained_committee.models[0], device="cpu"), + ) + + # test committee model + test_calculator_committee( + fitting_configs, + trained_committee=MACECalculator(models=trained_committee.models, device="cpu"), + ) + + +def test_calculator_dipole(fitting_configs, trained_dipole_model): + at = fitting_configs[2].copy() + at.calc = trained_dipole_model + + dip = at.get_dipole_moment() + + assert len(dip) == 3 + + +def test_calculator_energy_dipole(fitting_configs, trained_energy_dipole_model): + at = fitting_configs[2].copy() + at.calc = trained_energy_dipole_model + + grads = gradient_test(at) + dip = at.get_dipole_moment() + + assert np.allclose(grads[0], grads[1]) + assert len(dip) == 3 + + +def test_calculator_descriptor(fitting_configs, trained_equivariant_model): + at = fitting_configs[2].copy() + at.calc = trained_equivariant_model + + desc_invariant = at.calc.get_descriptors(at, invariants_only=True) + desc_single_layer = at.calc.get_descriptors(at, invariants_only=True, num_layers=1) + desc = at.calc.get_descriptors(at, invariants_only=False) + + assert desc_invariant.shape[0] == 3 + assert desc_invariant.shape[1] == 32 + assert desc_single_layer.shape[0] == 3 + assert desc_single_layer.shape[1] == 16 + assert desc.shape[0] == 3 + assert desc.shape[1] == 80 + + +def test_mace_mp(capsys: pytest.CaptureFixture): + mp_mace = mace_mp() + assert isinstance(mp_mace, MACECalculator) + assert mp_mace.model_type == "MACE" + assert len(mp_mace.models) == 1 + assert isinstance(mp_mace.models[0], ScaleShiftMACE) + + _, stderr = capsys.readouterr() + assert stderr == "" + + +def test_mace_off(): + mace_off_model = mace_off(model="small", device="cpu") + assert isinstance(mace_off_model, MACECalculator) + assert mace_off_model.model_type == "MACE" + assert len(mace_off_model.models) == 1 + assert isinstance(mace_off_model.models[0], ScaleShiftMACE) + + atoms = build.molecule("H2O") + atoms.calc = mace_off_model + + E = atoms.get_potential_energy() + + assert np.allclose(E, -2081.116128586803, atol=1e-9) diff --git a/tests/test_cg.py b/tests/test_cg.py new file mode 100644 index 00000000..36b119b9 --- /dev/null +++ b/tests/test_cg.py @@ -0,0 +1,12 @@ +from e3nn import o3 + +from mace.tools import cg + + +def test_U_matrix(): + irreps_in = o3.Irreps("1x0e + 1x1o + 1x2e") + irreps_out = o3.Irreps("1x0e + 1x1o") + u_matrix = cg.U_matrix_real( + irreps_in=irreps_in, irreps_out=irreps_out, correlation=3 + )[-1] + assert u_matrix.shape == (3, 9, 9, 9, 21) diff --git a/tests/test_compile.py b/tests/test_compile.py new file mode 100644 index 00000000..d7d585e8 --- /dev/null +++ b/tests/test_compile.py @@ -0,0 +1,154 @@ +import os +from functools import wraps +from typing import Callable + +import numpy as np +import pytest +import torch +import torch.nn.functional as F +from e3nn import o3 +from torch.testing import assert_close + +from mace import data, modules, tools +from mace.tools import compile as mace_compile +from mace.tools import torch_geometric + +table = tools.AtomicNumberTable([6]) +atomic_energies = np.array([1.0], dtype=float) +cutoff = 5.0 + + +def create_mace(device: str, seed: int = 1702): + torch_geometric.seed_everything(seed) + + model_config = { + "r_max": cutoff, + "num_bessel": 8, + "num_polynomial_cutoff": 6, + "max_ell": 3, + "interaction_cls": modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + "interaction_cls_first": modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + "num_interactions": 2, + "num_elements": 1, + "hidden_irreps": o3.Irreps("128x0e + 128x1o"), + "MLP_irreps": o3.Irreps("16x0e"), + "gate": F.silu, + "atomic_energies": atomic_energies, + "avg_num_neighbors": 8, + "atomic_numbers": table.zs, + "correlation": 3, + "radial_type": "bessel", + "atomic_inter_scale": 1.0, + "atomic_inter_shift": 0.0, + } + model = modules.ScaleShiftMACE(**model_config) + return model.to(device) + + +def create_batch(device: str): + from ase import build + + size = 2 + atoms = build.bulk("C", "diamond", a=3.567, cubic=True) + atoms_list = [atoms.repeat((size, size, size))] + print("Number of atoms", len(atoms_list[0])) + + configs = [data.config_from_atoms(atoms) for atoms in atoms_list] + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[ + data.AtomicData.from_config(config, z_table=table, cutoff=cutoff) + for config in configs + ], + batch_size=1, + shuffle=False, + drop_last=False, + ) + batch = next(iter(data_loader)) + batch = batch.to(device) + batch = batch.to_dict() + return batch + + +def time_func(func: Callable): + @wraps(func) + def wrapper(*args, **kwargs): + torch._inductor.cudagraph_mark_step_begin() # pylint: disable=W0212 + outputs = func(*args, **kwargs) + torch.cuda.synchronize() + return outputs + + return wrapper + + +@pytest.fixture(params=[torch.float32, torch.float64], ids=["fp32", "fp64"]) +def default_dtype(request): + with tools.torch_tools.default_dtype(request.param): + yield torch.get_default_dtype() + + +# skip if on windows +@pytest.mark.skipif(os.name == "nt", reason="Not supported on Windows") +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_mace(device, default_dtype): # pylint: disable=W0621 + print(f"using default dtype = {default_dtype}") + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip(reason="cuda is not available") + + model_defaults = create_mace(device) + tmp_model = mace_compile.prepare(create_mace)(device) + model_compiled = torch.compile(tmp_model, mode="default") + + batch = create_batch(device) + output1 = model_defaults(batch, training=True) + output2 = model_compiled(batch, training=True) + assert_close(output1["energy"], output2["energy"]) + assert_close(output1["forces"], output2["forces"]) + + +@pytest.mark.skipif(os.name == "nt", reason="Not supported on Windows") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available") +def test_eager_benchmark(benchmark, default_dtype): # pylint: disable=W0621 + print(f"using default dtype = {default_dtype}") + batch = create_batch("cuda") + model = create_mace("cuda") + model = time_func(model) + benchmark(model, batch, training=True) + + +@pytest.mark.skipif(os.name == "nt", reason="Not supported on Windows") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available") +@pytest.mark.parametrize("compile_mode", ["default", "reduce-overhead", "max-autotune"]) +@pytest.mark.parametrize("enable_amp", [False, True], ids=["fp32", "mixed"]) +def test_compile_benchmark(benchmark, compile_mode, enable_amp): + if enable_amp: + pytest.skip(reason="autocast compiler assertion aten.slice_scatter.default") + + with tools.torch_tools.default_dtype(torch.float32): + batch = create_batch("cuda") + torch.compiler.reset() + model = mace_compile.prepare(create_mace)("cuda") + model = torch.compile(model, mode=compile_mode) + model = time_func(model) + + with torch.autocast("cuda", enabled=enable_amp): + benchmark(model, batch, training=True) + + +@pytest.mark.skipif(os.name == "nt", reason="Not supported on Windows") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available") +def test_graph_breaks(): + import torch._dynamo as dynamo + + batch = create_batch("cuda") + model = mace_compile.prepare(create_mace)("cuda") + explanation = dynamo.explain(model)(batch, training=False) + + # these clutter the output but might be useful for investigating graph breaks + explanation.ops_per_graph = None + explanation.out_guards = None + print(explanation) + assert explanation.graph_break_count == 0 diff --git a/tests/test_data.py b/tests/test_data.py new file mode 100644 index 00000000..9e0c49e6 --- /dev/null +++ b/tests/test_data.py @@ -0,0 +1,207 @@ +from copy import deepcopy +from pathlib import Path + +import ase.build +import h5py +import numpy as np +import torch + +from mace.data import ( + AtomicData, + Configuration, + HDF5Dataset, + config_from_atoms, + get_neighborhood, + save_configurations_as_HDF5, +) +from mace.tools import AtomicNumberTable, torch_geometric + +mace_path = Path(__file__).parent.parent + + +class TestAtomicData: + config = Configuration( + atomic_numbers=np.array([8, 1, 1]), + positions=np.array( + [ + [0.0, -2.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + ] + ), + forces=np.array( + [ + [0.0, -1.3, 0.0], + [1.0, 0.2, 0.0], + [0.0, 1.1, 0.3], + ] + ), + energy=-1.5, + ) + config_2 = deepcopy(config) + config_2.positions = config.positions + 0.01 + + table = AtomicNumberTable([1, 8]) + + def test_atomic_data(self): + data = AtomicData.from_config(self.config, z_table=self.table, cutoff=3.0) + + assert data.edge_index.shape == (2, 4) + assert data.forces.shape == (3, 3) + assert data.node_attrs.shape == (3, 2) + + def test_data_loader(self): + data1 = AtomicData.from_config(self.config, z_table=self.table, cutoff=3.0) + data2 = AtomicData.from_config(self.config, z_table=self.table, cutoff=3.0) + + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[data1, data2], + batch_size=2, + shuffle=True, + drop_last=False, + ) + + for batch in data_loader: + assert batch.batch.shape == (6,) + assert batch.edge_index.shape == (2, 8) + assert batch.shifts.shape == (8, 3) + assert batch.positions.shape == (6, 3) + assert batch.node_attrs.shape == (6, 2) + assert batch.energy.shape == (2,) + assert batch.forces.shape == (6, 3) + + def test_to_atomic_data_dict(self): + data1 = AtomicData.from_config(self.config, z_table=self.table, cutoff=3.0) + data2 = AtomicData.from_config(self.config, z_table=self.table, cutoff=3.0) + + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[data1, data2], + batch_size=2, + shuffle=True, + drop_last=False, + ) + for batch in data_loader: + batch_dict = batch.to_dict() + assert batch_dict["batch"].shape == (6,) + assert batch_dict["edge_index"].shape == (2, 8) + assert batch_dict["shifts"].shape == (8, 3) + assert batch_dict["positions"].shape == (6, 3) + assert batch_dict["node_attrs"].shape == (6, 2) + assert batch_dict["energy"].shape == (2,) + assert batch_dict["forces"].shape == (6, 3) + + def test_hdf5_dataloader(self): + datasets = [self.config, self.config_2] * 5 + # get path of the mace package + with h5py.File(str(mace_path) + "test.h5", "w") as f: + save_configurations_as_HDF5(datasets, 0, f) + train_dataset = HDF5Dataset( + str(mace_path) + "test.h5", z_table=self.table, r_max=3.0 + ) + train_loader = torch_geometric.dataloader.DataLoader( + dataset=train_dataset, + batch_size=2, + shuffle=False, + drop_last=False, + ) + batch_count = 0 + for batch in train_loader: + batch_count += 1 + assert batch.batch.shape == (6,) + assert batch.edge_index.shape == (2, 8) + assert batch.shifts.shape == (8, 3) + assert batch.positions.shape == (6, 3) + assert batch.node_attrs.shape == (6, 2) + assert batch.energy.shape == (2,) + assert batch.forces.shape == (6, 3) + print(batch_count, len(train_loader), len(train_dataset)) + assert batch_count == len(train_loader) == len(train_dataset) / 2 + train_loader_direct = torch_geometric.dataloader.DataLoader( + dataset=[ + AtomicData.from_config(config, z_table=self.table, cutoff=3.0) + for config in datasets + ], + batch_size=2, + shuffle=False, + drop_last=False, + ) + for batch_direct, batch in zip(train_loader_direct, train_loader): + assert torch.all(batch_direct.edge_index == batch.edge_index) + assert torch.all(batch_direct.shifts == batch.shifts) + assert torch.all(batch_direct.positions == batch.positions) + assert torch.all(batch_direct.node_attrs == batch.node_attrs) + assert torch.all(batch_direct.energy == batch.energy) + assert torch.all(batch_direct.forces == batch.forces) + + +class TestNeighborhood: + def test_basic(self): + positions = np.array( + [ + [-1.0, 0.0, 0.0], + [+0.0, 0.0, 0.0], + [+1.0, 0.0, 0.0], + ] + ) + + indices, shifts, unit_shifts, _ = get_neighborhood(positions, cutoff=1.5) + assert indices.shape == (2, 4) + assert shifts.shape == (4, 3) + assert unit_shifts.shape == (4, 3) + + def test_signs(self): + positions = np.array( + [ + [+0.5, 0.5, 0.0], + [+1.0, 1.0, 0.0], + ] + ) + + cell = np.array([[2.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) + edge_index, shifts, unit_shifts, _ = get_neighborhood( + positions, cutoff=3.5, pbc=(True, False, False), cell=cell + ) + num_edges = 10 + assert edge_index.shape == (2, num_edges) + assert shifts.shape == (num_edges, 3) + assert unit_shifts.shape == (num_edges, 3) + + +# Based on mir-group/nequip +def test_periodic_edge(): + atoms = ase.build.bulk("Cu", "fcc") + dist = np.linalg.norm(atoms.cell[0]).item() + config = config_from_atoms(atoms) + edge_index, shifts, _, _ = get_neighborhood( + config.positions, cutoff=1.05 * dist, pbc=(True, True, True), cell=config.cell + ) + sender, receiver = edge_index + vectors = ( + config.positions[receiver] - config.positions[sender] + shifts + ) # [n_edges, 3] + assert vectors.shape == (12, 3) # 12 neighbors in close-packed bulk + assert np.allclose( + np.linalg.norm(vectors, axis=-1), + dist, + ) + + +def test_half_periodic(): + atoms = ase.build.fcc111("Al", size=(3, 3, 1), vacuum=0.0) + assert all(atoms.pbc == (True, True, False)) + config = config_from_atoms(atoms) # first shell dist is 2.864A + edge_index, shifts, _, _ = get_neighborhood( + config.positions, cutoff=2.9, pbc=(True, True, False), cell=config.cell + ) + sender, receiver = edge_index + vectors = ( + config.positions[receiver] - config.positions[sender] + shifts + ) # [n_edges, 3] + # Check number of neighbors: + _, neighbor_count = np.unique(edge_index[0], return_counts=True) + assert (neighbor_count == 6).all() # 6 neighbors + # Check not periodic in z + assert np.allclose( + vectors[:, 2], + np.zeros(vectors.shape[0]), + ) diff --git a/tests/test_foundations.py b/tests/test_foundations.py new file mode 100644 index 00000000..44879395 --- /dev/null +++ b/tests/test_foundations.py @@ -0,0 +1,447 @@ +from pathlib import Path + +import numpy as np +import pytest +import torch +import torch.nn.functional +from ase.build import molecule +from e3nn import o3 +from scipy.spatial.transform import Rotation as R + +from mace import data, modules, tools +from mace.calculators import mace_mp, mace_off +from mace.tools import torch_geometric +from mace.tools.finetuning_utils import load_foundations_elements +from mace.tools.scripts_utils import extract_config_mace_model, remove_pt_head +from mace.tools.utils import AtomicNumberTable + +MODEL_PATH = ( + Path(__file__).parent.parent + / "mace" + / "calculators" + / "foundations_models" + / "2023-12-03-mace-mp.model" +) + +torch.set_default_dtype(torch.float64) +config = data.Configuration( + atomic_numbers=molecule("H2COH").numbers, + positions=molecule("H2COH").positions, + forces=molecule("H2COH").positions, + energy=-1.5, + charges=molecule("H2COH").numbers, + dipole=np.array([-1.5, 1.5, 2.0]), +) +# Created the rotated environment +rot = R.from_euler("z", 60, degrees=True).as_matrix() +positions_rotated = np.array(rot @ config.positions.T).T +config_rotated = data.Configuration( + atomic_numbers=molecule("H2COH").numbers, + positions=positions_rotated, + forces=molecule("H2COH").positions, + energy=-1.5, + charges=molecule("H2COH").numbers, + dipole=np.array([-1.5, 1.5, 2.0]), +) +table = tools.AtomicNumberTable([1, 6, 8]) +atomic_energies = np.array([0.0, 0.0, 0.0], dtype=float) + + +# @pytest.skip("Problem with the float type", allow_module_level=True) +def test_foundations(): + # Create MACE model + model_config = dict( + r_max=6, + num_bessel=10, + num_polynomial_cutoff=5, + max_ell=3, + interaction_cls=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + interaction_cls_first=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + num_interactions=2, + num_elements=3, + hidden_irreps=o3.Irreps("128x0e"), + MLP_irreps=o3.Irreps("16x0e"), + gate=torch.nn.functional.silu, + atomic_energies=atomic_energies, + avg_num_neighbors=3, + atomic_numbers=table.zs, + correlation=3, + radial_type="bessel", + atomic_inter_scale=0.1, + atomic_inter_shift=0.0, + ) + model = modules.ScaleShiftMACE(**model_config) + calc = mace_mp( + model="small", + device="cpu", + default_dtype="float64", + ) + model_foundations = calc.models[0] + model_loaded = load_foundations_elements( + model, + model_foundations, + table=table, + load_readout=True, + use_shift=False, + max_L=0, + ) + atomic_data = data.AtomicData.from_config(config, z_table=table, cutoff=6.0) + atomic_data2 = data.AtomicData.from_config( + config_rotated, z_table=table, cutoff=6.0 + ) + + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[atomic_data, atomic_data2], + batch_size=2, + shuffle=True, + drop_last=False, + ) + batch = next(iter(data_loader)) + forces_loaded = model_loaded(batch)["forces"] + forces = model(batch)["forces"] + assert torch.allclose(forces, forces_loaded) + + +def test_multi_reference(): + config_multi = data.Configuration( + atomic_numbers=molecule("H2COH").numbers, + positions=molecule("H2COH").positions, + forces=molecule("H2COH").positions, + energy=-1.5, + charges=molecule("H2COH").numbers, + dipole=np.array([-1.5, 1.5, 2.0]), + head="MP2", + ) + table_multi = tools.AtomicNumberTable([1, 6, 8]) + atomic_energies_multi = np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], dtype=float) + + # Create MACE model + model_config = dict( + r_max=6, + num_bessel=10, + num_polynomial_cutoff=5, + max_ell=3, + interaction_cls=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + interaction_cls_first=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + num_interactions=2, + num_elements=3, + hidden_irreps=o3.Irreps("128x0e + 128x1o"), + MLP_irreps=o3.Irreps("16x0e"), + gate=torch.nn.functional.silu, + atomic_energies=atomic_energies_multi, + avg_num_neighbors=61, + atomic_numbers=table.zs, + correlation=3, + radial_type="bessel", + atomic_inter_scale=[1.0, 1.0], + atomic_inter_shift=[0.0, 0.0], + heads=["MP2", "DFT"], + ) + model = modules.ScaleShiftMACE(**model_config) + calc_foundation = mace_mp(device="cpu", default_dtype="float64") + model_loaded = load_foundations_elements( + model, + calc_foundation.models[0], + table=table, + load_readout=True, + use_shift=False, + max_L=1, + ) + atomic_data = data.AtomicData.from_config( + config_multi, z_table=table_multi, cutoff=6.0, heads=["MP2", "DFT"] + ) + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[atomic_data, atomic_data], + batch_size=2, + shuffle=True, + drop_last=False, + ) + batch = next(iter(data_loader)) + forces_loaded = model_loaded(batch)["forces"] + calc_foundation = mace_mp(device="cpu", default_dtype="float64") + atoms = molecule("H2COH") + atoms.info["head"] = "MP2" + atoms.calc = calc_foundation + forces = atoms.get_forces() + assert np.allclose( + forces, forces_loaded.detach().numpy()[:5, :], atol=1e-5, rtol=1e-5 + ) + + +@pytest.mark.parametrize( + "model", + [ + mace_mp(model="small", device="cpu", default_dtype="float64").models[0], + mace_mp(model="medium", device="cpu", default_dtype="float64").models[0], + mace_mp(model="large", device="cpu", default_dtype="float64").models[0], + mace_mp(model=MODEL_PATH, device="cpu", default_dtype="float64").models[0], + mace_off(model="small", device="cpu", default_dtype="float64").models[0], + mace_off(model="medium", device="cpu", default_dtype="float64").models[0], + mace_off(model="large", device="cpu", default_dtype="float64").models[0], + mace_off(model=MODEL_PATH, device="cpu", default_dtype="float64").models[0], + ], +) +def test_extract_config(model): + assert isinstance(model, modules.ScaleShiftMACE) + model_copy = modules.ScaleShiftMACE(**extract_config_mace_model(model)) + model_copy.load_state_dict(model.state_dict()) + z_table = AtomicNumberTable([int(z) for z in model.atomic_numbers]) + atomic_data = data.AtomicData.from_config(config, z_table=z_table, cutoff=6.0) + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[atomic_data, atomic_data], + batch_size=2, + shuffle=True, + drop_last=False, + ) + batch = next(iter(data_loader)) + output = model(batch) + output_copy = model_copy(batch) + # assert all items of the output dicts are equal + for key in output.keys(): + if isinstance(output[key], torch.Tensor): + assert torch.allclose(output[key], output_copy[key], atol=1e-5) + + +def test_remove_pt_head(): + # Set up test data + torch.manual_seed(42) + atomic_energies_pt_head = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=float) + z_table = AtomicNumberTable([1, 8]) # H and O + + # Create multihead model + model_config = { + "r_max": 5.0, + "num_bessel": 8, + "num_polynomial_cutoff": 5, + "max_ell": 2, + "interaction_cls": modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + "interaction_cls_first": modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + "num_interactions": 2, + "num_elements": len(z_table), + "hidden_irreps": o3.Irreps("32x0e + 32x1o"), + "MLP_irreps": o3.Irreps("16x0e"), + "gate": torch.nn.functional.silu, + "atomic_energies": atomic_energies_pt_head, + "avg_num_neighbors": 8, + "atomic_numbers": z_table.zs, + "correlation": 3, + "heads": ["pt_head", "DFT"], + "atomic_inter_scale": [1.0, 1.0], + "atomic_inter_shift": [0.0, 0.1], + } + + model = modules.ScaleShiftMACE(**model_config) + + # Create test molecule + mol = molecule("H2O") + config_pt_head = data.Configuration( + atomic_numbers=mol.numbers, + positions=mol.positions, + energy=1.0, + forces=np.random.randn(len(mol), 3), + head="DFT", + ) + atomic_data = data.AtomicData.from_config( + config_pt_head, z_table=z_table, cutoff=5.0, heads=["pt_head", "DFT"] + ) + dataloader = torch_geometric.dataloader.DataLoader( + dataset=[atomic_data], batch_size=1, shuffle=False + ) + batch = next(iter(dataloader)) + # Test original mode + output_orig = model(batch) + + # Convert to single head model + new_model = remove_pt_head(model, head_to_keep="DFT") + + # Basic structure tests + assert len(new_model.heads) == 1 + assert new_model.heads[0] == "DFT" + assert new_model.atomic_energies_fn.atomic_energies.shape[0] == 1 + assert len(torch.atleast_1d(new_model.scale_shift.scale)) == 1 + assert len(torch.atleast_1d(new_model.scale_shift.shift)) == 1 + + # Test output consistency + atomic_data = data.AtomicData.from_config( + config_pt_head, z_table=z_table, cutoff=5.0, heads=["DFT"] + ) + dataloader = torch_geometric.dataloader.DataLoader( + dataset=[atomic_data], batch_size=1, shuffle=False + ) + batch = next(iter(dataloader)) + output_new = new_model(batch) + torch.testing.assert_close( + output_orig["energy"], output_new["energy"], rtol=1e-5, atol=1e-5 + ) + torch.testing.assert_close( + output_orig["forces"], output_new["forces"], rtol=1e-5, atol=1e-5 + ) + + +def test_remove_pt_head_multihead(): + # Set up test data + torch.manual_seed(42) + atomic_energies_pt_head = np.array( + [ + [1.0, 2.0], # H energies for each head + [3.0, 4.0], # O energies for each head + ] + * 2 + ) + z_table = AtomicNumberTable([1, 8]) # H and O + + # Create multihead model + model_config = { + "r_max": 5.0, + "num_bessel": 8, + "num_polynomial_cutoff": 5, + "max_ell": 2, + "interaction_cls": modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + "interaction_cls_first": modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + "num_interactions": 2, + "num_elements": len(z_table), + "hidden_irreps": o3.Irreps("32x0e + 32x1o"), + "MLP_irreps": o3.Irreps("16x0e"), + "gate": torch.nn.functional.silu, + "atomic_energies": atomic_energies_pt_head, + "avg_num_neighbors": 8, + "atomic_numbers": z_table.zs, + "correlation": 3, + "heads": ["pt_head", "DFT", "MP2", "CCSD"], + "atomic_inter_scale": [1.0, 1.0, 1.0, 1.0], + "atomic_inter_shift": [0.0, 0.1, 0.2, 0.3], + } + + model = modules.ScaleShiftMACE(**model_config) + + # Create test configurations for each head + mol = molecule("H2O") + configs = {} + atomic_datas = {} + dataloaders = {} + original_outputs = {} + + # First get outputs from original model for each head + for head in model.heads: + config_pt_head = data.Configuration( + atomic_numbers=mol.numbers, + positions=mol.positions, + energy=1.0, + forces=np.random.randn(len(mol), 3), + head=head, + ) + configs[head] = config_pt_head + + atomic_data = data.AtomicData.from_config( + config_pt_head, z_table=z_table, cutoff=5.0, heads=model.heads + ) + atomic_datas[head] = atomic_data + + dataloader = torch_geometric.dataloader.DataLoader( + dataset=[atomic_data], batch_size=1, shuffle=False + ) + dataloaders[head] = dataloader + + batch = next(iter(dataloader)) + output = model(batch) + original_outputs[head] = output + + # Now test each head separately + for i, head in enumerate(model.heads): + # Convert to single head model + new_model = remove_pt_head(model, head_to_keep=head) + + # Basic structure tests + assert len(new_model.heads) == 1, f"Failed for head {head}" + assert new_model.heads[0] == head, f"Failed for head {head}" + assert ( + new_model.atomic_energies_fn.atomic_energies.shape[0] == 1 + ), f"Failed for head {head}" + assert ( + len(torch.atleast_1d(new_model.scale_shift.scale)) == 1 + ), f"Failed for head {head}" + assert ( + len(torch.atleast_1d(new_model.scale_shift.shift)) == 1 + ), f"Failed for head {head}" + + # Verify scale and shift values + assert torch.allclose( + new_model.scale_shift.scale, model.scale_shift.scale[i : i + 1] + ), f"Failed for head {head}" + assert torch.allclose( + new_model.scale_shift.shift, model.scale_shift.shift[i : i + 1] + ), f"Failed for head {head}" + + # Test output consistency + single_head_data = data.AtomicData.from_config( + configs[head], z_table=z_table, cutoff=5.0, heads=[head] + ) + single_head_loader = torch_geometric.dataloader.DataLoader( + dataset=[single_head_data], batch_size=1, shuffle=False + ) + batch = next(iter(single_head_loader)) + new_output = new_model(batch) + + # Compare outputs + print( + original_outputs[head]["energy"], + new_output["energy"], + ) + torch.testing.assert_close( + original_outputs[head]["energy"], + new_output["energy"], + rtol=1e-5, + atol=1e-5, + msg=f"Energy mismatch for head {head}", + ) + torch.testing.assert_close( + original_outputs[head]["forces"], + new_output["forces"], + rtol=1e-5, + atol=1e-5, + msg=f"Forces mismatch for head {head}", + ) + + # Test error cases + with pytest.raises(ValueError, match="Head non_existent not found in model"): + remove_pt_head(model, head_to_keep="non_existent") + + # Test default behavior (first non-PT head) + default_model = remove_pt_head(model) + assert default_model.heads[0] == "DFT" + + # Additional test: check if each model's computation graph is independent + models = {head: remove_pt_head(model, head_to_keep=head) for head in model.heads} + results = {} + + for head, head_model in models.items(): + single_head_data = data.AtomicData.from_config( + configs[head], z_table=z_table, cutoff=5.0, heads=[head] + ) + single_head_loader = torch_geometric.dataloader.DataLoader( + dataset=[single_head_data], batch_size=1, shuffle=False + ) + batch = next(iter(single_head_loader)) + results[head] = head_model(batch) + + # Verify each model produces different outputs + energies = torch.stack([results[head]["energy"] for head in model.heads]) + assert not torch.allclose( + energies[0], energies[1], rtol=1e-3 + ), "Different heads should produce different outputs" diff --git a/tests/test_hessian.py b/tests/test_hessian.py new file mode 100644 index 00000000..53457335 --- /dev/null +++ b/tests/test_hessian.py @@ -0,0 +1,54 @@ +import numpy as np +import pytest +from ase.build import fcc111 + +from mace.calculators import mace_mp + + +@pytest.fixture(name="setup_calculator_") +def setup_calculator(): + calc = mace_mp( + model="medium", dispersion=False, default_dtype="float64", device="cpu" + ) + return calc + + +@pytest.fixture(name="setup_structure_") +def setup_structure(setup_calculator_): + initial = fcc111("Pt", size=(4, 4, 1), vacuum=10.0, orthogonal=True) + initial.calc = setup_calculator_ + return initial + + +def test_potential_energy_and_hessian(setup_structure_): + initial = setup_structure_ + h_autograd = initial.calc.get_hessian(atoms=initial) + assert h_autograd.shape == (len(initial) * 3, len(initial), 3) + + +def test_finite_difference_hessian(setup_structure_): + initial = setup_structure_ + indicies = list(range(len(initial))) + delta, ndim = 1e-4, 3 + hessian = np.zeros((len(indicies) * ndim, len(indicies) * ndim)) + atoms_h = initial.copy() + for i, index in enumerate(indicies): + for j in range(ndim): + atoms_i = atoms_h.copy() + atoms_i.positions[index, j] += delta + atoms_i.calc = initial.calc + forces_i = atoms_i.get_forces() + + atoms_j = atoms_h.copy() + atoms_j.positions[index, j] -= delta + atoms_j.calc = initial.calc + forces_j = atoms_j.get_forces() + + hessian[:, i * ndim + j] = -(forces_i - forces_j)[indicies].flatten() / ( + 2 * delta + ) + + hessian = hessian.reshape((-1, len(initial), 3)) + h_autograd = initial.calc.get_hessian(atoms=initial) + is_close = np.allclose(h_autograd, hessian, atol=1e-6) + assert is_close diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 00000000..8e8c60da --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,251 @@ +import numpy as np +import torch +import torch.nn.functional +from e3nn import o3 +from e3nn.util import jit +from scipy.spatial.transform import Rotation as R + +from mace import data, modules, tools +from mace.tools import torch_geometric + +torch.set_default_dtype(torch.float64) +config = data.Configuration( + atomic_numbers=np.array([8, 1, 1]), + positions=np.array( + [ + [0.0, -2.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + ] + ), + forces=np.array( + [ + [0.0, -1.3, 0.0], + [1.0, 0.2, 0.0], + [0.0, 1.1, 0.3], + ] + ), + energy=-1.5, + charges=np.array([-2.0, 1.0, 1.0]), + dipole=np.array([-1.5, 1.5, 2.0]), +) +# Created the rotated environment +rot = R.from_euler("z", 60, degrees=True).as_matrix() +positions_rotated = np.array(rot @ config.positions.T).T +config_rotated = data.Configuration( + atomic_numbers=np.array([8, 1, 1]), + positions=positions_rotated, + forces=np.array( + [ + [0.0, -1.3, 0.0], + [1.0, 0.2, 0.0], + [0.0, 1.1, 0.3], + ] + ), + energy=-1.5, + charges=np.array([-2.0, 1.0, 1.0]), + dipole=np.array([-1.5, 1.5, 2.0]), +) +table = tools.AtomicNumberTable([1, 8]) +atomic_energies = np.array([1.0, 3.0], dtype=float) + + +def test_mace(): + # Create MACE model + model_config = dict( + r_max=5, + num_bessel=8, + num_polynomial_cutoff=6, + max_ell=2, + interaction_cls=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + interaction_cls_first=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + num_interactions=5, + num_elements=2, + hidden_irreps=o3.Irreps("32x0e + 32x1o"), + MLP_irreps=o3.Irreps("16x0e"), + gate=torch.nn.functional.silu, + atomic_energies=atomic_energies, + avg_num_neighbors=8, + atomic_numbers=table.zs, + correlation=3, + radial_type="bessel", + ) + model = modules.MACE(**model_config) + model_compiled = jit.compile(model) + + atomic_data = data.AtomicData.from_config(config, z_table=table, cutoff=3.0) + atomic_data2 = data.AtomicData.from_config( + config_rotated, z_table=table, cutoff=3.0 + ) + + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[atomic_data, atomic_data2], + batch_size=2, + shuffle=True, + drop_last=False, + ) + batch = next(iter(data_loader)) + output1 = model(batch.to_dict(), training=True) + output2 = model_compiled(batch.to_dict(), training=True) + assert torch.allclose(output1["energy"][0], output2["energy"][0]) + assert torch.allclose(output2["energy"][0], output2["energy"][1]) + + +def test_dipole_mace(): + # create dipole MACE model + model_config = dict( + r_max=5, + num_bessel=8, + num_polynomial_cutoff=5, + max_ell=2, + interaction_cls=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + interaction_cls_first=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + num_interactions=2, + num_elements=2, + hidden_irreps=o3.Irreps("16x0e + 16x1o + 16x2e"), + MLP_irreps=o3.Irreps("16x0e"), + gate=torch.nn.functional.silu, + atomic_energies=None, + avg_num_neighbors=3, + atomic_numbers=table.zs, + correlation=3, + radial_type="gaussian", + ) + model = modules.AtomicDipolesMACE(**model_config) + + atomic_data = data.AtomicData.from_config(config, z_table=table, cutoff=3.0) + atomic_data2 = data.AtomicData.from_config( + config_rotated, z_table=table, cutoff=3.0 + ) + + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[atomic_data, atomic_data2], + batch_size=2, + shuffle=False, + drop_last=False, + ) + batch = next(iter(data_loader)) + output = model( + batch, + training=True, + ) + # sanity check of dipoles being the right shape + assert output["dipole"][0].unsqueeze(0).shape == atomic_data.dipole.shape + # test equivariance of output dipoles + assert np.allclose( + np.array(rot @ output["dipole"][0].detach().numpy()), + output["dipole"][1].detach().numpy(), + ) + + +def test_energy_dipole_mace(): + # create dipole MACE model + model_config = dict( + r_max=5, + num_bessel=8, + num_polynomial_cutoff=5, + max_ell=2, + interaction_cls=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + interaction_cls_first=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + num_interactions=2, + num_elements=2, + hidden_irreps=o3.Irreps("16x0e + 16x1o + 16x2e"), + MLP_irreps=o3.Irreps("16x0e"), + gate=torch.nn.functional.silu, + atomic_energies=atomic_energies, + avg_num_neighbors=3, + atomic_numbers=table.zs, + correlation=3, + ) + model = modules.EnergyDipolesMACE(**model_config) + + atomic_data = data.AtomicData.from_config(config, z_table=table, cutoff=3.0) + atomic_data2 = data.AtomicData.from_config( + config_rotated, z_table=table, cutoff=3.0 + ) + + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[atomic_data, atomic_data2], + batch_size=2, + shuffle=False, + drop_last=False, + ) + batch = next(iter(data_loader)) + output = model( + batch, + training=True, + ) + # sanity check of dipoles being the right shape + assert output["dipole"][0].unsqueeze(0).shape == atomic_data.dipole.shape + # test energy is invariant + assert torch.allclose(output["energy"][0], output["energy"][1]) + # test equivariance of output dipoles + assert np.allclose( + np.array(rot @ output["dipole"][0].detach().numpy()), + output["dipole"][1].detach().numpy(), + ) + + +def test_mace_multi_reference(): + atomic_energies_multi = np.array([[1.0, 3.0], [0.0, 0.0]], dtype=float) + model_config = dict( + r_max=5, + num_bessel=8, + num_polynomial_cutoff=6, + max_ell=3, + interaction_cls=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + interaction_cls_first=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + num_interactions=2, + num_elements=2, + hidden_irreps=o3.Irreps("96x0e + 96x1o"), + MLP_irreps=o3.Irreps("16x0e"), + gate=torch.nn.functional.silu, + atomic_energies=atomic_energies_multi, + avg_num_neighbors=8, + atomic_numbers=table.zs, + distance_transform=True, + pair_repulsion=True, + correlation=3, + heads=["Default", "dft"], + # radial_type="chebyshev", + atomic_inter_scale=[1.0, 1.0], + atomic_inter_shift=[0.0, 0.1], + ) + model = modules.ScaleShiftMACE(**model_config) + model_compiled = jit.compile(model) + config.head = "Default" + config_rotated.head = "dft" + atomic_data = data.AtomicData.from_config( + config, z_table=table, cutoff=3.0, heads=["Default", "dft"] + ) + atomic_data2 = data.AtomicData.from_config( + config_rotated, z_table=table, cutoff=3.0, heads=["Default", "dft"] + ) + + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[atomic_data, atomic_data2], + batch_size=2, + shuffle=True, + drop_last=False, + ) + batch = next(iter(data_loader)) + output1 = model(batch.to_dict(), training=True) + output2 = model_compiled(batch.to_dict(), training=True) + assert torch.allclose(output1["energy"][0], output2["energy"][0]) + assert output2["energy"].shape[0] == 2 diff --git a/tests/test_modules.py b/tests/test_modules.py new file mode 100644 index 00000000..b99d7d6d --- /dev/null +++ b/tests/test_modules.py @@ -0,0 +1,249 @@ +import numpy as np +import pytest +import torch +import torch.nn.functional +from e3nn import o3 + +from mace.data import AtomicData, Configuration +from mace.modules import ( + AtomicEnergiesBlock, + BesselBasis, + PolynomialCutoff, + SymmetricContraction, + WeightedEnergyForcesLoss, + WeightedHuberEnergyForcesStressLoss, + compute_mean_rms_energy_forces, + compute_statistics, +) +from mace.tools import AtomicNumberTable, scatter, to_numpy, torch_geometric +from mace.tools.scripts_utils import dict_to_array + + +@pytest.fixture(name="config") +def _config(): + return Configuration( + atomic_numbers=np.array([8, 1, 1]), + positions=np.array( + [ + [0.0, -2.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + ] + ), + forces=np.array( + [ + [0.0, -1.3, 0.0], + [1.0, 0.2, 0.0], + [0.0, 1.1, 0.3], + ] + ), + energy=-1.5, + stress=np.array([1.0, 0.0, 0.5, 0.0, -1.0, 0.0]), + ) + + +@pytest.fixture(name="table") +def _table(): + return AtomicNumberTable([1, 8]) + + +@pytest.fixture(name="config1") +def _config1(): + return Configuration( + atomic_numbers=np.array([8, 1, 1]), + positions=np.array( + [ + [0.0, -2.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + ] + ), + forces=np.array( + [ + [0.0, -1.3, 0.0], + [1.0, 0.2, 0.0], + [0.0, 1.1, 0.3], + ] + ), + energy=-1.5, + head="DFT", + ) + + +@pytest.fixture(name="config2") +def _config2(): + return Configuration( + atomic_numbers=np.array([8, 1, 1]), + positions=np.array( + [ + [0.1, -1.9, 0.1], + [1.1, 0.1, 0.1], + [0.1, 1.1, 0.1], + ] + ), + forces=np.array( + [ + [0.1, -1.2, 0.1], + [1.1, 0.3, 0.1], + [0.1, 1.2, 0.4], + ] + ), + energy=-1.4, + head="MP2", + ) + + +@pytest.fixture(name="atomic_data") +def _atomic_data(config1, config2, table): + atomic_data1 = AtomicData.from_config( + config1, z_table=table, cutoff=3.0, heads=["DFT", "MP2"] + ) + atomic_data2 = AtomicData.from_config( + config2, z_table=table, cutoff=3.0, heads=["DFT", "MP2"] + ) + return [atomic_data1, atomic_data2] + + +@pytest.fixture(name="data_loader") +def _data_loader(atomic_data): + return torch_geometric.dataloader.DataLoader( + dataset=atomic_data, + batch_size=2, + shuffle=False, + drop_last=False, + ) + + +@pytest.fixture(name="atomic_energies") +def _atomic_energies(): + atomic_energies_dict = { + "DFT": np.array([0.0, 0.0]), + "MP2": np.array([0.1, 0.1]), + } + return dict_to_array(atomic_energies_dict, ["DFT", "MP2"]) + + +@pytest.fixture(autouse=True) +def _set_torch_default_dtype(): + torch.set_default_dtype(torch.float64) + + +def test_weighted_loss(config, table): + loss1 = WeightedEnergyForcesLoss(energy_weight=1, forces_weight=10) + loss2 = WeightedHuberEnergyForcesStressLoss(energy_weight=1, forces_weight=10) + data = AtomicData.from_config(config, z_table=table, cutoff=3.0) + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[data, data], + batch_size=2, + shuffle=True, + drop_last=False, + ) + batch = next(iter(data_loader)) + pred = { + "energy": batch.energy, + "forces": batch.forces, + "stress": batch.stress, + } + out1 = loss1(batch, pred) + assert out1 == 0.0 + out2 = loss2(batch, pred) + assert out2 == 0.0 + + +def test_symmetric_contraction(): + operation = SymmetricContraction( + irreps_in=o3.Irreps("16x0e + 16x1o + 16x2e"), + irreps_out=o3.Irreps("16x0e + 16x1o"), + correlation=3, + num_elements=2, + ) + torch.manual_seed(123) + features = torch.randn(30, 16, 9) + one_hots = torch.nn.functional.one_hot(torch.arange(0, 30) % 2).to( + torch.get_default_dtype() + ) + out = operation(features, one_hots) + assert out.shape == (30, 64) + assert operation.contractions[0].weights_max.shape == (2, 11, 16) + + +def test_bessel_basis(): + d = torch.linspace(start=0.5, end=5.5, steps=10) + bessel_basis = BesselBasis(r_max=6.0, num_basis=5) + output = bessel_basis(d.unsqueeze(-1)) + assert output.shape == (10, 5) + + +def test_polynomial_cutoff(): + d = torch.linspace(start=0.5, end=5.5, steps=10) + cutoff_fn = PolynomialCutoff(r_max=5.0) + output = cutoff_fn(d) + assert output.shape == (10,) + + +def test_atomic_energies(config, table): + energies_block = AtomicEnergiesBlock(atomic_energies=np.array([1.0, 3.0])) + data = AtomicData.from_config(config, z_table=table, cutoff=3.0) + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[data, data], + batch_size=2, + shuffle=True, + drop_last=False, + ) + batch = next(iter(data_loader)) + energies = energies_block(batch.node_attrs).squeeze(-1) + out = scatter.scatter_sum(src=energies, index=batch.batch, dim=-1, reduce="sum") + out = to_numpy(out) + assert np.allclose(out, np.array([5.0, 5.0])) + + +def test_atomic_energies_multireference(config, table): + energies_block = AtomicEnergiesBlock( + atomic_energies=np.array([[1.0, 3.0], [2.0, 4.0]]) + ) + config.head = "MP2" + data = AtomicData.from_config( + config, z_table=table, cutoff=3.0, heads=["DFT", "MP2"] + ) + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[data, data], + batch_size=2, + shuffle=True, + drop_last=False, + ) + batch = next(iter(data_loader)) + num_atoms_arange = torch.arange(batch["positions"].shape[0]) + node_heads = ( + batch["head"][batch["batch"]] + if "head" in batch + else torch.zeros_like(batch["batch"]) + ) + energies = energies_block(batch.node_attrs).squeeze(-1) + energies = energies[num_atoms_arange, node_heads] + out = scatter.scatter_sum(src=energies, index=batch.batch, dim=-1, reduce="sum") + out = to_numpy(out) + assert np.allclose(out, np.array([8.0, 8.0])) + + +def test_compute_mean_rms_energy_forces_multi_head(data_loader, atomic_energies): + mean, rms = compute_mean_rms_energy_forces(data_loader, atomic_energies) + assert isinstance(mean, np.ndarray) + assert isinstance(rms, np.ndarray) + assert mean.shape == (2,) + assert rms.shape == (2,) + assert np.all(rms >= 0) + assert rms[0] != rms[1] + + +def test_compute_statistics(data_loader, atomic_energies): + avg_num_neighbors, mean, std = compute_statistics(data_loader, atomic_energies) + assert isinstance(avg_num_neighbors, float) + assert isinstance(mean, np.ndarray) + assert isinstance(std, np.ndarray) + assert mean.shape == (2,) + assert std.shape == (2,) + assert avg_num_neighbors > 0 + assert np.all(mean != 0) + assert np.all(std > 0) + assert mean[0] != mean[1] + assert std[0] != std[1] diff --git a/tests/test_preprocess.py b/tests/test_preprocess.py new file mode 100644 index 00000000..e0258bd4 --- /dev/null +++ b/tests/test_preprocess.py @@ -0,0 +1,166 @@ +import os +import subprocess +import sys +from pathlib import Path + +import ase.io +import numpy as np +import pytest +from ase.atoms import Atoms + +pytest_mace_dir = Path(__file__).parent.parent +preprocess_data = Path(__file__).parent.parent / "mace" / "cli" / "preprocess_data.py" + + +@pytest.fixture(name="sample_configs") +def fixture_sample_configs(): + water = Atoms( + numbers=[8, 1, 1], + positions=[[0, -2.0, 0], [1, 0, 0], [0, 1, 0]], + cell=[4] * 3, + pbc=[True] * 3, + ) + configs = [ + Atoms(numbers=[8], positions=[[0, 0, 0]], cell=[6] * 3), + Atoms(numbers=[1], positions=[[0, 0, 0]], cell=[6] * 3), + ] + configs[0].info["REF_energy"] = 0.0 + configs[0].info["config_type"] = "IsolatedAtom" + configs[1].info["REF_energy"] = 0.0 + configs[1].info["config_type"] = "IsolatedAtom" + + np.random.seed(5) + for _ in range(10): + c = water.copy() + c.positions += np.random.normal(0.1, size=c.positions.shape) + c.info["REF_energy"] = np.random.normal(0.1) + c.new_array("REF_forces", np.random.normal(0.1, size=c.positions.shape)) + c.info["REF_stress"] = np.random.normal(0.1, size=6) + configs.append(c) + + return configs + + +def test_preprocess_data(tmp_path, sample_configs): + ase.io.write(tmp_path / "sample.xyz", sample_configs) + + preprocess_params = { + "train_file": tmp_path / "sample.xyz", + "r_max": 5.0, + "config_type_weights": "{'Default':1.0}", + "num_process": 2, + "valid_fraction": 0.1, + "h5_prefix": tmp_path / "preprocessed_", + "compute_statistics": None, + "seed": 42, + "energy_key": "REF_energy", + "forces_key": "REF_forces", + "stress_key": "REF_stress", + } + + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(preprocess_data) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in preprocess_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 + + # Check if the output files are created + assert (tmp_path / "preprocessed_train").is_dir() + assert (tmp_path / "preprocessed_val").is_dir() + assert (tmp_path / "preprocessed_statistics.json").is_file() + + # Check if the correct number of files are created + train_files = list((tmp_path / "preprocessed_train").glob("*.h5")) + val_files = list((tmp_path / "preprocessed_val").glob("*.h5")) + assert len(train_files) == preprocess_params["num_process"] + assert len(val_files) == preprocess_params["num_process"] + + # Example of checking statistics file content: + import json + + with open(tmp_path / "preprocessed_statistics.json", "r", encoding="utf-8") as f: + statistics = json.load(f) + assert "atomic_energies" in statistics + assert "avg_num_neighbors" in statistics + assert "mean" in statistics + assert "std" in statistics + assert "atomic_numbers" in statistics + assert "r_max" in statistics + + # Example of checking H5 file content: + import h5py + + with h5py.File(train_files[0], "r") as f: + assert "config_batch_0" in f + config = f["config_batch_0"]["config_0"] + assert "atomic_numbers" in config + assert "positions" in config + assert "energy" in config + assert "forces" in config + + original_energies = [ + config.info["REF_energy"] + for config in sample_configs[2:] + if "REF_energy" in config.info + ] + original_forces = [ + config.arrays["REF_forces"] + for config in sample_configs[2:] + if "REF_forces" in config.arrays + ] + + h5_energies = [] + h5_forces = [] + + for train_file in train_files: + with h5py.File(train_file, "r") as f: + for _, batch in f.items(): + for config_key in batch.keys(): + config = batch[config_key] + assert "atomic_numbers" in config + assert "positions" in config + assert "energy" in config + assert "forces" in config + + h5_energies.append(config["energy"][()]) + h5_forces.append(config["forces"][()]) + + for val_file in val_files: + with h5py.File(val_file, "r") as f: + for _, batch in f.items(): + for config_key in batch.keys(): + config = batch[config_key] + h5_energies.append(config["energy"][()]) + h5_forces.append(config["forces"][()]) + + print("Original energies", original_energies) + print("H5 energies", h5_energies) + print("Original forces", original_forces) + print("H5 forces", h5_forces) + original_energies.sort() + h5_energies.sort() + original_forces = np.concatenate(original_forces).flatten() + h5_forces = np.concatenate(h5_forces).flatten() + original_forces.sort() + h5_forces.sort() + + # Compare energies and forces + np.testing.assert_allclose(original_energies, h5_energies, rtol=1e-5, atol=1e-8) + np.testing.assert_allclose(original_forces, h5_forces, rtol=1e-5, atol=1e-8) + + print("All checks passed successfully!") diff --git a/tests/test_run_train.py b/tests/test_run_train.py new file mode 100644 index 00000000..ca196c47 --- /dev/null +++ b/tests/test_run_train.py @@ -0,0 +1,849 @@ +import json +import os +import subprocess +import sys +from pathlib import Path + +import ase.io +import numpy as np +import pytest +from ase.atoms import Atoms + +from mace.calculators.mace import MACECalculator + +run_train = Path(__file__).parent.parent / "mace" / "cli" / "run_train.py" + + +@pytest.fixture(name="fitting_configs") +def fixture_fitting_configs(): + water = Atoms( + numbers=[8, 1, 1], + positions=[[0, -2.0, 0], [1, 0, 0], [0, 1, 0]], + cell=[4] * 3, + pbc=[True] * 3, + ) + fit_configs = [ + Atoms(numbers=[8], positions=[[0, 0, 0]], cell=[6] * 3), + Atoms(numbers=[1], positions=[[0, 0, 0]], cell=[6] * 3), + ] + fit_configs[0].info["REF_energy"] = 0.0 + fit_configs[0].info["config_type"] = "IsolatedAtom" + fit_configs[1].info["REF_energy"] = 0.0 + fit_configs[1].info["config_type"] = "IsolatedAtom" + + np.random.seed(5) + for _ in range(20): + c = water.copy() + c.positions += np.random.normal(0.1, size=c.positions.shape) + c.info["REF_energy"] = np.random.normal(0.1) + print(c.info["REF_energy"]) + c.new_array("REF_forces", np.random.normal(0.1, size=c.positions.shape)) + c.info["REF_stress"] = np.random.normal(0.1, size=6) + fit_configs.append(c) + + return fit_configs + + +@pytest.fixture(name="pretraining_configs") +def fixture_pretraining_configs(): + configs = [] + for _ in range(10): + atoms = Atoms( + numbers=[8, 1, 1], + positions=np.random.rand(3, 3) * 3, + cell=[5, 5, 5], + pbc=[True] * 3, + ) + atoms.info["REF_energy"] = np.random.normal(0, 1) + atoms.arrays["REF_forces"] = np.random.normal(0, 1, size=(3, 3)) + atoms.info["REF_stress"] = np.random.normal(0, 1, size=6) + configs.append(atoms) + configs.append( + Atoms(numbers=[8], positions=[[0, 0, 0]], cell=[6] * 3, pbc=[True] * 3), + ) + configs.append( + Atoms(numbers=[1], positions=[[0, 0, 0]], cell=[6] * 3, pbc=[True] * 3) + ) + configs[-2].info["REF_energy"] = -2.0 + configs[-2].info["config_type"] = "IsolatedAtom" + configs[-1].info["REF_energy"] = -4.0 + configs[-1].info["config_type"] = "IsolatedAtom" + return configs + + +_mace_params = { + "name": "MACE", + "valid_fraction": 0.05, + "energy_weight": 1.0, + "forces_weight": 10.0, + "stress_weight": 1.0, + "model": "MACE", + "hidden_irreps": "128x0e", + "r_max": 3.5, + "batch_size": 5, + "max_num_epochs": 10, + "swa": None, + "start_swa": 5, + "ema": None, + "ema_decay": 0.99, + "amsgrad": None, + "restart_latest": None, + "device": "cpu", + "seed": 5, + "loss": "stress", + "energy_key": "REF_energy", + "forces_key": "REF_forces", + "stress_key": "REF_stress", + "eval_interval": 2, +} + + +def test_run_train(tmp_path, fitting_configs): + ase.io.write(tmp_path / "fit.xyz", fitting_configs) + + mace_params = _mace_params.copy() + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["train_file"] = tmp_path / "fit.xyz" + + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 + + calc = MACECalculator(model_paths=tmp_path / "MACE.model", device="cpu") + + Es = [] + for at in fitting_configs: + at.calc = calc + Es.append(at.get_potential_energy()) + + print("Es", Es) + # from a run on 04/06/2024 on stress_bugfix 967f0bfb6490086599da247874b24595d149caa7 + ref_Es = [ + 0.0, + 0.0, + -0.039181344585828524, + -0.0915223395136733, + -0.14953484236456582, + -0.06662480820063998, + -0.09983737353050133, + 0.12477442296789745, + -0.06486086271762856, + -0.1460607988519944, + 0.12886334908465508, + -0.14000990081920373, + -0.05319886578958313, + 0.07780520158391, + -0.08895480281886901, + -0.15474719614734422, + 0.007756765146527644, + -0.044879267197498685, + -0.036065736712447574, + -0.24413743841886623, + -0.0838104612106429, + -0.14751978636626545, + ] + + assert np.allclose(Es, ref_Es) + + +def test_run_train_missing_data(tmp_path, fitting_configs): + del fitting_configs[5].info["REF_energy"] + del fitting_configs[6].arrays["REF_forces"] + del fitting_configs[7].info["REF_stress"] + + ase.io.write(tmp_path / "fit.xyz", fitting_configs) + + mace_params = _mace_params.copy() + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["train_file"] = tmp_path / "fit.xyz" + + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 + + calc = MACECalculator(model_paths=tmp_path / "MACE.model", device="cpu") + + Es = [] + for at in fitting_configs: + at.calc = calc + Es.append(at.get_potential_energy()) + + print("Es", Es) + # from a run on 04/06/2024 on stress_bugfix 967f0bfb6490086599da247874b24595d149caa7 + ref_Es = [ + 0.0, + 0.0, + -0.05464025113696155, + -0.11272131295940478, + 0.039200919331076826, + -0.07517990972827505, + -0.13504202474582666, + 0.0292022872055344, + -0.06541099574579018, + -0.1497824717832886, + 0.19397709360828813, + -0.13587609467143014, + -0.05242956276828463, + -0.0504862057364953, + -0.07095795959430119, + -0.2463753796753703, + -0.002031543147676121, + -0.03864918790300681, + -0.13680153117705554, + -0.23418951968636786, + -0.11790833839379238, + -0.14930562311066484, + ] + assert np.allclose(Es, ref_Es) + + +def test_run_train_no_stress(tmp_path, fitting_configs): + del fitting_configs[5].info["REF_energy"] + del fitting_configs[6].arrays["REF_forces"] + del fitting_configs[7].info["REF_stress"] + + ase.io.write(tmp_path / "fit.xyz", fitting_configs) + + mace_params = _mace_params.copy() + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["train_file"] = tmp_path / "fit.xyz" + mace_params["loss"] = "weighted" + + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 + + calc = MACECalculator(model_paths=tmp_path / "MACE.model", device="cpu") + + Es = [] + for at in fitting_configs: + at.calc = calc + Es.append(at.get_potential_energy()) + + print("Es", Es) + # from a run on 28/03/2023 on main 88d49f9ed6925dec07d1777043a36e1fe4872ff3 + ref_Es = [ + 0.0, + 0.0, + -0.05450093218377135, + -0.11235475232750518, + 0.03914558031854152, + -0.07500839914816063, + -0.13469160624431492, + 0.029384214243251838, + -0.06521819204166135, + -0.14944896282001804, + 0.19413948083049481, + -0.13543541860473626, + -0.05235495076237124, + -0.049556206595684105, + -0.07080758913030646, + -0.24571898386301153, + -0.002070636306950905, + -0.03863113401320783, + -0.13620291339913712, + -0.23383074855679695, + -0.11776449630199368, + -0.1489441490225184, + ] + assert np.allclose(Es, ref_Es) + + +def test_run_train_multihead(tmp_path, fitting_configs): + fitting_configs_dft = [] + fitting_configs_mp2 = [] + fitting_configs_ccd = [] + for _, c in enumerate(fitting_configs): + c_dft = c.copy() + c_dft.info["head"] = "DFT" + fitting_configs_dft.append(c_dft) + + c_mp2 = c.copy() + c_mp2.info["head"] = "MP2" + fitting_configs_mp2.append(c_mp2) + + c_ccd = c.copy() + c_ccd.info["head"] = "CCD" + fitting_configs_ccd.append(c_ccd) + ase.io.write(tmp_path / "fit_multihead_dft.xyz", fitting_configs_dft) + ase.io.write(tmp_path / "fit_multihead_mp2.xyz", fitting_configs_mp2) + ase.io.write(tmp_path / "fit_multihead_ccd.xyz", fitting_configs_ccd) + + heads = { + "DFT": {"train_file": f"{str(tmp_path)}/fit_multihead_dft.xyz"}, + "MP2": {"train_file": f"{str(tmp_path)}/fit_multihead_mp2.xyz"}, + "CCD": {"train_file": f"{str(tmp_path)}/fit_multihead_ccd.xyz"}, + } + yaml_str = "heads:\n" + for key, value in heads.items(): + yaml_str += f" {key}:\n" + for sub_key, sub_value in value.items(): + yaml_str += f" {sub_key}: {sub_value}\n" + filename = tmp_path / "config.yaml" + with open(filename, "w", encoding="utf-8") as file: + file.write(yaml_str) + + mace_params = _mace_params.copy() + mace_params["valid_fraction"] = 0.1 + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["loss"] = "weighted" + mace_params["hidden_irreps"] = "128x0e" + mace_params["r_max"] = 6.0 + mace_params["default_dtype"] = "float64" + mace_params["num_radial_basis"] = 10 + mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" + mace_params["config"] = tmp_path / "config.yaml" + mace_params["batch_size"] = 2 + mace_params["num_samples_pt"] = 50 + mace_params["subselect_pt"] = "random" + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 + + calc = MACECalculator( + model_paths=tmp_path / "MACE.model", device="cpu", default_dtype="float64" + ) + + Es = [] + for at in fitting_configs: + at.calc = calc + Es.append(at.get_potential_energy()) + + print("Es", Es) + # from a run on 02/09/2024 on develop branch + ref_Es = [ + 0.0, + 0.0, + 0.10637113905361611, + -0.012499594026624754, + 0.08983077108171753, + 0.21071322543112597, + -0.028921849222784398, + -0.02423359575741567, + 0.022923252188079057, + -0.02048334610058991, + 0.4349711162741364, + -0.04455577015569887, + -0.09765806785570091, + 0.16013134616829822, + 0.0758442928017698, + -0.05931856557011721, + 0.33964473532953265, + 0.134338442158641, + 0.18024119757783053, + -0.18914740992058765, + -0.06503477155294624, + 0.03436649147415213, + ] + assert np.allclose(Es, ref_Es) + + +def test_run_train_foundation(tmp_path, fitting_configs): + ase.io.write(tmp_path / "fit.xyz", fitting_configs) + + mace_params = _mace_params.copy() + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["train_file"] = tmp_path / "fit.xyz" + mace_params["loss"] = "weighted" + mace_params["foundation_model"] = "small" + mace_params["hidden_irreps"] = "128x0e" + mace_params["r_max"] = 6.0 + mace_params["default_dtype"] = "float64" + mace_params["num_radial_basis"] = 10 + mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" + mace_params["multiheads_finetuning"] = False + print("mace_params", mace_params) + # mace_params["num_samples_pt"] = 50 + # mace_params["subselect_pt"] = "random" + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 + + calc = MACECalculator( + model_paths=tmp_path / "MACE.model", device="cpu", default_dtype="float64" + ) + + Es = [] + for at in fitting_configs: + at.calc = calc + Es.append(at.get_potential_energy()) + + print("Es", Es) + # from a run on 28/03/2023 on repulsion a63434aaab70c84ee016e13e4aca8d57297a0f26 + ref_Es = [ + 1.6780993938446045, + 0.8916864395141602, + 0.7290308475494385, + 0.6194742918014526, + 0.6697757840156555, + 0.7025266289710999, + 0.5818213224411011, + 0.7897703647613525, + 0.6558921337127686, + 0.5071806907653809, + 3.581131935119629, + 0.691562294960022, + 0.6257331967353821, + 0.9560437202453613, + 0.7716934680938721, + 0.6730310916900635, + 0.8297463655471802, + 0.8053972721099854, + 0.8337507247924805, + 0.4107491970062256, + 0.6019601821899414, + 0.7301387786865234, + ] + assert np.allclose(Es, ref_Es) + + +def test_run_train_foundation_multihead(tmp_path, fitting_configs): + fitting_configs_dft = [] + fitting_configs_mp2 = [] + for i, c in enumerate(fitting_configs): + if i in (0, 1): + c_dft = c.copy() + c_dft.info["head"] = "DFT" + fitting_configs_dft.append(c_dft) + fitting_configs_dft.append(c) + c_mp2 = c.copy() + c_mp2.info["head"] = "MP2" + fitting_configs_mp2.append(c_mp2) + elif i % 2 == 0: + c.info["head"] = "DFT" + fitting_configs_dft.append(c) + else: + c.info["head"] = "MP2" + fitting_configs_mp2.append(c) + ase.io.write(tmp_path / "fit_multihead_dft.xyz", fitting_configs_dft) + ase.io.write(tmp_path / "fit_multihead_mp2.xyz", fitting_configs_mp2) + + heads = { + "DFT": {"train_file": f"{str(tmp_path)}/fit_multihead_dft.xyz"}, + "MP2": {"train_file": f"{str(tmp_path)}/fit_multihead_mp2.xyz"}, + } + yaml_str = "heads:\n" + for key, value in heads.items(): + yaml_str += f" {key}:\n" + for sub_key, sub_value in value.items(): + yaml_str += f" {sub_key}: {sub_value}\n" + filename = tmp_path / "config.yaml" + with open(filename, "w", encoding="utf-8") as file: + file.write(yaml_str) + mace_params = _mace_params.copy() + mace_params["valid_fraction"] = 0.1 + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["config"] = tmp_path / "config.yaml" + mace_params["loss"] = "weighted" + mace_params["foundation_model"] = "small" + mace_params["hidden_irreps"] = "128x0e" + mace_params["r_max"] = 6.0 + mace_params["default_dtype"] = "float64" + mace_params["num_radial_basis"] = 10 + mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" + mace_params["batch_size"] = 2 + mace_params["valid_batch_size"] = 1 + mace_params["num_samples_pt"] = 50 + mace_params["subselect_pt"] = "random" + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 + + calc = MACECalculator( + model_paths=tmp_path / "MACE.model", device="cpu", default_dtype="float64" + ) + + Es = [] + for at in fitting_configs: + at.calc = calc + Es.append(at.get_potential_energy()) + + print("Es", Es) + # from a run on 20/08/2024 on commit + ref_Es = [ + 1.654685616493225, + 0.44693732261657715, + 0.8741313815116882, + 0.569085955619812, + 0.7161882519721985, + 0.8654778599739075, + 0.8722733855247498, + 0.49582308530807495, + 0.814422607421875, + 0.7027317881584167, + 0.7196993827819824, + 0.517953097820282, + 0.8631765246391296, + 0.4679797887802124, + 0.8163984417915344, + 0.4252359867095947, + 1.0861445665359497, + 0.6829671263694763, + 0.7136879563331604, + 0.5160345435142517, + 0.7002358436584473, + 0.5574042201042175, + ] + assert np.allclose(Es, ref_Es, atol=1e-1) + + +def test_run_train_foundation_multihead_json(tmp_path, fitting_configs): + fitting_configs_dft = [] + fitting_configs_mp2 = [] + for i, c in enumerate(fitting_configs): + + if i in (0, 1): + continue # skip isolated atoms, as energies specified by json files below + if i % 2 == 0: + c.info["head"] = "DFT" + fitting_configs_dft.append(c) + else: + c.info["head"] = "MP2" + fitting_configs_mp2.append(c) + ase.io.write(tmp_path / "fit_multihead_dft.xyz", fitting_configs_dft) + ase.io.write(tmp_path / "fit_multihead_mp2.xyz", fitting_configs_mp2) + + # write E0s to json files + E0s = {1: 0.0, 8: 0.0} + with open(tmp_path / "fit_multihead_dft.json", "w", encoding="utf-8") as f: + json.dump(E0s, f) + with open(tmp_path / "fit_multihead_mp2.json", "w", encoding="utf-8") as f: + json.dump(E0s, f) + + heads = { + "DFT": { + "train_file": f"{str(tmp_path)}/fit_multihead_dft.xyz", + "E0s": f"{str(tmp_path)}/fit_multihead_dft.json", + }, + "MP2": { + "train_file": f"{str(tmp_path)}/fit_multihead_mp2.xyz", + "E0s": f"{str(tmp_path)}/fit_multihead_mp2.json", + }, + } + yaml_str = "heads:\n" + for key, value in heads.items(): + yaml_str += f" {key}:\n" + for sub_key, sub_value in value.items(): + yaml_str += f" {sub_key}: {sub_value}\n" + filename = tmp_path / "config.yaml" + with open(filename, "w", encoding="utf-8") as file: + file.write(yaml_str) + mace_params = _mace_params.copy() + mace_params["valid_fraction"] = 0.1 + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["config"] = tmp_path / "config.yaml" + mace_params["loss"] = "weighted" + mace_params["foundation_model"] = "small" + mace_params["hidden_irreps"] = "128x0e" + mace_params["r_max"] = 6.0 + mace_params["default_dtype"] = "float64" + mace_params["num_radial_basis"] = 10 + mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" + mace_params["batch_size"] = 2 + mace_params["valid_batch_size"] = 1 + mace_params["num_samples_pt"] = 50 + mace_params["subselect_pt"] = "random" + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 + + calc = MACECalculator( + model_paths=tmp_path / "MACE.model", device="cpu", default_dtype="float64" + ) + + Es = [] + for at in fitting_configs: + at.calc = calc + Es.append(at.get_potential_energy()) + + print("Es", Es) + # from a run on 20/08/2024 on commit + ref_Es = [ + 1.654685616493225, + 0.44693732261657715, + 0.8741313815116882, + 0.569085955619812, + 0.7161882519721985, + 0.8654778599739075, + 0.8722733855247498, + 0.49582308530807495, + 0.814422607421875, + 0.7027317881584167, + 0.7196993827819824, + 0.517953097820282, + 0.8631765246391296, + 0.4679797887802124, + 0.8163984417915344, + 0.4252359867095947, + 1.0861445665359497, + 0.6829671263694763, + 0.7136879563331604, + 0.5160345435142517, + 0.7002358436584473, + 0.5574042201042175, + ] + assert np.allclose(Es, ref_Es, atol=1e-1) + + +def test_run_train_multihead_replay_custum_finetuning( + tmp_path, fitting_configs, pretraining_configs +): + ase.io.write(tmp_path / "pretrain.xyz", pretraining_configs) + + foundation_params = { + "name": "foundation", + "train_file": os.path.join(tmp_path, "pretrain.xyz"), + "valid_fraction": 0.2, + "energy_weight": 1.0, + "forces_weight": 10.0, + "stress_weight": 1.0, + "model": "MACE", + "hidden_irreps": "32x0e", + "r_max": 5.0, + "batch_size": 2, + "max_num_epochs": 5, + "swa": None, + "start_swa": 3, + "device": "cpu", + "seed": 42, + "loss": "weighted", + "energy_key": "REF_energy", + "forces_key": "REF_forces", + "stress_key": "REF_stress", + "default_dtype": "float64", + "checkpoints_dir": str(tmp_path), + "model_dir": str(tmp_path), + } + + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + + cmd = [sys.executable, str(run_train)] + for k, v in foundation_params.items(): + if v is None: + cmd.append(f"--{k}") + else: + cmd.append(f"--{k}={v}") + + p = subprocess.run(cmd, env=run_env, check=True) + assert p.returncode == 0 + + # Step 3: Create finetuning set + fitting_configs_dft = [] + fitting_configs_mp2 = [] + for i, c in enumerate(fitting_configs): + if i in (0, 1): + c_dft = c.copy() + c_dft.info["head"] = "DFT" + fitting_configs_dft.append(c_dft) + fitting_configs_dft.append(c) + c_mp2 = c.copy() + c_mp2.info["head"] = "MP2" + fitting_configs_mp2.append(c_mp2) + elif i % 2 == 0: + c.info["head"] = "DFT" + fitting_configs_dft.append(c) + else: + c.info["head"] = "MP2" + fitting_configs_mp2.append(c) + ase.io.write(tmp_path / "fit_multihead_dft.xyz", fitting_configs_dft) + ase.io.write(tmp_path / "fit_multihead_mp2.xyz", fitting_configs_mp2) + + # Step 4: Finetune the pretrained model with multihead replay + heads = { + "DFT": {"train_file": f"{str(tmp_path)}/fit_multihead_dft.xyz"}, + "MP2": {"train_file": f"{str(tmp_path)}/fit_multihead_mp2.xyz"}, + } + yaml_str = "heads:\n" + for key, value in heads.items(): + yaml_str += f" {key}:\n" + for sub_key, sub_value in value.items(): + yaml_str += f" {sub_key}: {sub_value}\n" + filename = tmp_path / "config.yaml" + with open(filename, "w", encoding="utf-8") as file: + file.write(yaml_str) + + finetuning_params = { + "name": "finetuned", + "valid_fraction": 0.1, + "energy_weight": 1.0, + "forces_weight": 10.0, + "stress_weight": 1.0, + "model": "MACE", + "hidden_irreps": "32x0e", + "r_max": 5.0, + "batch_size": 2, + "max_num_epochs": 5, + "device": "cpu", + "seed": 42, + "loss": "weighted", + "default_dtype": "float64", + "checkpoints_dir": str(tmp_path), + "model_dir": str(tmp_path), + "foundation_model": os.path.join(tmp_path, "foundation.model"), + "config": os.path.join(tmp_path, "config.yaml"), + "pt_train_file": os.path.join(tmp_path, "pretrain.xyz"), + "num_samples_pt": 3, + "subselect_pt": "random", + } + + cmd = [sys.executable, str(run_train)] + for k, v in finetuning_params.items(): + if v is None: + cmd.append(f"--{k}") + else: + cmd.append(f"--{k}={v}") + + p = subprocess.run(cmd, env=run_env, check=True) + assert p.returncode == 0 + + # Load and test the finetuned model + calc = MACECalculator( + model_paths=tmp_path / "finetuned.model", device="cpu", default_dtype="float64" + ) + + Es = [] + for at in fitting_configs: + at.calc = calc + Es.append(at.get_potential_energy()) + + print("Energies:", Es) + + # Add some basic checks + assert len(Es) == len(fitting_configs) + assert all(isinstance(E, float) for E in Es) + assert len(set(Es)) > 1 # Ens diff --git a/tests/test_schedulefree.py b/tests/test_schedulefree.py new file mode 100644 index 00000000..00b20750 --- /dev/null +++ b/tests/test_schedulefree.py @@ -0,0 +1,127 @@ +import tempfile +from unittest.mock import MagicMock + +import numpy as np +import pytest +import torch +import torch.nn.functional as F +from e3nn import o3 + +from mace import data, modules, tools +from mace.tools import scripts_utils, torch_geometric + +try: + import schedulefree +except ImportError: + pytest.skip( + "Skipping schedulefree tests due to ImportError", allow_module_level=True + ) + +torch.set_default_dtype(torch.float64) + +table = tools.AtomicNumberTable([6]) +atomic_energies = np.array([1.0], dtype=float) +cutoff = 5.0 + + +def create_mace(device: str, seed: int = 1702): + torch_geometric.seed_everything(seed) + + model_config = { + "r_max": cutoff, + "num_bessel": 8, + "num_polynomial_cutoff": 6, + "max_ell": 3, + "interaction_cls": modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + "interaction_cls_first": modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + "num_interactions": 2, + "num_elements": 1, + "hidden_irreps": o3.Irreps("8x0e + 8x1o"), + "MLP_irreps": o3.Irreps("16x0e"), + "gate": F.silu, + "atomic_energies": atomic_energies, + "avg_num_neighbors": 8, + "atomic_numbers": table.zs, + "correlation": 3, + "radial_type": "bessel", + } + model = modules.MACE(**model_config) + return model.to(device) + + +def create_batch(device: str): + from ase import build + + size = 2 + atoms = build.bulk("C", "diamond", a=3.567, cubic=True) + atoms_list = [atoms.repeat((size, size, size))] + print("Number of atoms", len(atoms_list[0])) + + configs = [data.config_from_atoms(atoms) for atoms in atoms_list] + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[ + data.AtomicData.from_config(config, z_table=table, cutoff=cutoff) + for config in configs + ], + batch_size=1, + shuffle=False, + drop_last=False, + ) + batch = next(iter(data_loader)) + batch = batch.to(device) + batch = batch.to_dict() + return batch + + +def do_optimization_step( + model, + optimizer, + device, +): + batch = create_batch(device) + model.train() + optimizer.train() + optimizer.zero_grad() + output = model(batch, training=True, compute_force=False) + loss = output["energy"].mean() + loss.backward() + optimizer.step() + model.eval() + optimizer.eval() + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_can_load_checkpoint(device): + model = create_mace(device) + optimizer = schedulefree.adamw_schedulefree.AdamWScheduleFree(model.parameters()) + args = MagicMock() + args.optimizer = "schedulefree" + args.scheduler = "ExponentialLR" + args.lr_scheduler_gamma = 0.9 + lr_scheduler = scripts_utils.LRScheduler(optimizer, args) + with tempfile.TemporaryDirectory() as d: + checkpoint_handler = tools.CheckpointHandler( + directory=d, keep=False, tag="schedulefree" + ) + for _ in range(10): + do_optimization_step(model, optimizer, device) + batch = create_batch(device) + output = model(batch) + energy = output["energy"].detach().cpu().numpy() + + state = tools.CheckpointState( + model=model, optimizer=optimizer, lr_scheduler=lr_scheduler + ) + checkpoint_handler.save(state, epochs=0, keep_last=False) + checkpoint_handler.load_latest( + state=tools.CheckpointState(model, optimizer, lr_scheduler), + swa=False, + ) + batch = create_batch(device) + output = model(batch) + new_energy = output["energy"].detach().cpu().numpy() + assert np.allclose(energy, new_energy, atol=1e-9) diff --git a/tests/test_tools.py b/tests/test_tools.py new file mode 100644 index 00000000..227a1bfc --- /dev/null +++ b/tests/test_tools.py @@ -0,0 +1,48 @@ +import tempfile + +import numpy as np +import torch +import torch.nn.functional +from torch import nn, optim + +from mace.tools import ( + AtomicNumberTable, + CheckpointHandler, + CheckpointState, + atomic_numbers_to_indices, +) + + +def test_atomic_number_table(): + table = AtomicNumberTable(zs=[1, 8]) + array = np.array([8, 8, 1]) + indices = atomic_numbers_to_indices(array, z_table=table) + expected = np.array([1, 1, 0], dtype=int) + assert np.allclose(expected, indices) + + +class MyModel(nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 4) + + def forward(self, x): + return torch.nn.functional.relu(self.linear(x)) + + +def test_save_load(): + model = MyModel() + initial_lr = 0.001 + optimizer = optim.SGD(model.parameters(), lr=initial_lr, momentum=0.9) + scheduler = optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.99) + + with tempfile.TemporaryDirectory() as directory: + handler = CheckpointHandler(directory=directory, tag="test", keep=True) + handler.save(state=CheckpointState(model, optimizer, scheduler), epochs=50) + + optimizer.step() + scheduler.step() + assert not np.isclose(optimizer.param_groups[0]["lr"], initial_lr) + + handler.load_latest(state=CheckpointState(model, optimizer, scheduler)) + assert np.isclose(optimizer.param_groups[0]["lr"], initial_lr) From a00ab0fa0180c38606e87c0af66684625150aa12 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Thu, 14 Nov 2024 13:08:02 +0000 Subject: [PATCH 02/17] add configdata, support in block, wrappers --- mace/modules/blocks.py | 181 +++++++++++++++++----------- mace/modules/irreps_tools.py | 28 ++++- mace/modules/wrapper_ops.py | 220 +++++++++++++++++++++++++++++++++++ 3 files changed, 358 insertions(+), 71 deletions(-) create mode 100644 mace/modules/wrapper_ops.py diff --git a/mace/modules/blocks.py b/mace/modules/blocks.py index 0db3b02e..b5e6bfb0 100644 --- a/mace/modules/blocks.py +++ b/mace/modules/blocks.py @@ -5,6 +5,7 @@ ########################################################################################### from abc import abstractmethod +import dataclasses from typing import Callable, List, Optional, Tuple, Union import numpy as np @@ -12,6 +13,7 @@ from e3nn import nn, o3 from e3nn.util.jit import compile_mode +from mace.modules.wrapper_ops import CuEquivarianceConfig, FullyConnectedTensorProduct, Linear, SymmetricContractionWrapper, TensorProduct from mace.tools.compile import simplify_if_compile from mace.tools.scatter import scatter_sum @@ -31,12 +33,19 @@ ) from .symmetric_contraction import SymmetricContraction +try: + import cuequivariance as cue + import cuequivariance_torch as cuet + CUET_AVAILABLE = True +except ImportError: + CUET_AVAILABLE = False @compile_mode("script") class LinearNodeEmbeddingBlock(torch.nn.Module): - def __init__(self, irreps_in: o3.Irreps, irreps_out: o3.Irreps): + def __init__(self, irreps_in: o3.Irreps, irreps_out: o3.Irreps, cueq_config: Optional[CuEquivarianceConfig] = None, +): super().__init__() - self.linear = o3.Linear(irreps_in=irreps_in, irreps_out=irreps_out) + self.linear = Linear(irreps_in=irreps_in, irreps_out=irreps_out, cueq_config=cueq_config) def forward( self, @@ -47,9 +56,9 @@ def forward( @compile_mode("script") class LinearReadoutBlock(torch.nn.Module): - def __init__(self, irreps_in: o3.Irreps, irrep_out: o3.Irreps = o3.Irreps("0e")): + def __init__(self, irreps_in: o3.Irreps, irrep_out: o3.Irreps = o3.Irreps("0e"), cueq_config: Optional[CuEquivarianceConfig] = None): super().__init__() - self.linear = o3.Linear(irreps_in=irreps_in, irreps_out=irrep_out) + self.linear = Linear(irreps_in=irreps_in, irreps_out=irrep_out, cueq_config=cueq_config) def forward( self, @@ -69,13 +78,14 @@ def __init__( gate: Optional[Callable], irrep_out: o3.Irreps = o3.Irreps("0e"), num_heads: int = 1, + cueq_config: Optional[CuEquivarianceConfig] = None, ): super().__init__() self.hidden_irreps = MLP_irreps self.num_heads = num_heads - self.linear_1 = o3.Linear(irreps_in=irreps_in, irreps_out=self.hidden_irreps) + self.linear_1 = Linear(irreps_in=irreps_in, irreps_out=self.hidden_irreps, cueq_config=cueq_config) self.non_linearity = nn.Activation(irreps_in=self.hidden_irreps, acts=[gate]) - self.linear_2 = o3.Linear(irreps_in=self.hidden_irreps, irreps_out=irrep_out) + self.linear_2 = Linear(irreps_in=self.hidden_irreps, irreps_out=irrep_out, cueq_config=cueq_config) def forward( self, x: torch.Tensor, heads: Optional[torch.Tensor] = None @@ -89,13 +99,13 @@ def forward( @compile_mode("script") class LinearDipoleReadoutBlock(torch.nn.Module): - def __init__(self, irreps_in: o3.Irreps, dipole_only: bool = False): + def __init__(self, irreps_in: o3.Irreps, dipole_only: bool = False, cueq_config: Optional[CuEquivarianceConfig] = None): super().__init__() if dipole_only: self.irreps_out = o3.Irreps("1x1o") else: self.irreps_out = o3.Irreps("1x0e + 1x1o") - self.linear = o3.Linear(irreps_in=irreps_in, irreps_out=self.irreps_out) + self.linear = o3.Linear(irreps_in=irreps_in, irreps_out=self.irreps_out, cueq_config=cueq_config) def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ] return self.linear(x) # [n_nodes, 1] @@ -109,6 +119,7 @@ def __init__( MLP_irreps: o3.Irreps, gate: Callable, dipole_only: bool = False, + cueq_config: Optional[CuEquivarianceConfig] = None, ): super().__init__() self.hidden_irreps = MLP_irreps @@ -131,9 +142,9 @@ def __init__( irreps_gated=irreps_gated, ) self.irreps_nonlin = self.equivariant_nonlin.irreps_in.simplify() - self.linear_1 = o3.Linear(irreps_in=irreps_in, irreps_out=self.irreps_nonlin) - self.linear_2 = o3.Linear( - irreps_in=self.hidden_irreps, irreps_out=self.irreps_out + self.linear_1 = Linear(irreps_in=irreps_in, irreps_out=self.irreps_nonlin, cueq_config=cueq_config) + self.linear_2 = Linear( + irreps_in=self.hidden_irreps, irreps_out=self.irreps_out, cueq_config=cueq_config ) def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ] @@ -218,22 +229,25 @@ def __init__( correlation: int, use_sc: bool = True, num_elements: Optional[int] = None, + cueq_config: Optional[CuEquivarianceConfig] = None, ) -> None: super().__init__() self.use_sc = use_sc - self.symmetric_contractions = SymmetricContraction( + self.symmetric_contractions = SymmetricContractionWrapper( irreps_in=node_feats_irreps, irreps_out=target_irreps, correlation=correlation, num_elements=num_elements, + cueq_config=cueq_config, ) # Update linear - self.linear = o3.Linear( + self.linear = Linear( target_irreps, target_irreps, internal_weights=True, shared_weights=True, + cueq_config=cueq_config, ) def forward( @@ -260,6 +274,7 @@ def __init__( hidden_irreps: o3.Irreps, avg_num_neighbors: float, radial_MLP: Optional[List[int]] = None, + cueq_config: Optional[CuEquivarianceConfig] = None, ) -> None: super().__init__() self.node_attrs_irreps = node_attrs_irreps @@ -272,6 +287,7 @@ def __init__( if radial_MLP is None: radial_MLP = [64, 64, 64] self.radial_MLP = radial_MLP + self.cueq_config = cueq_config self._setup() @@ -325,23 +341,29 @@ def __repr__(self): @compile_mode("script") class ResidualElementDependentInteractionBlock(InteractionBlock): def _setup(self) -> None: - self.linear_up = o3.Linear( + if not hasattr(self, "cueq_config"): + self.cueq_config = None + + # First linear + self.linear_up = Linear( self.node_feats_irreps, self.node_feats_irreps, internal_weights=True, shared_weights=True, + cueq_config=self.cueq_config, ) # TensorProduct irreps_mid, instructions = tp_out_irreps_with_instructions( self.node_feats_irreps, self.edge_attrs_irreps, self.target_irreps ) - self.conv_tp = o3.TensorProduct( + self.conv_tp = TensorProduct( self.node_feats_irreps, self.edge_attrs_irreps, irreps_mid, instructions=instructions, shared_weights=False, internal_weights=False, + cueq_config=self.cueq_config, ) self.conv_tp_weights = TensorProductWeightsBlock( num_elements=self.node_attrs_irreps.num_irreps, @@ -353,13 +375,13 @@ def _setup(self) -> None: irreps_mid = irreps_mid.simplify() self.irreps_out = linear_out_irreps(irreps_mid, self.target_irreps) self.irreps_out = self.irreps_out.simplify() - self.linear = o3.Linear( - irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + self.linear = Linear( + irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True, cueq_config=self.cueq_config ) # Selector TensorProduct - self.skip_tp = o3.FullyConnectedTensorProduct( - self.node_feats_irreps, self.node_attrs_irreps, self.irreps_out + self.skip_tp = FullyConnectedTensorProduct( + self.node_feats_irreps, self.node_attrs_irreps, self.irreps_out, cueq_config=self.cueq_config ) def forward( @@ -389,23 +411,27 @@ def forward( @compile_mode("script") class AgnosticNonlinearInteractionBlock(InteractionBlock): def _setup(self) -> None: - self.linear_up = o3.Linear( + if not hasattr(self, "cueq_config"): + self.cueq_config = None + self.linear_up = Linear( self.node_feats_irreps, self.node_feats_irreps, internal_weights=True, shared_weights=True, + cueq_config=self.cueq_config, ) # TensorProduct irreps_mid, instructions = tp_out_irreps_with_instructions( self.node_feats_irreps, self.edge_attrs_irreps, self.target_irreps ) - self.conv_tp = o3.TensorProduct( + self.conv_tp = TensorProduct( self.node_feats_irreps, self.edge_attrs_irreps, irreps_mid, instructions=instructions, shared_weights=False, internal_weights=False, + cueq_config=self.cueq_config ) # Convolution weights @@ -419,13 +445,13 @@ def _setup(self) -> None: irreps_mid = irreps_mid.simplify() self.irreps_out = linear_out_irreps(irreps_mid, self.target_irreps) self.irreps_out = self.irreps_out.simplify() - self.linear = o3.Linear( - irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + self.linear = Linear( + irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True, cueq_config=self.cueq_config ) # Selector TensorProduct - self.skip_tp = o3.FullyConnectedTensorProduct( - self.irreps_out, self.node_attrs_irreps, self.irreps_out + self.skip_tp = FullyConnectedTensorProduct( + self.irreps_out, self.node_attrs_irreps, self.irreps_out, cueq_config=self.cueq_config ) def forward( @@ -456,23 +482,25 @@ def forward( class AgnosticResidualNonlinearInteractionBlock(InteractionBlock): def _setup(self) -> None: # First linear - self.linear_up = o3.Linear( + self.linear_up = Linear( self.node_feats_irreps, self.node_feats_irreps, internal_weights=True, shared_weights=True, + cueq_config=self.cueq_config, ) # TensorProduct irreps_mid, instructions = tp_out_irreps_with_instructions( self.node_feats_irreps, self.edge_attrs_irreps, self.target_irreps ) - self.conv_tp = o3.TensorProduct( + self.conv_tp = TensorProduct( self.node_feats_irreps, self.edge_attrs_irreps, irreps_mid, instructions=instructions, shared_weights=False, internal_weights=False, + cueq_config=self.cueq_config ) # Convolution weights @@ -486,13 +514,13 @@ def _setup(self) -> None: irreps_mid = irreps_mid.simplify() self.irreps_out = linear_out_irreps(irreps_mid, self.target_irreps) self.irreps_out = self.irreps_out.simplify() - self.linear = o3.Linear( - irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + self.linear = Linear( + irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True, cueq_config=self.cueq_config ) # Selector TensorProduct - self.skip_tp = o3.FullyConnectedTensorProduct( - self.node_feats_irreps, self.node_attrs_irreps, self.irreps_out + self.skip_tp = FullyConnectedTensorProduct( + self.node_feats_irreps, self.node_attrs_irreps, self.irreps_out, cueq_config=self.cueq_config ) def forward( @@ -523,12 +551,15 @@ def forward( @compile_mode("script") class RealAgnosticInteractionBlock(InteractionBlock): def _setup(self) -> None: + if not hasattr(self, "cueq_config"): + self.cueq_config = None # First linear - self.linear_up = o3.Linear( + self.linear_up = Linear( self.node_feats_irreps, self.node_feats_irreps, internal_weights=True, shared_weights=True, + cueq_config=self.cueq_config ) # TensorProduct irreps_mid, instructions = tp_out_irreps_with_instructions( @@ -543,6 +574,7 @@ def _setup(self) -> None: instructions=instructions, shared_weights=False, internal_weights=False, + cueq_config=self.cueq_config, ) # Convolution weights @@ -555,15 +587,15 @@ def _setup(self) -> None: # Linear irreps_mid = irreps_mid.simplify() self.irreps_out = self.target_irreps - self.linear = o3.Linear( - irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + self.linear = Linear( + irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True, cueq_config=self.cueq_config ) # Selector TensorProduct - self.skip_tp = o3.FullyConnectedTensorProduct( - self.irreps_out, self.node_attrs_irreps, self.irreps_out + self.skip_tp = FullyConnectedTensorProduct( + self.irreps_out, self.node_attrs_irreps, self.irreps_out, cueq_config=self.cueq_config ) - self.reshape = reshape_irreps(self.irreps_out) + self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config) def forward( self, @@ -595,12 +627,15 @@ def forward( @compile_mode("script") class RealAgnosticResidualInteractionBlock(InteractionBlock): def _setup(self) -> None: + if not hasattr(self, "cueq_config"): + self.cueq_config = None # First linear - self.linear_up = o3.Linear( + self.linear_up = Linear( self.node_feats_irreps, self.node_feats_irreps, internal_weights=True, shared_weights=True, + cueq_config=self.cueq_config, ) # TensorProduct irreps_mid, instructions = tp_out_irreps_with_instructions( @@ -608,13 +643,14 @@ def _setup(self) -> None: self.edge_attrs_irreps, self.target_irreps, ) - self.conv_tp = o3.TensorProduct( + self.conv_tp = TensorProduct( self.node_feats_irreps, self.edge_attrs_irreps, irreps_mid, instructions=instructions, shared_weights=False, internal_weights=False, + cueq_config=self.cueq_config, ) # Convolution weights @@ -627,15 +663,15 @@ def _setup(self) -> None: # Linear irreps_mid = irreps_mid.simplify() self.irreps_out = self.target_irreps - self.linear = o3.Linear( - irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + self.linear = Linear( + irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True, cueq_config=self.cueq_config ) # Selector TensorProduct - self.skip_tp = o3.FullyConnectedTensorProduct( - self.node_feats_irreps, self.node_attrs_irreps, self.hidden_irreps + self.skip_tp = FullyConnectedTensorProduct( + self.node_feats_irreps, self.node_attrs_irreps, self.hidden_irreps, cueq_config=self.cueq_config ) - self.reshape = reshape_irreps(self.irreps_out) + self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config) def forward( self, @@ -667,12 +703,15 @@ def forward( @compile_mode("script") class RealAgnosticDensityInteractionBlock(InteractionBlock): def _setup(self) -> None: + if not hasattr(self, "cueq_config"): + self.cueq_config = None # First linear - self.linear_up = o3.Linear( + self.linear_up = Linear( self.node_feats_irreps, self.node_feats_irreps, internal_weights=True, shared_weights=True, + cueq_config=self.cueq_config ) # TensorProduct irreps_mid, instructions = tp_out_irreps_with_instructions( @@ -680,13 +719,14 @@ def _setup(self) -> None: self.edge_attrs_irreps, self.target_irreps, ) - self.conv_tp = o3.TensorProduct( + self.conv_tp = TensorProduct( self.node_feats_irreps, self.edge_attrs_irreps, irreps_mid, instructions=instructions, shared_weights=False, internal_weights=False, + cueq_config=self.cueq_config, ) # Convolution weights @@ -699,15 +739,14 @@ def _setup(self) -> None: # Linear irreps_mid = irreps_mid.simplify() self.irreps_out = self.target_irreps - self.linear = o3.Linear( - irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + self.linear = Linear( + irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True, cueq_config=self.cueq_config ) # Selector TensorProduct - self.skip_tp = o3.FullyConnectedTensorProduct( - self.irreps_out, self.node_attrs_irreps, self.irreps_out + self.skip_tp = FullyConnectedTensorProduct( + self.irreps_out, self.node_attrs_irreps, self.irreps_out, cueq_config=self.cueq_config ) - self.reshape = reshape_irreps(self.irreps_out) # Density normalization self.density_fn = nn.FullyConnectedNet( @@ -718,7 +757,7 @@ def _setup(self) -> None: torch.nn.functional.silu, ) # Reshape - self.reshape = reshape_irreps(self.irreps_out) + self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config) def forward( self, @@ -754,12 +793,16 @@ def forward( @compile_mode("script") class RealAgnosticDensityResidualInteractionBlock(InteractionBlock): def _setup(self) -> None: + if not hasattr(self, "cueq_config"): + self.cueq_config = None + # First linear - self.linear_up = o3.Linear( + self.linear_up = Linear( self.node_feats_irreps, self.node_feats_irreps, internal_weights=True, shared_weights=True, + cueq_config=self.cueq_config, ) # TensorProduct irreps_mid, instructions = tp_out_irreps_with_instructions( @@ -767,13 +810,14 @@ def _setup(self) -> None: self.edge_attrs_irreps, self.target_irreps, ) - self.conv_tp = o3.TensorProduct( + self.conv_tp = TensorProduct( self.node_feats_irreps, self.edge_attrs_irreps, irreps_mid, instructions=instructions, shared_weights=False, internal_weights=False, + cueq_config=self.cueq_config, ) # Convolution weights @@ -786,15 +830,14 @@ def _setup(self) -> None: # Linear irreps_mid = irreps_mid.simplify() self.irreps_out = self.target_irreps - self.linear = o3.Linear( - irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + self.linear = Linear( + irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True, cueq_config=self.cueq_config ) # Selector TensorProduct - self.skip_tp = o3.FullyConnectedTensorProduct( - self.node_feats_irreps, self.node_attrs_irreps, self.hidden_irreps + self.skip_tp = FullyConnectedTensorProduct( + self.node_feats_irreps, self.node_attrs_irreps, self.hidden_irreps, cueq_config=self.cueq_config ) - self.reshape = reshape_irreps(self.irreps_out) # Density normalization self.density_fn = nn.FullyConnectedNet( @@ -806,7 +849,7 @@ def _setup(self) -> None: ) # Reshape - self.reshape = reshape_irreps(self.irreps_out) + self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config) def forward( self, @@ -842,13 +885,16 @@ def forward( @compile_mode("script") class RealAgnosticAttResidualInteractionBlock(InteractionBlock): def _setup(self) -> None: + if not hasattr(self, "cueq_config"): + self.cueq_config = None self.node_feats_down_irreps = o3.Irreps("64x0e") # First linear - self.linear_up = o3.Linear( + self.linear_up = Linear( self.node_feats_irreps, self.node_feats_irreps, internal_weights=True, shared_weights=True, + cueq_config=self.cueq_config ) # TensorProduct irreps_mid, instructions = tp_out_irreps_with_instructions( @@ -856,21 +902,23 @@ def _setup(self) -> None: self.edge_attrs_irreps, self.target_irreps, ) - self.conv_tp = o3.TensorProduct( + self.conv_tp = TensorProduct( self.node_feats_irreps, self.edge_attrs_irreps, irreps_mid, instructions=instructions, shared_weights=False, internal_weights=False, + cueq_config=self.cueq_config, ) # Convolution weights - self.linear_down = o3.Linear( + self.linear_down = Linear( self.node_feats_irreps, self.node_feats_down_irreps, internal_weights=True, shared_weights=True, + cueq_config=self.cueq_config, ) input_dim = ( self.edge_feats_irreps.num_irreps @@ -884,17 +932,18 @@ def _setup(self) -> None: # Linear irreps_mid = irreps_mid.simplify() self.irreps_out = self.target_irreps - self.linear = o3.Linear( + self.linear = Linear( irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True, + cueq_config=self.cueq_config ) - self.reshape = reshape_irreps(self.irreps_out) + self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config) # Skip connection. - self.skip_linear = o3.Linear(self.node_feats_irreps, self.hidden_irreps) + self.skip_linear = Linear(self.node_feats_irreps, self.hidden_irreps, cueq_config=self.cueq_config) def forward( self, diff --git a/mace/modules/irreps_tools.py b/mace/modules/irreps_tools.py index b0960193..81122d30 100644 --- a/mace/modules/irreps_tools.py +++ b/mace/modules/irreps_tools.py @@ -4,12 +4,20 @@ # This program is distributed under the MIT License (see MIT.md) ########################################################################################### -from typing import List, Tuple +from typing import List, Optional, Tuple import torch from e3nn import o3 from e3nn.util.jit import compile_mode +from mace.modules.wrapper_ops import CuEquivarianceConfig + +try: + import cuequivariance as cue + LAYOUTS = cue +except ImportError: + LAYOUTS = DefaultLayouts + # Based on mir-group/nequip def tp_out_irreps_with_instructions( @@ -61,12 +69,16 @@ def linear_out_irreps(irreps: o3.Irreps, target_irreps: o3.Irreps) -> o3.Irreps: return o3.Irreps(irreps_mid) +class DefaultLayouts: + mul_ir = "mul_ir" + ir_mul = "ir_mul" @compile_mode("script") class reshape_irreps(torch.nn.Module): - def __init__(self, irreps: o3.Irreps) -> None: + def __init__(self, irreps: o3.Irreps, cueq_config: Optional[CuEquivarianceConfig] = None) -> None: super().__init__() self.irreps = o3.Irreps(irreps) + self.cueq_config = cueq_config self.dims = [] self.muls = [] for mul, ir in self.irreps: @@ -81,10 +93,16 @@ def forward(self, tensor: torch.Tensor) -> torch.Tensor: for mul, d in zip(self.muls, self.dims): field = tensor[:, ix : ix + mul * d] # [batch, sample, mul * repr] ix += mul * d - field = field.reshape(batch, mul, d) + if hasattr(self, "cueq_config") and (not self.cueq_config or self.cueq_config.layout_str == "mul_ir"): + field = field.reshape(batch, mul, d) + else: + field = field.reshape(batch, d, mul) out.append(field) - return torch.cat(out, dim=-1) - + + if hasattr(self, "cueq_config") and (not self.cueq_config or self.cueq_config.layout_str == "mul_ir"): + return torch.cat(out, dim=-1) + else: + return torch.cat(out, dim=-2) def mask_head(x: torch.Tensor, head: torch.Tensor, num_heads: int) -> torch.Tensor: mask = torch.zeros(x.shape[0], x.shape[1] // num_heads, num_heads, device=x.device) diff --git a/mace/modules/wrapper_ops.py b/mace/modules/wrapper_ops.py new file mode 100644 index 00000000..a80b7df5 --- /dev/null +++ b/mace/modules/wrapper_ops.py @@ -0,0 +1,220 @@ +""" +Wrapper class for o3.Linear that optionally uses cuet.Linear +""" +import dataclasses +import torch +from typing import List, Optional +import e3nn.o3 as o3 + +from mace.modules.symmetric_contraction import SymmetricContraction +try: + import cuequivariance as cue + import cuequivariance_torch as cuet + CUET_AVAILABLE = True +except ImportError: + CUET_AVAILABLE = False + +@dataclasses.dataclass +class CuEquivarianceConfig: + """Configuration for cuequivariance acceleration""" + enabled: bool = False + layout: str = "mul_ir" # One of: mul_ir, ir_mul + group: str = "O3" + optimize_all: bool = False # Set to True to enable all optimizations + optimize_linear: bool = False + optimize_channelwise: bool = False + optimize_symmetric: bool = False + optimize_fctp: bool = False + + def __post_init__(self): + if self.enabled and CUET_AVAILABLE: + self.layout_str = self.layout + self.layout = getattr(cue, self.layout) + self.group = getattr(cue, self.group) + +class Linear(torch.nn.Module): + """Wrapper around o3.Linear that optionally uses cuet.Linear when enabled""" + def __init__( + self, + irreps_in: o3.Irreps, + irreps_out: o3.Irreps, + shared_weights: bool = True, + internal_weights: bool = True, + cueq_config: Optional[CuEquivarianceConfig] = None, + ): + super().__init__() + if (CUET_AVAILABLE and cueq_config is not None and cueq_config.enabled + and (cueq_config.optimize_all or cueq_config.optimize_linear)): + self.linear = cuet.Linear( + cue.Irreps(cueq_config.group, irreps_in), + cue.Irreps(cueq_config.group, irreps_out), + layout=cueq_config.layout, + shared_weights=shared_weights, + optimize_fallback=not cueq_config.optimize_linear + ) + self.use_cuet = True + self.cueq_config = cueq_config + else: + self.linear = o3.Linear( + irreps_in, + irreps_out, + shared_weights=shared_weights, + internal_weights=internal_weights + ) + self.use_cuet = False + + def __getattr__(self, name): + """Forward any unknown attribute access to the underlying linear object""" + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.linear, name) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.use_cuet: + return self.linear(x, use_fallback=not self.cueq_config.optimize_linear) + return self.linear(x) + +class TensorProduct(torch.nn.Module): + """Wrapper around o3.TensorProduct/cuet.ChannelwiseTensorProduct""" + def __init__( + self, + irreps_in1: o3.Irreps, + irreps_in2: o3.Irreps, + irreps_out: o3.Irreps, + instructions: Optional[List] = None, + shared_weights: bool = False, + internal_weights: bool = False, + cueq_config: Optional[CuEquivarianceConfig] = None, + ): + super().__init__() + if (CUET_AVAILABLE and cueq_config is not None and cueq_config.enabled + and (cueq_config.optimize_all or cueq_config.optimize_channelwise)): + self.tp = cuet.ChannelwiseTensorProduct( + cue.Irreps(cueq_config.group, irreps_in1), + cue.Irreps(cueq_config.group, irreps_in2), + cue.Irreps(cueq_config.group, irreps_out), + layout=cueq_config.layout, + shared_weights=shared_weights, + internal_weights=internal_weights, + optimize_fallback=not cueq_config.optimize_channelwise + ) + self.use_cuet = True + self.cueq_config = cueq_config + else: + self.tp = o3.TensorProduct( + irreps_in1, + irreps_in2, + irreps_out, + instructions=instructions, + shared_weights=shared_weights, + internal_weights=internal_weights + ) + self.use_cuet = False + + def __getattr__(self, name): + """Forward any unknown attribute access to the underlying linear object""" + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.tp, name) + + def forward(self, x1: torch.Tensor, x2: torch.Tensor, weights: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.use_cuet: + return self.tp(x1, x2, weights, use_fallback=not self.cueq_config.optimize_channelwise) + return self.tp(x1, x2, weights) + +class FullyConnectedTensorProduct(torch.nn.Module): + """Wrapper around o3.FullyConnectedTensorProduct/cuet.FullyConnectedTensorProduct""" + def __init__( + self, + irreps_in1: o3.Irreps, + irreps_in2: o3.Irreps, + irreps_out: o3.Irreps, + shared_weights: bool = True, + internal_weights: bool = True, + cueq_config: Optional[CuEquivarianceConfig] = None, + ): + super().__init__() + if (CUET_AVAILABLE and cueq_config is not None and cueq_config.enabled + and (cueq_config.optimize_all or cueq_config.optimize_fctp)): + self.tp = cuet.FullyConnectedTensorProduct( + cue.Irreps(cueq_config.group, irreps_in1), + cue.Irreps(cueq_config.group, irreps_in2), + cue.Irreps(cueq_config.group, irreps_out), + layout=cueq_config.layout, + shared_weights=shared_weights, + internal_weights=internal_weights, + optimize_fallback=not cueq_config.optimize_fctp + ) + self.use_cuet = True + self.cueq_config = cueq_config + else: + self.tp = o3.FullyConnectedTensorProduct( + irreps_in1, + irreps_in2, + irreps_out, + shared_weights=shared_weights, + internal_weights=internal_weights + ) + self.use_cuet = False + + def __getattr__(self, name): + """Forward any unknown attribute access to the underlying linear object""" + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.tp, name) + + def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: + if self.use_cuet: + return self.tp(x1, x2, use_fallback=not self.cueq_config.optimize_fctp) + return self.tp(x1, x2) + +class SymmetricContractionWrapper(torch.nn.Module): + """Wrapper around SymmetricContraction/cuet.SymmetricContraction""" + def __init__( + self, + irreps_in: o3.Irreps, + irreps_out: o3.Irreps, + correlation: int, + num_elements: Optional[int] = None, + cueq_config: Optional[CuEquivarianceConfig] = None, + ): + super().__init__() + if (CUET_AVAILABLE and cueq_config is not None and cueq_config.enabled + and (cueq_config.optimize_all or cueq_config.optimize_symmetric)): + self.sc = cuet.SymmetricContraction( + cue.Irreps(cueq_config.group, irreps_in), + cue.Irreps(cueq_config.group, irreps_out), + layout_in=cue.ir_mul, + layout_out=cueq_config.layout, + contraction_degree=correlation, + num_elements=num_elements, + optimize_fallback=not cueq_config.optimize_symmetric + ) + self.use_cuet = True + self.cueq_config = cueq_config + self.layout = cueq_config.layout + else: + self.sc = SymmetricContraction( + irreps_in=irreps_in, + irreps_out=irreps_out, + correlation=correlation, + num_elements=num_elements + ) + self.use_cuet = False + + def __getattr__(self, name): + """Forward any unknown attribute access to the underlying linear object""" + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.sc, name) + + def forward(self, x: torch.Tensor, attrs: torch.Tensor) -> torch.Tensor: + if self.use_cuet: + if self.layout == cue.mul_ir: + x = torch.transpose(x, 1, 2) + return self.sc(x.flatten(1), attrs, use_fallback=not self.cueq_config.optimize_symmetric) + return self.sc(x, attrs) \ No newline at end of file From 126fa91bd10e68145c8b334432e540ae94a05b32 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Fri, 15 Nov 2024 10:39:21 +0000 Subject: [PATCH 03/17] fix the config --- mace/__init__.py | 72 +----------------------------------- mace/modules/blocks.py | 3 +- mace/modules/irreps_tools.py | 18 ++++++--- mace/modules/models.py | 12 ++++-- mace/modules/wrapper_ops.py | 16 ++++---- mace/py.typed | 1 - setup.cfg | 2 +- 7 files changed, 33 insertions(+), 91 deletions(-) delete mode 100644 mace/py.typed diff --git a/mace/__init__.py b/mace/__init__.py index 8ad80243..77f33d04 100644 --- a/mace/__init__.py +++ b/mace/__init__.py @@ -1,71 +1,3 @@ -from .arg_parser import build_default_arg_parser, build_preprocess_arg_parser -from .arg_parser_tools import check_args -from .cg import U_matrix_real -from .checkpoint import CheckpointHandler, CheckpointIO, CheckpointState -from .finetuning_utils import load_foundations, load_foundations_elements -from .torch_tools import ( - TensorDict, - cartesian_to_spherical, - count_parameters, - init_device, - init_wandb, - set_default_dtype, - set_seeds, - spherical_to_cartesian, - to_numpy, - to_one_hot, - voigt_to_matrix, -) -from .train import SWAContainer, evaluate, train -from .utils import ( - AtomicNumberTable, - MetricsLogger, - atomic_numbers_to_indices, - compute_c, - compute_mae, - compute_q95, - compute_rel_mae, - compute_rel_rmse, - compute_rmse, - get_atomic_number_table_from_zs, - get_tag, - setup_logger, -) +from .__version__ import __version__ -__all__ = [ - "TensorDict", - "AtomicNumberTable", - "atomic_numbers_to_indices", - "to_numpy", - "to_one_hot", - "build_default_arg_parser", - "check_args", - "set_seeds", - "init_device", - "setup_logger", - "get_tag", - "count_parameters", - "MetricsLogger", - "get_atomic_number_table_from_zs", - "train", - "evaluate", - "SWAContainer", - "CheckpointHandler", - "CheckpointIO", - "CheckpointState", - "set_default_dtype", - "compute_mae", - "compute_rel_mae", - "compute_rmse", - "compute_rel_rmse", - "compute_q95", - "compute_c", - "U_matrix_real", - "spherical_to_cartesian", - "cartesian_to_spherical", - "voigt_to_matrix", - "init_wandb", - "load_foundations", - "load_foundations_elements", - "build_preprocess_arg_parser", -] +__all__ = ["__version__"] \ No newline at end of file diff --git a/mace/modules/blocks.py b/mace/modules/blocks.py index b5e6bfb0..1728097c 100644 --- a/mace/modules/blocks.py +++ b/mace/modules/blocks.py @@ -42,8 +42,7 @@ @compile_mode("script") class LinearNodeEmbeddingBlock(torch.nn.Module): - def __init__(self, irreps_in: o3.Irreps, irreps_out: o3.Irreps, cueq_config: Optional[CuEquivarianceConfig] = None, -): + def __init__(self, irreps_in: o3.Irreps, irreps_out: o3.Irreps, cueq_config: Optional[CuEquivarianceConfig] = None): super().__init__() self.linear = Linear(irreps_in=irreps_in, irreps_out=irreps_out, cueq_config=cueq_config) diff --git a/mace/modules/irreps_tools.py b/mace/modules/irreps_tools.py index 81122d30..597e439a 100644 --- a/mace/modules/irreps_tools.py +++ b/mace/modules/irreps_tools.py @@ -93,16 +93,22 @@ def forward(self, tensor: torch.Tensor) -> torch.Tensor: for mul, d in zip(self.muls, self.dims): field = tensor[:, ix : ix + mul * d] # [batch, sample, mul * repr] ix += mul * d - if hasattr(self, "cueq_config") and (not self.cueq_config or self.cueq_config.layout_str == "mul_ir"): - field = field.reshape(batch, mul, d) + if hasattr(self, "cueq_config") and self.cueq_config is not None: + if self.cueq_config.layout_str == "mul_ir": + field = field.reshape(batch, mul, d) + else: + field = field.reshape(batch, d, mul) else: - field = field.reshape(batch, d, mul) + field = field.reshape(batch, mul, d) out.append(field) - if hasattr(self, "cueq_config") and (not self.cueq_config or self.cueq_config.layout_str == "mul_ir"): - return torch.cat(out, dim=-1) + if hasattr(self, "cueq_config") and self.cueq_config is not None: + if self.cueq_config.layout_str == "mul_ir": + return torch.cat(out, dim=-1) + else: + return torch.cat(out, dim=-2) else: - return torch.cat(out, dim=-2) + return torch.cat(out, dim=-1) def mask_head(x: torch.Tensor, head: torch.Tensor, num_heads: int) -> torch.Tensor: mask = torch.zeros(x.shape[0], x.shape[1] // num_heads, num_heads, device=x.device) diff --git a/mace/modules/models.py b/mace/modules/models.py index c0d8ab43..419af4f1 100644 --- a/mace/modules/models.py +++ b/mace/modules/models.py @@ -62,6 +62,7 @@ def __init__( radial_MLP: Optional[List[int]] = None, radial_type: Optional[str] = "bessel", heads: Optional[List[str]] = None, + cueq_config: Optional[Dict[str, Any]] = None, ): super().__init__() self.register_buffer( @@ -82,7 +83,7 @@ def __init__( node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) self.node_embedding = LinearNodeEmbeddingBlock( - irreps_in=node_attr_irreps, irreps_out=node_feats_irreps + irreps_in=node_attr_irreps, irreps_out=node_feats_irreps, cueq_config=cueq_config ) self.radial_embedding = RadialEmbeddingBlock( r_max=r_max, @@ -116,6 +117,7 @@ def __init__( hidden_irreps=hidden_irreps, avg_num_neighbors=avg_num_neighbors, radial_MLP=radial_MLP, + cueq_config=cueq_config, ) self.interactions = torch.nn.ModuleList([inter]) @@ -131,12 +133,13 @@ def __init__( correlation=correlation[0], num_elements=num_elements, use_sc=use_sc_first, + cueq_config=cueq_config, ) self.products = torch.nn.ModuleList([prod]) self.readouts = torch.nn.ModuleList() self.readouts.append( - LinearReadoutBlock(hidden_irreps, o3.Irreps(f"{len(heads)}x0e")) + LinearReadoutBlock(hidden_irreps, o3.Irreps(f"{len(heads)}x0e"), cueq_config) ) for i in range(num_interactions - 1): @@ -155,6 +158,7 @@ def __init__( hidden_irreps=hidden_irreps_out, avg_num_neighbors=avg_num_neighbors, radial_MLP=radial_MLP, + cueq_config=cueq_config, ) self.interactions.append(inter) prod = EquivariantProductBasisBlock( @@ -163,6 +167,7 @@ def __init__( correlation=correlation[i + 1], num_elements=num_elements, use_sc=True, + cueq_config=cueq_config, ) self.products.append(prod) if i == num_interactions - 2: @@ -173,11 +178,12 @@ def __init__( gate, o3.Irreps(f"{len(heads)}x0e"), len(heads), + cueq_config, ) ) else: self.readouts.append( - LinearReadoutBlock(hidden_irreps, o3.Irreps(f"{len(heads)}x0e")) + LinearReadoutBlock(hidden_irreps, o3.Irreps(f"{len(heads)}x0e"), cueq_config) ) def forward( diff --git a/mace/modules/wrapper_ops.py b/mace/modules/wrapper_ops.py index a80b7df5..d820abb9 100644 --- a/mace/modules/wrapper_ops.py +++ b/mace/modules/wrapper_ops.py @@ -71,7 +71,7 @@ def __getattr__(self, name): return getattr(self.linear, name) def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.use_cuet: + if self.use_cuet and hasattr(self, "cueq_config"): return self.linear(x, use_fallback=not self.cueq_config.optimize_linear) return self.linear(x) @@ -120,7 +120,7 @@ def __getattr__(self, name): return getattr(self.tp, name) def forward(self, x1: torch.Tensor, x2: torch.Tensor, weights: Optional[torch.Tensor] = None) -> torch.Tensor: - if self.use_cuet: + if self.use_cuet and hasattr(self, "cueq_config"): return self.tp(x1, x2, weights, use_fallback=not self.cueq_config.optimize_channelwise) return self.tp(x1, x2, weights) @@ -167,7 +167,7 @@ def __getattr__(self, name): return getattr(self.tp, name) def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: - if self.use_cuet: + if self.use_cuet and hasattr(self, "cueq_config"): return self.tp(x1, x2, use_fallback=not self.cueq_config.optimize_fctp) return self.tp(x1, x2) @@ -184,7 +184,7 @@ def __init__( super().__init__() if (CUET_AVAILABLE and cueq_config is not None and cueq_config.enabled and (cueq_config.optimize_all or cueq_config.optimize_symmetric)): - self.sc = cuet.SymmetricContraction( + self.sconctaction = cuet.SymmetricContraction( cue.Irreps(cueq_config.group, irreps_in), cue.Irreps(cueq_config.group, irreps_out), layout_in=cue.ir_mul, @@ -197,7 +197,7 @@ def __init__( self.cueq_config = cueq_config self.layout = cueq_config.layout else: - self.sc = SymmetricContraction( + self.sconctaction = SymmetricContraction( irreps_in=irreps_in, irreps_out=irreps_out, correlation=correlation, @@ -213,8 +213,8 @@ def __getattr__(self, name): return getattr(self.sc, name) def forward(self, x: torch.Tensor, attrs: torch.Tensor) -> torch.Tensor: - if self.use_cuet: + if self.use_cuet and hasattr(self, "cueq_config"): if self.layout == cue.mul_ir: x = torch.transpose(x, 1, 2) - return self.sc(x.flatten(1), attrs, use_fallback=not self.cueq_config.optimize_symmetric) - return self.sc(x, attrs) \ No newline at end of file + return self.sconctaction(x.flatten(1), attrs, use_fallback=not self.cueq_config.optimize_symmetric) + return self.sconctaction(x, attrs) \ No newline at end of file diff --git a/mace/py.typed b/mace/py.typed deleted file mode 100644 index 8b137891..00000000 --- a/mace/py.typed +++ /dev/null @@ -1 +0,0 @@ - diff --git a/setup.cfg b/setup.cfg index 76467fda..b5b2c7a2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [metadata] -name = mace-torch +name = mace-cueq version = attr: mace.__version__ short_description = MACE - Fast and accurate machine learning interatomic potentials with higher order equivariant message passing. long_description = file: README.md From 24a870da835a23ad30ecfbfd4276f825606acbf1 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Fri, 15 Nov 2024 13:50:59 +0000 Subject: [PATCH 04/17] add cueq_config to the run_train --- mace/modules/models.py | 3 ++ mace/modules/wrapper_ops.py | 32 ++++++++++++++++++-- mace/tools/arg_parser.py | 51 ++++++++++++++++++++++++++++++++ mace/tools/model_script_utils.py | 12 ++++++++ 4 files changed, 96 insertions(+), 2 deletions(-) diff --git a/mace/modules/models.py b/mace/modules/models.py index 419af4f1..345ee213 100644 --- a/mace/modules/models.py +++ b/mace/modules/models.py @@ -477,6 +477,7 @@ def __init__( gate: Optional[Callable], avg_num_neighbors: float, atomic_numbers: List[int], + cueq_config: Optional[Dict[str, Any]] = None, ): super().__init__() self.r_max = r_max @@ -681,6 +682,7 @@ def __init__( ], # Just here to make it compatible with energy models, MUST be None radial_type: Optional[str] = "bessel", radial_MLP: Optional[List[int]] = None, + cueq_config: Optional[Dict[str, Any]] = None, ): super().__init__() self.register_buffer( @@ -882,6 +884,7 @@ def __init__( gate: Optional[Callable], atomic_energies: Optional[np.ndarray], radial_MLP: Optional[List[int]] = None, + cueq_config: Optional[Dict[str, Any]] = None, ): super().__init__() self.register_buffer( diff --git a/mace/modules/wrapper_ops.py b/mace/modules/wrapper_ops.py index d820abb9..6f295de7 100644 --- a/mace/modules/wrapper_ops.py +++ b/mace/modules/wrapper_ops.py @@ -2,8 +2,10 @@ Wrapper class for o3.Linear that optionally uses cuet.Linear """ import dataclasses +import itertools +import numpy as np import torch -from typing import List, Optional +from typing import Iterator, List, Optional import e3nn.o3 as o3 from mace.modules.symmetric_contraction import SymmetricContraction @@ -30,7 +32,33 @@ def __post_init__(self): if self.enabled and CUET_AVAILABLE: self.layout_str = self.layout self.layout = getattr(cue, self.layout) - self.group = getattr(cue, self.group) + self.group = O3_e3nn if self.group == "O3" else getattr(cue, self.group) + +class O3_e3nn(cue.O3): + def __mul__(rep1: "O3_e3nn", rep2: "O3_e3nn") -> Iterator["O3_e3nn"]: + return [O3_e3nn(l=ir.l, p=ir.p) for ir in cue.O3.__mul__(rep1, rep2)] + + @classmethod + def clebsch_gordan( + cls, rep1: "O3_e3nn", rep2: "O3_e3nn", rep3: "O3_e3nn" + ) -> np.ndarray: + rep1, rep2, rep3 = cls._from(rep1), cls._from(rep2), cls._from(rep3) + + if rep1.p * rep2.p == rep3.p: + return o3.wigner_3j(rep1.l, rep2.l, rep3.l).numpy()[None] * np.sqrt(rep3.dim) + else: + return np.zeros((0, rep1.dim, rep2.dim, rep3.dim)) + + def __lt__(rep1: "O3_e3nn", rep2: "O3_e3nn") -> bool: + rep2 = rep1._from(rep2) + return (rep1.l, rep1.p) < (rep2.l, rep2.p) + + @classmethod + def iterator(cls) -> Iterator["O3_e3nn"]: + for l in itertools.count(0): + yield O3_e3nn(l=l, p=1 * (-1) ** l) + yield O3_e3nn(l=l, p=-1 * (-1) ** l) + class Linear(torch.nn.Module): """Wrapper around o3.Linear that optionally uses cuet.Linear when enabled""" diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index cb4f8ac5..8cf6990a 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -660,6 +660,57 @@ def build_default_arg_parser() -> argparse.ArgumentParser: type=check_float_or_none, default=10.0, ) + # option for cuequivariance acceleration + parser.add_argument( + "--cue_enabled", + help="Enable cuequivariance acceleration", + type=str2bool, + default=False + ) + parser.add_argument( + "--cue_layout", + help="Memory layout for cuequivariance tensors", + type=str, + choices=["mul_ir", "ir_mul"], + default="mul_ir" + ) + parser.add_argument( + "--cue_group", + help="Symmetry group for cuequivariance", + type=str, + choices=["O3nn, O3"], + default="O3nn" + ) + parser.add_argument( + "--cue_optimize_all", + help="Enable all cuequivariance optimizations", + type=str2bool, + default=False + ) + parser.add_argument( + "--cue_optimize_linear", + help="Enable cuequivariance linear layer optimization", + type=str2bool, + default=False + ) + parser.add_argument( + "--cue_optimize_channelwise", + help="Enable cuequivariance channelwise optimization", + type=str2bool, + default=False + ) + parser.add_argument( + "--cue_optimize_symmetric", + help="Enable cuequivariance symmetric contraction optimization", + type=str2bool, + default=False + ) + parser.add_argument( + "--cue_optimize_fctp", + help="Enable cuequivariance fully connected tensor product optimization", + type=str2bool, + default=False + ) # options for using Weights and Biases for experiment tracking # to install see https://wandb.ai parser.add_argument( diff --git a/mace/tools/model_script_utils.py b/mace/tools/model_script_utils.py index 3f49eb41..25ef0a93 100644 --- a/mace/tools/model_script_utils.py +++ b/mace/tools/model_script_utils.py @@ -5,6 +5,7 @@ from e3nn import o3 from mace import modules +from mace.modules.wrapper_ops import CuEquivarianceConfig from mace.tools.finetuning_utils import load_foundations_elements from mace.tools.scripts_utils import extract_config_mace_model @@ -28,6 +29,16 @@ def configure_model( logging.info( f"During training the following quantities will be reported: {', '.join([f'{report}' for report, value in output_args.items() if value])}" ) + cueq_config = CuEquivarianceConfig( + enabled=args.cue_enabled, + layout=args.cue_layout, + group=args.cue_group, + optimize_all=args.cue_optimize_all, + optimize_linear=args.cue_optimize_linear, + optimize_channelwise=args.cue_optimize_channelwise, + optimize_symmetric=args.cue_optimize_symmetric, + optimize_fctp=args.cue_optimize_fctp, + ) logging.info("===========MODEL DETAILS===========") if args.scaling == "no_scaling": @@ -109,6 +120,7 @@ def configure_model( atomic_energies=atomic_energies, avg_num_neighbors=args.avg_num_neighbors, atomic_numbers=z_table.zs, + cueq_config=cueq_config, ) model_config_foundation = None From 96a9f5f0c8557c05a9d090fe3c6bd82f1d4b192e Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Fri, 15 Nov 2024 14:32:44 +0000 Subject: [PATCH 05/17] Create test_cueq.py --- tests/test_cueq.py | 102 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 tests/test_cueq.py diff --git a/tests/test_cueq.py b/tests/test_cueq.py new file mode 100644 index 00000000..256f3570 --- /dev/null +++ b/tests/test_cueq.py @@ -0,0 +1,102 @@ +import pytest +import torch +import torch.nn.functional as F +from e3nn import o3 +from typing import Dict, Any + +from mace import data, modules, tools +from mace.modules.wrapper_ops import CuEquivarianceConfig +from mace.tools import torch_geometric + +try: + import cuequivariance as cue + CUET_AVAILABLE = True +except ImportError: + CUET_AVAILABLE = False + +@pytest.mark.skipif(not CUET_AVAILABLE, reason="cuequivariance not installed") +class TestCueq: + @pytest.fixture + def model_config(self) -> Dict[str, Any]: + table = tools.AtomicNumberTable([6]) + return { + "r_max": 5.0, + "num_bessel": 8, + "num_polynomial_cutoff": 6, + "max_ell": 3, + "interaction_cls": modules.interaction_classes["RealAgnosticResidualInteractionBlock"], + "interaction_cls_first": modules.interaction_classes["RealAgnosticResidualInteractionBlock"], + "num_interactions": 2, + "num_elements": 1, + "hidden_irreps": o3.Irreps("32x0e + 32x1o"), + "MLP_irreps": o3.Irreps("16x0e"), + "gate": F.silu, + "atomic_energies": torch.tensor([1.0]), + "avg_num_neighbors": 8, + "atomic_numbers": table.zs, + "correlation": 3, + "radial_type": "bessel" + } + + @pytest.fixture + def batch(self, device: str): + from ase import build + table = tools.AtomicNumberTable([6]) + + atoms = build.bulk("C", "diamond", a=3.567, cubic=True) + atoms_list = [atoms.repeat((2, 2, 2))] + + configs = [data.config_from_atoms(atoms) for atoms in atoms_list] + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[data.AtomicData.from_config(config, z_table=table, cutoff=5.0) + for config in configs], + batch_size=1, + shuffle=False, + drop_last=False, + ) + batch = next(iter(data_loader)) + return batch.to(device).to_dict() + + @pytest.mark.parametrize("device", ["cuda"]) + def test_cueq_equivalence(self, model_config: Dict[str, Any], batch: Dict[str, torch.Tensor], device: str): + torch.manual_seed(42) + + # Create model without cuequivariance + model_std = modules.MACE(**model_config) + model_std = model_std.to(device) + + # Create model with cuequivariance + cueq_config = CuEquivarianceConfig( + enabled=True, + layout="mul_ir", + group="O3nn", + optimize_all=True + ) + model_config["cueq_config"] = cueq_config + model_cueq = modules.MACE(**model_config) + model_cueq = model_cueq.to(device) + + # Copy weights + model_cueq.load_state_dict(model_std.state_dict()) + + # Compare outputs + with torch.no_grad(): + out_std = model_std(batch, training=True) + out_cueq = model_cueq(batch, training=True) + + torch.testing.assert_close(out_std["energy"], out_cueq["energy"]) + torch.testing.assert_close(out_std["forces"], out_cueq["forces"]) + + # Test gradients + out_std = model_std(batch, training=True) + out_cueq = model_cueq(batch, training=True) + + loss_std = out_std["energy"].sum() + loss_cueq = out_cueq["energy"].sum() + + loss_std.backward() + loss_cueq.backward() + + for p1, p2 in zip(model_std.parameters(), model_cueq.parameters()): + if p1.grad is not None: + torch.testing.assert_close(p1.grad, p2.grad) \ No newline at end of file From f09d6fd1a8daabaa637f21bb6e33b681b0e090da Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Fri, 15 Nov 2024 15:47:48 +0000 Subject: [PATCH 06/17] fix the formatting --- .pre-commit-config.yaml | 1 + mace/__init__.py | 2 +- mace/modules/blocks.py | 164 ++++++++++++++++++++++++++--------- mace/modules/irreps_tools.py | 23 ++--- mace/modules/models.py | 18 ++-- mace/modules/wrapper_ops.py | 156 +++++++++++++++++++++------------ mace/tools/arg_parser.py | 32 +++---- pyproject.toml | 3 + tests/test_cueq.py | 42 +++++---- 9 files changed, 287 insertions(+), 154 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d78624bb..6f8c2daa 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -55,5 +55,6 @@ repos: '--disable=cell-var-from-loop', '--disable=duplicate-code', '--disable=use-dict-literal', + '--max-module-lines=1500', ] exclude: *exclude_files \ No newline at end of file diff --git a/mace/__init__.py b/mace/__init__.py index 77f33d04..6ebb05a2 100644 --- a/mace/__init__.py +++ b/mace/__init__.py @@ -1,3 +1,3 @@ from .__version__ import __version__ -__all__ = ["__version__"] \ No newline at end of file +__all__ = ["__version__"] diff --git a/mace/modules/blocks.py b/mace/modules/blocks.py index 1728097c..7bc3561f 100644 --- a/mace/modules/blocks.py +++ b/mace/modules/blocks.py @@ -5,7 +5,6 @@ ########################################################################################### from abc import abstractmethod -import dataclasses from typing import Callable, List, Optional, Tuple, Union import numpy as np @@ -13,7 +12,13 @@ from e3nn import nn, o3 from e3nn.util.jit import compile_mode -from mace.modules.wrapper_ops import CuEquivarianceConfig, FullyConnectedTensorProduct, Linear, SymmetricContractionWrapper, TensorProduct +from mace.modules.wrapper_ops import ( + CuEquivarianceConfig, + FullyConnectedTensorProduct, + Linear, + SymmetricContractionWrapper, + TensorProduct, +) from mace.tools.compile import simplify_if_compile from mace.tools.scatter import scatter_sum @@ -31,20 +36,20 @@ PolynomialCutoff, SoftTransform, ) -from .symmetric_contraction import SymmetricContraction -try: - import cuequivariance as cue - import cuequivariance_torch as cuet - CUET_AVAILABLE = True -except ImportError: - CUET_AVAILABLE = False @compile_mode("script") class LinearNodeEmbeddingBlock(torch.nn.Module): - def __init__(self, irreps_in: o3.Irreps, irreps_out: o3.Irreps, cueq_config: Optional[CuEquivarianceConfig] = None): + def __init__( + self, + irreps_in: o3.Irreps, + irreps_out: o3.Irreps, + cueq_config: Optional[CuEquivarianceConfig] = None, + ): super().__init__() - self.linear = Linear(irreps_in=irreps_in, irreps_out=irreps_out, cueq_config=cueq_config) + self.linear = Linear( + irreps_in=irreps_in, irreps_out=irreps_out, cueq_config=cueq_config + ) def forward( self, @@ -55,9 +60,16 @@ def forward( @compile_mode("script") class LinearReadoutBlock(torch.nn.Module): - def __init__(self, irreps_in: o3.Irreps, irrep_out: o3.Irreps = o3.Irreps("0e"), cueq_config: Optional[CuEquivarianceConfig] = None): + def __init__( + self, + irreps_in: o3.Irreps, + irrep_out: o3.Irreps = o3.Irreps("0e"), + cueq_config: Optional[CuEquivarianceConfig] = None, + ): super().__init__() - self.linear = Linear(irreps_in=irreps_in, irreps_out=irrep_out, cueq_config=cueq_config) + self.linear = Linear( + irreps_in=irreps_in, irreps_out=irrep_out, cueq_config=cueq_config + ) def forward( self, @@ -77,14 +89,18 @@ def __init__( gate: Optional[Callable], irrep_out: o3.Irreps = o3.Irreps("0e"), num_heads: int = 1, - cueq_config: Optional[CuEquivarianceConfig] = None, + cueq_config: Optional[CuEquivarianceConfig] = None, ): super().__init__() self.hidden_irreps = MLP_irreps self.num_heads = num_heads - self.linear_1 = Linear(irreps_in=irreps_in, irreps_out=self.hidden_irreps, cueq_config=cueq_config) + self.linear_1 = Linear( + irreps_in=irreps_in, irreps_out=self.hidden_irreps, cueq_config=cueq_config + ) self.non_linearity = nn.Activation(irreps_in=self.hidden_irreps, acts=[gate]) - self.linear_2 = Linear(irreps_in=self.hidden_irreps, irreps_out=irrep_out, cueq_config=cueq_config) + self.linear_2 = Linear( + irreps_in=self.hidden_irreps, irreps_out=irrep_out, cueq_config=cueq_config + ) def forward( self, x: torch.Tensor, heads: Optional[torch.Tensor] = None @@ -98,13 +114,20 @@ def forward( @compile_mode("script") class LinearDipoleReadoutBlock(torch.nn.Module): - def __init__(self, irreps_in: o3.Irreps, dipole_only: bool = False, cueq_config: Optional[CuEquivarianceConfig] = None): + def __init__( + self, + irreps_in: o3.Irreps, + dipole_only: bool = False, + cueq_config: Optional[CuEquivarianceConfig] = None, + ): super().__init__() if dipole_only: self.irreps_out = o3.Irreps("1x1o") else: self.irreps_out = o3.Irreps("1x0e + 1x1o") - self.linear = o3.Linear(irreps_in=irreps_in, irreps_out=self.irreps_out, cueq_config=cueq_config) + self.linear = Linear( + irreps_in=irreps_in, irreps_out=self.irreps_out, cueq_config=cueq_config + ) def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ] return self.linear(x) # [n_nodes, 1] @@ -141,9 +164,13 @@ def __init__( irreps_gated=irreps_gated, ) self.irreps_nonlin = self.equivariant_nonlin.irreps_in.simplify() - self.linear_1 = Linear(irreps_in=irreps_in, irreps_out=self.irreps_nonlin, cueq_config=cueq_config) + self.linear_1 = Linear( + irreps_in=irreps_in, irreps_out=self.irreps_nonlin, cueq_config=cueq_config + ) self.linear_2 = Linear( - irreps_in=self.hidden_irreps, irreps_out=self.irreps_out, cueq_config=cueq_config + irreps_in=self.hidden_irreps, + irreps_out=self.irreps_out, + cueq_config=cueq_config, ) def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ] @@ -375,12 +402,19 @@ def _setup(self) -> None: self.irreps_out = linear_out_irreps(irreps_mid, self.target_irreps) self.irreps_out = self.irreps_out.simplify() self.linear = Linear( - irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True, cueq_config=self.cueq_config + irreps_mid, + self.irreps_out, + internal_weights=True, + shared_weights=True, + cueq_config=self.cueq_config, ) # Selector TensorProduct self.skip_tp = FullyConnectedTensorProduct( - self.node_feats_irreps, self.node_attrs_irreps, self.irreps_out, cueq_config=self.cueq_config + self.node_feats_irreps, + self.node_attrs_irreps, + self.irreps_out, + cueq_config=self.cueq_config, ) def forward( @@ -430,7 +464,7 @@ def _setup(self) -> None: instructions=instructions, shared_weights=False, internal_weights=False, - cueq_config=self.cueq_config + cueq_config=self.cueq_config, ) # Convolution weights @@ -445,12 +479,19 @@ def _setup(self) -> None: self.irreps_out = linear_out_irreps(irreps_mid, self.target_irreps) self.irreps_out = self.irreps_out.simplify() self.linear = Linear( - irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True, cueq_config=self.cueq_config + irreps_mid, + self.irreps_out, + internal_weights=True, + shared_weights=True, + cueq_config=self.cueq_config, ) # Selector TensorProduct self.skip_tp = FullyConnectedTensorProduct( - self.irreps_out, self.node_attrs_irreps, self.irreps_out, cueq_config=self.cueq_config + self.irreps_out, + self.node_attrs_irreps, + self.irreps_out, + cueq_config=self.cueq_config, ) def forward( @@ -480,6 +521,8 @@ def forward( @compile_mode("script") class AgnosticResidualNonlinearInteractionBlock(InteractionBlock): def _setup(self) -> None: + if not hasattr(self, "cueq_config"): + self.cueq_config = None # First linear self.linear_up = Linear( self.node_feats_irreps, @@ -499,7 +542,7 @@ def _setup(self) -> None: instructions=instructions, shared_weights=False, internal_weights=False, - cueq_config=self.cueq_config + cueq_config=self.cueq_config, ) # Convolution weights @@ -514,12 +557,19 @@ def _setup(self) -> None: self.irreps_out = linear_out_irreps(irreps_mid, self.target_irreps) self.irreps_out = self.irreps_out.simplify() self.linear = Linear( - irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True, cueq_config=self.cueq_config + irreps_mid, + self.irreps_out, + internal_weights=True, + shared_weights=True, + cueq_config=self.cueq_config, ) # Selector TensorProduct self.skip_tp = FullyConnectedTensorProduct( - self.node_feats_irreps, self.node_attrs_irreps, self.irreps_out, cueq_config=self.cueq_config + self.node_feats_irreps, + self.node_attrs_irreps, + self.irreps_out, + cueq_config=self.cueq_config, ) def forward( @@ -558,7 +608,7 @@ def _setup(self) -> None: self.node_feats_irreps, internal_weights=True, shared_weights=True, - cueq_config=self.cueq_config + cueq_config=self.cueq_config, ) # TensorProduct irreps_mid, instructions = tp_out_irreps_with_instructions( @@ -566,7 +616,7 @@ def _setup(self) -> None: self.edge_attrs_irreps, self.target_irreps, ) - self.conv_tp = o3.TensorProduct( + self.conv_tp = TensorProduct( self.node_feats_irreps, self.edge_attrs_irreps, irreps_mid, @@ -587,12 +637,19 @@ def _setup(self) -> None: irreps_mid = irreps_mid.simplify() self.irreps_out = self.target_irreps self.linear = Linear( - irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True, cueq_config=self.cueq_config + irreps_mid, + self.irreps_out, + internal_weights=True, + shared_weights=True, + cueq_config=self.cueq_config, ) # Selector TensorProduct self.skip_tp = FullyConnectedTensorProduct( - self.irreps_out, self.node_attrs_irreps, self.irreps_out, cueq_config=self.cueq_config + self.irreps_out, + self.node_attrs_irreps, + self.irreps_out, + cueq_config=self.cueq_config, ) self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config) @@ -663,12 +720,19 @@ def _setup(self) -> None: irreps_mid = irreps_mid.simplify() self.irreps_out = self.target_irreps self.linear = Linear( - irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True, cueq_config=self.cueq_config + irreps_mid, + self.irreps_out, + internal_weights=True, + shared_weights=True, + cueq_config=self.cueq_config, ) # Selector TensorProduct self.skip_tp = FullyConnectedTensorProduct( - self.node_feats_irreps, self.node_attrs_irreps, self.hidden_irreps, cueq_config=self.cueq_config + self.node_feats_irreps, + self.node_attrs_irreps, + self.hidden_irreps, + cueq_config=self.cueq_config, ) self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config) @@ -710,7 +774,7 @@ def _setup(self) -> None: self.node_feats_irreps, internal_weights=True, shared_weights=True, - cueq_config=self.cueq_config + cueq_config=self.cueq_config, ) # TensorProduct irreps_mid, instructions = tp_out_irreps_with_instructions( @@ -739,12 +803,19 @@ def _setup(self) -> None: irreps_mid = irreps_mid.simplify() self.irreps_out = self.target_irreps self.linear = Linear( - irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True, cueq_config=self.cueq_config + irreps_mid, + self.irreps_out, + internal_weights=True, + shared_weights=True, + cueq_config=self.cueq_config, ) # Selector TensorProduct self.skip_tp = FullyConnectedTensorProduct( - self.irreps_out, self.node_attrs_irreps, self.irreps_out, cueq_config=self.cueq_config + self.irreps_out, + self.node_attrs_irreps, + self.irreps_out, + cueq_config=self.cueq_config, ) # Density normalization @@ -830,12 +901,19 @@ def _setup(self) -> None: irreps_mid = irreps_mid.simplify() self.irreps_out = self.target_irreps self.linear = Linear( - irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True, cueq_config=self.cueq_config + irreps_mid, + self.irreps_out, + internal_weights=True, + shared_weights=True, + cueq_config=self.cueq_config, ) # Selector TensorProduct self.skip_tp = FullyConnectedTensorProduct( - self.node_feats_irreps, self.node_attrs_irreps, self.hidden_irreps, cueq_config=self.cueq_config + self.node_feats_irreps, + self.node_attrs_irreps, + self.hidden_irreps, + cueq_config=self.cueq_config, ) # Density normalization @@ -893,7 +971,7 @@ def _setup(self) -> None: self.node_feats_irreps, internal_weights=True, shared_weights=True, - cueq_config=self.cueq_config + cueq_config=self.cueq_config, ) # TensorProduct irreps_mid, instructions = tp_out_irreps_with_instructions( @@ -936,13 +1014,15 @@ def _setup(self) -> None: self.irreps_out, internal_weights=True, shared_weights=True, - cueq_config=self.cueq_config + cueq_config=self.cueq_config, ) self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config) # Skip connection. - self.skip_linear = Linear(self.node_feats_irreps, self.hidden_irreps, cueq_config=self.cueq_config) + self.skip_linear = Linear( + self.node_feats_irreps, self.hidden_irreps, cueq_config=self.cueq_config + ) def forward( self, diff --git a/mace/modules/irreps_tools.py b/mace/modules/irreps_tools.py index 597e439a..3e4cc6f6 100644 --- a/mace/modules/irreps_tools.py +++ b/mace/modules/irreps_tools.py @@ -12,13 +12,6 @@ from mace.modules.wrapper_ops import CuEquivarianceConfig -try: - import cuequivariance as cue - LAYOUTS = cue -except ImportError: - LAYOUTS = DefaultLayouts - - # Based on mir-group/nequip def tp_out_irreps_with_instructions( irreps1: o3.Irreps, irreps2: o3.Irreps, target_irreps: o3.Irreps @@ -69,13 +62,12 @@ def linear_out_irreps(irreps: o3.Irreps, target_irreps: o3.Irreps) -> o3.Irreps: return o3.Irreps(irreps_mid) -class DefaultLayouts: - mul_ir = "mul_ir" - ir_mul = "ir_mul" @compile_mode("script") class reshape_irreps(torch.nn.Module): - def __init__(self, irreps: o3.Irreps, cueq_config: Optional[CuEquivarianceConfig] = None) -> None: + def __init__( + self, irreps: o3.Irreps, cueq_config: Optional[CuEquivarianceConfig] = None + ) -> None: super().__init__() self.irreps = o3.Irreps(irreps) self.cueq_config = cueq_config @@ -101,14 +93,13 @@ def forward(self, tensor: torch.Tensor) -> torch.Tensor: else: field = field.reshape(batch, mul, d) out.append(field) - + if hasattr(self, "cueq_config") and self.cueq_config is not None: if self.cueq_config.layout_str == "mul_ir": return torch.cat(out, dim=-1) - else: - return torch.cat(out, dim=-2) - else: - return torch.cat(out, dim=-1) + return torch.cat(out, dim=-2) + return torch.cat(out, dim=-1) + def mask_head(x: torch.Tensor, head: torch.Tensor, num_heads: int) -> torch.Tensor: mask = torch.zeros(x.shape[0], x.shape[1] // num_heads, num_heads, device=x.device) diff --git a/mace/modules/models.py b/mace/modules/models.py index 345ee213..0e03317e 100644 --- a/mace/modules/models.py +++ b/mace/modules/models.py @@ -83,7 +83,9 @@ def __init__( node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) self.node_embedding = LinearNodeEmbeddingBlock( - irreps_in=node_attr_irreps, irreps_out=node_feats_irreps, cueq_config=cueq_config + irreps_in=node_attr_irreps, + irreps_out=node_feats_irreps, + cueq_config=cueq_config, ) self.radial_embedding = RadialEmbeddingBlock( r_max=r_max, @@ -139,7 +141,9 @@ def __init__( self.readouts = torch.nn.ModuleList() self.readouts.append( - LinearReadoutBlock(hidden_irreps, o3.Irreps(f"{len(heads)}x0e"), cueq_config) + LinearReadoutBlock( + hidden_irreps, o3.Irreps(f"{len(heads)}x0e"), cueq_config + ) ) for i in range(num_interactions - 1): @@ -183,7 +187,9 @@ def __init__( ) else: self.readouts.append( - LinearReadoutBlock(hidden_irreps, o3.Irreps(f"{len(heads)}x0e"), cueq_config) + LinearReadoutBlock( + hidden_irreps, o3.Irreps(f"{len(heads)}x0e"), cueq_config + ) ) def forward( @@ -477,7 +483,7 @@ def __init__( gate: Optional[Callable], avg_num_neighbors: float, atomic_numbers: List[int], - cueq_config: Optional[Dict[str, Any]] = None, + cueq_config: Optional[Dict[str, Any]] = None, # pylint: disable=unused-argument ): super().__init__() self.r_max = r_max @@ -682,7 +688,7 @@ def __init__( ], # Just here to make it compatible with energy models, MUST be None radial_type: Optional[str] = "bessel", radial_MLP: Optional[List[int]] = None, - cueq_config: Optional[Dict[str, Any]] = None, + cueq_config: Optional[Dict[str, Any]] = None, # pylint: disable=unused-argument ): super().__init__() self.register_buffer( @@ -884,7 +890,7 @@ def __init__( gate: Optional[Callable], atomic_energies: Optional[np.ndarray], radial_MLP: Optional[List[int]] = None, - cueq_config: Optional[Dict[str, Any]] = None, + cueq_config: Optional[Dict[str, Any]] = None, # pylint: disable=unused-argument ): super().__init__() self.register_buffer( diff --git a/mace/modules/wrapper_ops.py b/mace/modules/wrapper_ops.py index 6f295de7..84ccca0d 100644 --- a/mace/modules/wrapper_ops.py +++ b/mace/modules/wrapper_ops.py @@ -1,29 +1,67 @@ """ Wrapper class for o3.Linear that optionally uses cuet.Linear """ + import dataclasses import itertools +from typing import Iterator, List, Optional + import numpy as np import torch -from typing import Iterator, List, Optional -import e3nn.o3 as o3 +from e3nn import o3 from mace.modules.symmetric_contraction import SymmetricContraction + try: import cuequivariance as cue import cuequivariance_torch as cuet + CUET_AVAILABLE = True except ImportError: CUET_AVAILABLE = False -@dataclasses.dataclass +if CUET_AVAILABLE: + + class O3_e3nn(cue.O3): + def __mul__(rep1: "O3_e3nn", rep2: "O3_e3nn") -> Iterator["O3_e3nn"]: + return [O3_e3nn(l=ir.l, p=ir.p) for ir in cue.O3.__mul__(rep1, rep2)] + + @classmethod + def clebsch_gordan( + cls, rep1: "O3_e3nn", rep2: "O3_e3nn", rep3: "O3_e3nn" + ) -> np.ndarray: + rep1, rep2, rep3 = cls._from(rep1), cls._from(rep2), cls._from(rep3) + + if rep1.p * rep2.p == rep3.p: + return o3.wigner_3j(rep1.l, rep2.l, rep3.l).numpy()[None] * np.sqrt( + rep3.dim + ) + return np.zeros((0, rep1.dim, rep2.dim, rep3.dim)) + + def __lt__(rep1: "O3_e3nn", rep2: "O3_e3nn") -> bool: + rep2 = rep1._from(rep2) + return (rep1.l, rep1.p) < (rep2.l, rep2.p) + + @classmethod + def iterator(cls) -> Iterator["O3_e3nn"]: + for l in itertools.count(0): + yield O3_e3nn(l=l, p=1 * (-1) ** l) + yield O3_e3nn(l=l, p=-1 * (-1) ** l) +else: + print( + "cuequivariance or cuequivariance_torch is not available. Cuequivariance acceleration will be disabled." + ) + + +@dataclasses.dataclass class CuEquivarianceConfig: """Configuration for cuequivariance acceleration""" + enabled: bool = False layout: str = "mul_ir" # One of: mul_ir, ir_mul group: str = "O3" optimize_all: bool = False # Set to True to enable all optimizations - optimize_linear: bool = False + optimize_linear: bool = False optimize_channelwise: bool = False optimize_symmetric: bool = False optimize_fctp: bool = False @@ -34,34 +72,10 @@ def __post_init__(self): self.layout = getattr(cue, self.layout) self.group = O3_e3nn if self.group == "O3" else getattr(cue, self.group) -class O3_e3nn(cue.O3): - def __mul__(rep1: "O3_e3nn", rep2: "O3_e3nn") -> Iterator["O3_e3nn"]: - return [O3_e3nn(l=ir.l, p=ir.p) for ir in cue.O3.__mul__(rep1, rep2)] - - @classmethod - def clebsch_gordan( - cls, rep1: "O3_e3nn", rep2: "O3_e3nn", rep3: "O3_e3nn" - ) -> np.ndarray: - rep1, rep2, rep3 = cls._from(rep1), cls._from(rep2), cls._from(rep3) - - if rep1.p * rep2.p == rep3.p: - return o3.wigner_3j(rep1.l, rep2.l, rep3.l).numpy()[None] * np.sqrt(rep3.dim) - else: - return np.zeros((0, rep1.dim, rep2.dim, rep3.dim)) - - def __lt__(rep1: "O3_e3nn", rep2: "O3_e3nn") -> bool: - rep2 = rep1._from(rep2) - return (rep1.l, rep1.p) < (rep2.l, rep2.p) - - @classmethod - def iterator(cls) -> Iterator["O3_e3nn"]: - for l in itertools.count(0): - yield O3_e3nn(l=l, p=1 * (-1) ** l) - yield O3_e3nn(l=l, p=-1 * (-1) ** l) - class Linear(torch.nn.Module): """Wrapper around o3.Linear that optionally uses cuet.Linear when enabled""" + def __init__( self, irreps_in: o3.Irreps, @@ -71,14 +85,18 @@ def __init__( cueq_config: Optional[CuEquivarianceConfig] = None, ): super().__init__() - if (CUET_AVAILABLE and cueq_config is not None and cueq_config.enabled - and (cueq_config.optimize_all or cueq_config.optimize_linear)): + if ( + CUET_AVAILABLE + and cueq_config is not None + and cueq_config.enabled + and (cueq_config.optimize_all or cueq_config.optimize_linear) + ): self.linear = cuet.Linear( cue.Irreps(cueq_config.group, irreps_in), cue.Irreps(cueq_config.group, irreps_out), layout=cueq_config.layout, shared_weights=shared_weights, - optimize_fallback=not cueq_config.optimize_linear + optimize_fallback=not cueq_config.optimize_linear, ) self.use_cuet = True self.cueq_config = cueq_config @@ -87,7 +105,7 @@ def __init__( irreps_in, irreps_out, shared_weights=shared_weights, - internal_weights=internal_weights + internal_weights=internal_weights, ) self.use_cuet = False @@ -97,14 +115,16 @@ def __getattr__(self, name): return super().__getattr__(name) except AttributeError: return getattr(self.linear, name) - + def forward(self, x: torch.Tensor) -> torch.Tensor: if self.use_cuet and hasattr(self, "cueq_config"): return self.linear(x, use_fallback=not self.cueq_config.optimize_linear) return self.linear(x) + class TensorProduct(torch.nn.Module): """Wrapper around o3.TensorProduct/cuet.ChannelwiseTensorProduct""" + def __init__( self, irreps_in1: o3.Irreps, @@ -116,8 +136,12 @@ def __init__( cueq_config: Optional[CuEquivarianceConfig] = None, ): super().__init__() - if (CUET_AVAILABLE and cueq_config is not None and cueq_config.enabled - and (cueq_config.optimize_all or cueq_config.optimize_channelwise)): + if ( + CUET_AVAILABLE + and cueq_config is not None + and cueq_config.enabled + and (cueq_config.optimize_all or cueq_config.optimize_channelwise) + ): self.tp = cuet.ChannelwiseTensorProduct( cue.Irreps(cueq_config.group, irreps_in1), cue.Irreps(cueq_config.group, irreps_in2), @@ -125,7 +149,7 @@ def __init__( layout=cueq_config.layout, shared_weights=shared_weights, internal_weights=internal_weights, - optimize_fallback=not cueq_config.optimize_channelwise + optimize_fallback=not cueq_config.optimize_channelwise, ) self.use_cuet = True self.cueq_config = cueq_config @@ -136,7 +160,7 @@ def __init__( irreps_out, instructions=instructions, shared_weights=shared_weights, - internal_weights=internal_weights + internal_weights=internal_weights, ) self.use_cuet = False @@ -146,14 +170,20 @@ def __getattr__(self, name): return super().__getattr__(name) except AttributeError: return getattr(self.tp, name) - - def forward(self, x1: torch.Tensor, x2: torch.Tensor, weights: Optional[torch.Tensor] = None) -> torch.Tensor: + + def forward( + self, x1: torch.Tensor, x2: torch.Tensor, weights: Optional[torch.Tensor] = None + ) -> torch.Tensor: if self.use_cuet and hasattr(self, "cueq_config"): - return self.tp(x1, x2, weights, use_fallback=not self.cueq_config.optimize_channelwise) + return self.tp( + x1, x2, weights, use_fallback=not self.cueq_config.optimize_channelwise + ) return self.tp(x1, x2, weights) + class FullyConnectedTensorProduct(torch.nn.Module): """Wrapper around o3.FullyConnectedTensorProduct/cuet.FullyConnectedTensorProduct""" + def __init__( self, irreps_in1: o3.Irreps, @@ -164,16 +194,20 @@ def __init__( cueq_config: Optional[CuEquivarianceConfig] = None, ): super().__init__() - if (CUET_AVAILABLE and cueq_config is not None and cueq_config.enabled - and (cueq_config.optimize_all or cueq_config.optimize_fctp)): + if ( + CUET_AVAILABLE + and cueq_config is not None + and cueq_config.enabled + and (cueq_config.optimize_all or cueq_config.optimize_fctp) + ): self.tp = cuet.FullyConnectedTensorProduct( cue.Irreps(cueq_config.group, irreps_in1), - cue.Irreps(cueq_config.group, irreps_in2), + cue.Irreps(cueq_config.group, irreps_in2), cue.Irreps(cueq_config.group, irreps_out), layout=cueq_config.layout, shared_weights=shared_weights, internal_weights=internal_weights, - optimize_fallback=not cueq_config.optimize_fctp + optimize_fallback=not cueq_config.optimize_fctp, ) self.use_cuet = True self.cueq_config = cueq_config @@ -183,24 +217,26 @@ def __init__( irreps_in2, irreps_out, shared_weights=shared_weights, - internal_weights=internal_weights + internal_weights=internal_weights, ) self.use_cuet = False - + def __getattr__(self, name): """Forward any unknown attribute access to the underlying linear object""" try: return super().__getattr__(name) except AttributeError: return getattr(self.tp, name) - + def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: if self.use_cuet and hasattr(self, "cueq_config"): return self.tp(x1, x2, use_fallback=not self.cueq_config.optimize_fctp) return self.tp(x1, x2) + class SymmetricContractionWrapper(torch.nn.Module): """Wrapper around SymmetricContraction/cuet.SymmetricContraction""" + def __init__( self, irreps_in: o3.Irreps, @@ -210,16 +246,20 @@ def __init__( cueq_config: Optional[CuEquivarianceConfig] = None, ): super().__init__() - if (CUET_AVAILABLE and cueq_config is not None and cueq_config.enabled - and (cueq_config.optimize_all or cueq_config.optimize_symmetric)): + if ( + CUET_AVAILABLE + and cueq_config is not None + and cueq_config.enabled + and (cueq_config.optimize_all or cueq_config.optimize_symmetric) + ): self.sconctaction = cuet.SymmetricContraction( cue.Irreps(cueq_config.group, irreps_in), - cue.Irreps(cueq_config.group, irreps_out), + cue.Irreps(cueq_config.group, irreps_out), layout_in=cue.ir_mul, layout_out=cueq_config.layout, contraction_degree=correlation, num_elements=num_elements, - optimize_fallback=not cueq_config.optimize_symmetric + optimize_fallback=not cueq_config.optimize_symmetric, ) self.use_cuet = True self.cueq_config = cueq_config @@ -229,7 +269,7 @@ def __init__( irreps_in=irreps_in, irreps_out=irreps_out, correlation=correlation, - num_elements=num_elements + num_elements=num_elements, ) self.use_cuet = False @@ -239,10 +279,14 @@ def __getattr__(self, name): return super().__getattr__(name) except AttributeError: return getattr(self.sc, name) - + def forward(self, x: torch.Tensor, attrs: torch.Tensor) -> torch.Tensor: if self.use_cuet and hasattr(self, "cueq_config"): if self.layout == cue.mul_ir: x = torch.transpose(x, 1, 2) - return self.sconctaction(x.flatten(1), attrs, use_fallback=not self.cueq_config.optimize_symmetric) - return self.sconctaction(x, attrs) \ No newline at end of file + return self.sconctaction( + x.flatten(1), + attrs, + use_fallback=not self.cueq_config.optimize_symmetric, + ) + return self.sconctaction(x, attrs) diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index 8cf6990a..3df3960a 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -662,54 +662,54 @@ def build_default_arg_parser() -> argparse.ArgumentParser: ) # option for cuequivariance acceleration parser.add_argument( - "--cue_enabled", - help="Enable cuequivariance acceleration", - type=str2bool, - default=False + "--cue_enabled", + help="Enable cuequivariance acceleration", + type=str2bool, + default=False, ) parser.add_argument( - "--cue_layout", + "--cue_layout", help="Memory layout for cuequivariance tensors", type=str, choices=["mul_ir", "ir_mul"], - default="mul_ir" + default="mul_ir", ) parser.add_argument( "--cue_group", help="Symmetry group for cuequivariance", - type=str, + type=str, choices=["O3nn, O3"], - default="O3nn" + default="O3nn", ) parser.add_argument( "--cue_optimize_all", - help="Enable all cuequivariance optimizations", + help="Enable all cuequivariance optimizations", type=str2bool, - default=False + default=False, ) parser.add_argument( - "--cue_optimize_linear", + "--cue_optimize_linear", help="Enable cuequivariance linear layer optimization", type=str2bool, - default=False + default=False, ) parser.add_argument( "--cue_optimize_channelwise", help="Enable cuequivariance channelwise optimization", - type=str2bool, - default=False + type=str2bool, + default=False, ) parser.add_argument( "--cue_optimize_symmetric", help="Enable cuequivariance symmetric contraction optimization", type=str2bool, - default=False + default=False, ) parser.add_argument( "--cue_optimize_fctp", help="Enable cuequivariance fully connected tensor product optimization", type=str2bool, - default=False + default=False, ) # options for using Weights and Biases for experiment tracking # to install see https://wandb.ai diff --git a/pyproject.toml b/pyproject.toml index 489bc6e5..41685059 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,3 +39,6 @@ ignore-paths = [ "^mace/tools/torch_geometric/.*$", "^mace/tools/scatter.py$", ] + +[tool.pylint.FORMAT] +max-module-lines = 1500 \ No newline at end of file diff --git a/tests/test_cueq.py b/tests/test_cueq.py index 256f3570..beb0c45e 100644 --- a/tests/test_cueq.py +++ b/tests/test_cueq.py @@ -1,19 +1,21 @@ +from typing import Any, Dict + import pytest import torch import torch.nn.functional as F from e3nn import o3 -from typing import Dict, Any from mace import data, modules, tools from mace.modules.wrapper_ops import CuEquivarianceConfig from mace.tools import torch_geometric try: - import cuequivariance as cue + import cuequivariance as cue # pylint: disable=unused-import CUET_AVAILABLE = True except ImportError: CUET_AVAILABLE = False + @pytest.mark.skipif(not CUET_AVAILABLE, reason="cuequivariance not installed") class TestCueq: @pytest.fixture @@ -24,8 +26,12 @@ def model_config(self) -> Dict[str, Any]: "num_bessel": 8, "num_polynomial_cutoff": 6, "max_ell": 3, - "interaction_cls": modules.interaction_classes["RealAgnosticResidualInteractionBlock"], - "interaction_cls_first": modules.interaction_classes["RealAgnosticResidualInteractionBlock"], + "interaction_cls": modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + "interaction_cls_first": modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], "num_interactions": 2, "num_elements": 1, "hidden_irreps": o3.Irreps("32x0e + 32x1o"), @@ -35,12 +41,13 @@ def model_config(self) -> Dict[str, Any]: "avg_num_neighbors": 8, "atomic_numbers": table.zs, "correlation": 3, - "radial_type": "bessel" + "radial_type": "bessel", } @pytest.fixture def batch(self, device: str): from ase import build + table = tools.AtomicNumberTable([6]) atoms = build.bulk("C", "diamond", a=3.567, cubic=True) @@ -48,8 +55,10 @@ def batch(self, device: str): configs = [data.config_from_atoms(atoms) for atoms in atoms_list] data_loader = torch_geometric.dataloader.DataLoader( - dataset=[data.AtomicData.from_config(config, z_table=table, cutoff=5.0) - for config in configs], + dataset=[ + data.AtomicData.from_config(config, z_table=table, cutoff=5.0) + for config in configs + ], batch_size=1, shuffle=False, drop_last=False, @@ -58,24 +67,23 @@ def batch(self, device: str): return batch.to(device).to_dict() @pytest.mark.parametrize("device", ["cuda"]) - def test_cueq_equivalence(self, model_config: Dict[str, Any], batch: Dict[str, torch.Tensor], device: str): + def test_cueq_equivalence( + self, model_config: Dict[str, Any], batch: Dict[str, torch.Tensor], device: str + ): torch.manual_seed(42) - + # Create model without cuequivariance model_std = modules.MACE(**model_config) model_std = model_std.to(device) # Create model with cuequivariance cueq_config = CuEquivarianceConfig( - enabled=True, - layout="mul_ir", - group="O3nn", - optimize_all=True + enabled=True, layout="mul_ir", group="O3nn", optimize_all=True ) model_config["cueq_config"] = cueq_config model_cueq = modules.MACE(**model_config) model_cueq = model_cueq.to(device) - + # Copy weights model_cueq.load_state_dict(model_std.state_dict()) @@ -90,13 +98,13 @@ def test_cueq_equivalence(self, model_config: Dict[str, Any], batch: Dict[str, t # Test gradients out_std = model_std(batch, training=True) out_cueq = model_cueq(batch, training=True) - + loss_std = out_std["energy"].sum() loss_cueq = out_cueq["energy"].sum() - + loss_std.backward() loss_cueq.backward() for p1, p2 in zip(model_std.parameters(), model_cueq.parameters()): if p1.grad is not None: - torch.testing.assert_close(p1.grad, p2.grad) \ No newline at end of file + torch.testing.assert_close(p1.grad, p2.grad) From 42966579093af74d1f874bc25dcfb72864f1126b Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Fri, 15 Nov 2024 15:58:02 +0000 Subject: [PATCH 07/17] add pylint diable to cuet class --- mace/modules/wrapper_ops.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mace/modules/wrapper_ops.py b/mace/modules/wrapper_ops.py index 84ccca0d..36da62db 100644 --- a/mace/modules/wrapper_ops.py +++ b/mace/modules/wrapper_ops.py @@ -23,7 +23,7 @@ if CUET_AVAILABLE: class O3_e3nn(cue.O3): - def __mul__(rep1: "O3_e3nn", rep2: "O3_e3nn") -> Iterator["O3_e3nn"]: + def __mul__(rep1: "O3_e3nn", rep2: "O3_e3nn") -> Iterator["O3_e3nn"]: # pylint: disable=no-self-argument return [O3_e3nn(l=ir.l, p=ir.p) for ir in cue.O3.__mul__(rep1, rep2)] @classmethod @@ -38,7 +38,7 @@ def clebsch_gordan( ) return np.zeros((0, rep1.dim, rep2.dim, rep3.dim)) - def __lt__(rep1: "O3_e3nn", rep2: "O3_e3nn") -> bool: + def __lt__(rep1: "O3_e3nn", rep2: "O3_e3nn") -> bool: # pylint: disable=no-self-argument rep2 = rep1._from(rep2) return (rep1.l, rep1.p) < (rep2.l, rep2.p) @@ -96,7 +96,7 @@ def __init__( cue.Irreps(cueq_config.group, irreps_out), layout=cueq_config.layout, shared_weights=shared_weights, - optimize_fallback=not cueq_config.optimize_linear, + optimize_fallback=not cueq_config.optimize_linear, # pylint: disable=unexpected-keyword-arg ) self.use_cuet = True self.cueq_config = cueq_config @@ -149,7 +149,7 @@ def __init__( layout=cueq_config.layout, shared_weights=shared_weights, internal_weights=internal_weights, - optimize_fallback=not cueq_config.optimize_channelwise, + optimize_fallback=not cueq_config.optimize_channelwise, # pylint: disable=unexpected-keyword-arg ) self.use_cuet = True self.cueq_config = cueq_config @@ -207,7 +207,7 @@ def __init__( layout=cueq_config.layout, shared_weights=shared_weights, internal_weights=internal_weights, - optimize_fallback=not cueq_config.optimize_fctp, + optimize_fallback=not cueq_config.optimize_fctp, # pylint: disable=unexpected-keyword-arg ) self.use_cuet = True self.cueq_config = cueq_config @@ -259,7 +259,7 @@ def __init__( layout_out=cueq_config.layout, contraction_degree=correlation, num_elements=num_elements, - optimize_fallback=not cueq_config.optimize_symmetric, + optimize_fallback=not cueq_config.optimize_symmetric, # pylint: disable=unexpected-keyword-arg ) self.use_cuet = True self.cueq_config = cueq_config From 5c30ada2cdc801f2adb84fb64667a746bc3c3075 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Sat, 16 Nov 2024 19:51:41 +0000 Subject: [PATCH 08/17] fix indices --- mace/modules/wrapper_ops.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/mace/modules/wrapper_ops.py b/mace/modules/wrapper_ops.py index 36da62db..6643b8ed 100644 --- a/mace/modules/wrapper_ops.py +++ b/mace/modules/wrapper_ops.py @@ -96,7 +96,6 @@ def __init__( cue.Irreps(cueq_config.group, irreps_out), layout=cueq_config.layout, shared_weights=shared_weights, - optimize_fallback=not cueq_config.optimize_linear, # pylint: disable=unexpected-keyword-arg ) self.use_cuet = True self.cueq_config = cueq_config @@ -142,14 +141,13 @@ def __init__( and cueq_config.enabled and (cueq_config.optimize_all or cueq_config.optimize_channelwise) ): - self.tp = cuet.ChannelwiseTensorProduct( + self.tp = cuet.ChannelWiseTensorProduct( cue.Irreps(cueq_config.group, irreps_in1), cue.Irreps(cueq_config.group, irreps_in2), cue.Irreps(cueq_config.group, irreps_out), layout=cueq_config.layout, shared_weights=shared_weights, internal_weights=internal_weights, - optimize_fallback=not cueq_config.optimize_channelwise, # pylint: disable=unexpected-keyword-arg ) self.use_cuet = True self.cueq_config = cueq_config @@ -207,7 +205,6 @@ def __init__( layout=cueq_config.layout, shared_weights=shared_weights, internal_weights=internal_weights, - optimize_fallback=not cueq_config.optimize_fctp, # pylint: disable=unexpected-keyword-arg ) self.use_cuet = True self.cueq_config = cueq_config @@ -259,7 +256,6 @@ def __init__( layout_out=cueq_config.layout, contraction_degree=correlation, num_elements=num_elements, - optimize_fallback=not cueq_config.optimize_symmetric, # pylint: disable=unexpected-keyword-arg ) self.use_cuet = True self.cueq_config = cueq_config @@ -284,9 +280,10 @@ def forward(self, x: torch.Tensor, attrs: torch.Tensor) -> torch.Tensor: if self.use_cuet and hasattr(self, "cueq_config"): if self.layout == cue.mul_ir: x = torch.transpose(x, 1, 2) + index_attrs = torch.nonzero(attrs)[:,1].int() return self.sconctaction( x.flatten(1), - attrs, + index_attrs, use_fallback=not self.cueq_config.optimize_symmetric, ) return self.sconctaction(x, attrs) From 653e268f379e53617cc2cbf420b3ea9a660d0960 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Sat, 16 Nov 2024 19:55:11 +0000 Subject: [PATCH 09/17] fix wrapper --- mace/modules/wrapper_ops.py | 3 ++- mace/tools/arg_parser.py | 4 ++-- tests/test_cueq.py | 13 ++++--------- 3 files changed, 8 insertions(+), 12 deletions(-) diff --git a/mace/modules/wrapper_ops.py b/mace/modules/wrapper_ops.py index 6643b8ed..f9412a97 100644 --- a/mace/modules/wrapper_ops.py +++ b/mace/modules/wrapper_ops.py @@ -59,6 +59,7 @@ class CuEquivarianceConfig: enabled: bool = False layout: str = "mul_ir" # One of: mul_ir, ir_mul + layout_str: str = "mul_ir" group: str = "O3" optimize_all: bool = False # Set to True to enable all optimizations optimize_linear: bool = False @@ -70,7 +71,7 @@ def __post_init__(self): if self.enabled and CUET_AVAILABLE: self.layout_str = self.layout self.layout = getattr(cue, self.layout) - self.group = O3_e3nn if self.group == "O3" else getattr(cue, self.group) + self.group = O3_e3nn if self.group == "O3_e3nn" else getattr(cue, self.group) class Linear(torch.nn.Module): diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index 3df3960a..f513b9f9 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -678,8 +678,8 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "--cue_group", help="Symmetry group for cuequivariance", type=str, - choices=["O3nn, O3"], - default="O3nn", + choices=["O3_e3nn, O3"], + default="O3_e3nn", ) parser.add_argument( "--cue_optimize_all", diff --git a/tests/test_cueq.py b/tests/test_cueq.py index beb0c45e..129c3489 100644 --- a/tests/test_cueq.py +++ b/tests/test_cueq.py @@ -78,27 +78,22 @@ def test_cueq_equivalence( # Create model with cuequivariance cueq_config = CuEquivarianceConfig( - enabled=True, layout="mul_ir", group="O3nn", optimize_all=True + enabled=True, layout="mul_ir", group="O3_e3nn", optimize_all=True ) model_config["cueq_config"] = cueq_config model_cueq = modules.MACE(**model_config) model_cueq = model_cueq.to(device) # Copy weights - model_cueq.load_state_dict(model_std.state_dict()) + # model_cueq.load_state_dict(model_std.state_dict()) # Compare outputs - with torch.no_grad(): - out_std = model_std(batch, training=True) - out_cueq = model_cueq(batch, training=True) + out_std = model_std(batch, training=True) + out_cueq = model_cueq(batch, training=True) torch.testing.assert_close(out_std["energy"], out_cueq["energy"]) torch.testing.assert_close(out_std["forces"], out_cueq["forces"]) - # Test gradients - out_std = model_std(batch, training=True) - out_cueq = model_cueq(batch, training=True) - loss_std = out_std["energy"].sum() loss_cueq = out_cueq["energy"].sum() From 126f00f35519e51df38f507815c7b6223870a9b0 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Sat, 16 Nov 2024 20:36:44 +0000 Subject: [PATCH 10/17] add new wrapper --- mace/modules/wrapper_ops.py | 162 ++++++++++-------------------------- 1 file changed, 45 insertions(+), 117 deletions(-) diff --git a/mace/modules/wrapper_ops.py b/mace/modules/wrapper_ops.py index f9412a97..9ee68cdc 100644 --- a/mace/modules/wrapper_ops.py +++ b/mace/modules/wrapper_ops.py @@ -74,59 +74,36 @@ def __post_init__(self): self.group = O3_e3nn if self.group == "O3_e3nn" else getattr(cue, self.group) -class Linear(torch.nn.Module): - """Wrapper around o3.Linear that optionally uses cuet.Linear when enabled""" - - def __init__( - self, +class Linear: + """Returns either a cuet.Linear or o3.Linear based on config""" + def __new__( + cls, irreps_in: o3.Irreps, irreps_out: o3.Irreps, shared_weights: bool = True, internal_weights: bool = True, cueq_config: Optional[CuEquivarianceConfig] = None, ): - super().__init__() - if ( - CUET_AVAILABLE - and cueq_config is not None - and cueq_config.enabled - and (cueq_config.optimize_all or cueq_config.optimize_linear) - ): - self.linear = cuet.Linear( + if (CUET_AVAILABLE and cueq_config is not None and cueq_config.enabled + and (cueq_config.optimize_all or cueq_config.optimize_linear)): + return cuet.Linear( cue.Irreps(cueq_config.group, irreps_in), cue.Irreps(cueq_config.group, irreps_out), layout=cueq_config.layout, shared_weights=shared_weights, ) - self.use_cuet = True - self.cueq_config = cueq_config - else: - self.linear = o3.Linear( - irreps_in, - irreps_out, - shared_weights=shared_weights, - internal_weights=internal_weights, - ) - self.use_cuet = False - - def __getattr__(self, name): - """Forward any unknown attribute access to the underlying linear object""" - try: - return super().__getattr__(name) - except AttributeError: - return getattr(self.linear, name) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.use_cuet and hasattr(self, "cueq_config"): - return self.linear(x, use_fallback=not self.cueq_config.optimize_linear) - return self.linear(x) - - -class TensorProduct(torch.nn.Module): + return o3.Linear( + irreps_in, + irreps_out, + shared_weights=shared_weights, + internal_weights=internal_weights, + ) + +class TensorProduct: """Wrapper around o3.TensorProduct/cuet.ChannelwiseTensorProduct""" - def __init__( - self, + def __new__( + cls, irreps_in1: o3.Irreps, irreps_in2: o3.Irreps, irreps_out: o3.Irreps, @@ -135,14 +112,13 @@ def __init__( internal_weights: bool = False, cueq_config: Optional[CuEquivarianceConfig] = None, ): - super().__init__() if ( CUET_AVAILABLE and cueq_config is not None and cueq_config.enabled and (cueq_config.optimize_all or cueq_config.optimize_channelwise) ): - self.tp = cuet.ChannelWiseTensorProduct( + return cuet.ChannelWiseTensorProduct( cue.Irreps(cueq_config.group, irreps_in1), cue.Irreps(cueq_config.group, irreps_in2), cue.Irreps(cueq_config.group, irreps_out), @@ -150,10 +126,7 @@ def __init__( shared_weights=shared_weights, internal_weights=internal_weights, ) - self.use_cuet = True - self.cueq_config = cueq_config - else: - self.tp = o3.TensorProduct( + return o3.TensorProduct( irreps_in1, irreps_in2, irreps_out, @@ -161,30 +134,12 @@ def __init__( shared_weights=shared_weights, internal_weights=internal_weights, ) - self.use_cuet = False - - def __getattr__(self, name): - """Forward any unknown attribute access to the underlying linear object""" - try: - return super().__getattr__(name) - except AttributeError: - return getattr(self.tp, name) - - def forward( - self, x1: torch.Tensor, x2: torch.Tensor, weights: Optional[torch.Tensor] = None - ) -> torch.Tensor: - if self.use_cuet and hasattr(self, "cueq_config"): - return self.tp( - x1, x2, weights, use_fallback=not self.cueq_config.optimize_channelwise - ) - return self.tp(x1, x2, weights) - -class FullyConnectedTensorProduct(torch.nn.Module): +class FullyConnectedTensorProduct: """Wrapper around o3.FullyConnectedTensorProduct/cuet.FullyConnectedTensorProduct""" - def __init__( - self, + def __new__( + cls, irreps_in1: o3.Irreps, irreps_in2: o3.Irreps, irreps_out: o3.Irreps, @@ -192,14 +147,13 @@ def __init__( internal_weights: bool = True, cueq_config: Optional[CuEquivarianceConfig] = None, ): - super().__init__() if ( CUET_AVAILABLE and cueq_config is not None and cueq_config.enabled and (cueq_config.optimize_all or cueq_config.optimize_fctp) ): - self.tp = cuet.FullyConnectedTensorProduct( + return cuet.FullyConnectedTensorProduct( cue.Irreps(cueq_config.group, irreps_in1), cue.Irreps(cueq_config.group, irreps_in2), cue.Irreps(cueq_config.group, irreps_out), @@ -207,50 +161,45 @@ def __init__( shared_weights=shared_weights, internal_weights=internal_weights, ) - self.use_cuet = True - self.cueq_config = cueq_config - else: - self.tp = o3.FullyConnectedTensorProduct( + return o3.FullyConnectedTensorProduct( irreps_in1, irreps_in2, irreps_out, shared_weights=shared_weights, internal_weights=internal_weights, ) - self.use_cuet = False - - def __getattr__(self, name): - """Forward any unknown attribute access to the underlying linear object""" - try: - return super().__getattr__(name) - except AttributeError: - return getattr(self.tp, name) - - def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: - if self.use_cuet and hasattr(self, "cueq_config"): - return self.tp(x1, x2, use_fallback=not self.cueq_config.optimize_fctp) - return self.tp(x1, x2) - class SymmetricContractionWrapper(torch.nn.Module): """Wrapper around SymmetricContraction/cuet.SymmetricContraction""" + class CuetForward: + def __init__(self, instance, layout): + self.instance = instance + self.layout = layout + + def __call__(self, x: torch.Tensor, attrs: torch.Tensor) -> torch.Tensor: + if self.layout == cue.mul_ir: + x = torch.transpose(x, 1, 2) + index_attrs = torch.nonzero(attrs)[:,1].int() + return self.instance( + x.flatten(1), + index_attrs, + ) - def __init__( - self, + def __new__( + cls, irreps_in: o3.Irreps, irreps_out: o3.Irreps, correlation: int, num_elements: Optional[int] = None, cueq_config: Optional[CuEquivarianceConfig] = None, ): - super().__init__() if ( CUET_AVAILABLE and cueq_config is not None and cueq_config.enabled and (cueq_config.optimize_all or cueq_config.optimize_symmetric) ): - self.sconctaction = cuet.SymmetricContraction( + instance = cuet.SymmetricContraction( cue.Irreps(cueq_config.group, irreps_in), cue.Irreps(cueq_config.group, irreps_out), layout_in=cue.ir_mul, @@ -258,33 +207,12 @@ def __init__( contraction_degree=correlation, num_elements=num_elements, ) - self.use_cuet = True - self.cueq_config = cueq_config - self.layout = cueq_config.layout - else: - self.sconctaction = SymmetricContraction( + instance.forward = cls.CuetForward(instance, cueq_config.layout) + return instance + + return SymmetricContraction( irreps_in=irreps_in, irreps_out=irreps_out, correlation=correlation, num_elements=num_elements, - ) - self.use_cuet = False - - def __getattr__(self, name): - """Forward any unknown attribute access to the underlying linear object""" - try: - return super().__getattr__(name) - except AttributeError: - return getattr(self.sc, name) - - def forward(self, x: torch.Tensor, attrs: torch.Tensor) -> torch.Tensor: - if self.use_cuet and hasattr(self, "cueq_config"): - if self.layout == cue.mul_ir: - x = torch.transpose(x, 1, 2) - index_attrs = torch.nonzero(attrs)[:,1].int() - return self.sconctaction( - x.flatten(1), - index_attrs, - use_fallback=not self.cueq_config.optimize_symmetric, - ) - return self.sconctaction(x, attrs) + ) \ No newline at end of file From 7aae22b6440df098ffe807fcdf24d5c3cb1ac629 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Mon, 18 Nov 2024 11:49:05 +0000 Subject: [PATCH 11/17] fix the formatting --- mace/modules/irreps_tools.py | 1 + mace/modules/wrapper_ops.py | 70 ++++++++++++++++++++++-------------- tests/test_cueq.py | 3 +- 3 files changed, 46 insertions(+), 28 deletions(-) diff --git a/mace/modules/irreps_tools.py b/mace/modules/irreps_tools.py index 3e4cc6f6..2e79c0ab 100644 --- a/mace/modules/irreps_tools.py +++ b/mace/modules/irreps_tools.py @@ -12,6 +12,7 @@ from mace.modules.wrapper_ops import CuEquivarianceConfig + # Based on mir-group/nequip def tp_out_irreps_with_instructions( irreps1: o3.Irreps, irreps2: o3.Irreps, target_irreps: o3.Irreps diff --git a/mace/modules/wrapper_ops.py b/mace/modules/wrapper_ops.py index 9ee68cdc..6ef98faa 100644 --- a/mace/modules/wrapper_ops.py +++ b/mace/modules/wrapper_ops.py @@ -23,7 +23,9 @@ if CUET_AVAILABLE: class O3_e3nn(cue.O3): - def __mul__(rep1: "O3_e3nn", rep2: "O3_e3nn") -> Iterator["O3_e3nn"]: # pylint: disable=no-self-argument + def __mul__( # pylint: disable=no-self-argument + rep1: "O3_e3nn", rep2: "O3_e3nn" + ) -> Iterator["O3_e3nn"]: return [O3_e3nn(l=ir.l, p=ir.p) for ir in cue.O3.__mul__(rep1, rep2)] @classmethod @@ -38,7 +40,9 @@ def clebsch_gordan( ) return np.zeros((0, rep1.dim, rep2.dim, rep3.dim)) - def __lt__(rep1: "O3_e3nn", rep2: "O3_e3nn") -> bool: # pylint: disable=no-self-argument + def __lt__( # pylint: disable=no-self-argument + rep1: "O3_e3nn", rep2: "O3_e3nn" + ) -> bool: rep2 = rep1._from(rep2) return (rep1.l, rep1.p) < (rep2.l, rep2.p) @@ -47,6 +51,7 @@ def iterator(cls) -> Iterator["O3_e3nn"]: for l in itertools.count(0): yield O3_e3nn(l=l, p=1 * (-1) ** l) yield O3_e3nn(l=l, p=-1 * (-1) ** l) + else: print( "cuequivariance or cuequivariance_torch is not available. Cuequivariance acceleration will be disabled." @@ -71,11 +76,14 @@ def __post_init__(self): if self.enabled and CUET_AVAILABLE: self.layout_str = self.layout self.layout = getattr(cue, self.layout) - self.group = O3_e3nn if self.group == "O3_e3nn" else getattr(cue, self.group) + self.group = ( + O3_e3nn if self.group == "O3_e3nn" else getattr(cue, self.group) + ) class Linear: """Returns either a cuet.Linear or o3.Linear based on config""" + def __new__( cls, irreps_in: o3.Irreps, @@ -84,8 +92,12 @@ def __new__( internal_weights: bool = True, cueq_config: Optional[CuEquivarianceConfig] = None, ): - if (CUET_AVAILABLE and cueq_config is not None and cueq_config.enabled - and (cueq_config.optimize_all or cueq_config.optimize_linear)): + if ( + CUET_AVAILABLE + and cueq_config is not None + and cueq_config.enabled + and (cueq_config.optimize_all or cueq_config.optimize_linear) + ): return cuet.Linear( cue.Irreps(cueq_config.group, irreps_in), cue.Irreps(cueq_config.group, irreps_out), @@ -99,6 +111,7 @@ def __new__( internal_weights=internal_weights, ) + class TensorProduct: """Wrapper around o3.TensorProduct/cuet.ChannelwiseTensorProduct""" @@ -127,13 +140,14 @@ def __new__( internal_weights=internal_weights, ) return o3.TensorProduct( - irreps_in1, - irreps_in2, - irreps_out, - instructions=instructions, - shared_weights=shared_weights, - internal_weights=internal_weights, - ) + irreps_in1, + irreps_in2, + irreps_out, + instructions=instructions, + shared_weights=shared_weights, + internal_weights=internal_weights, + ) + class FullyConnectedTensorProduct: """Wrapper around o3.FullyConnectedTensorProduct/cuet.FullyConnectedTensorProduct""" @@ -162,24 +176,26 @@ def __new__( internal_weights=internal_weights, ) return o3.FullyConnectedTensorProduct( - irreps_in1, - irreps_in2, - irreps_out, - shared_weights=shared_weights, - internal_weights=internal_weights, - ) + irreps_in1, + irreps_in2, + irreps_out, + shared_weights=shared_weights, + internal_weights=internal_weights, + ) + -class SymmetricContractionWrapper(torch.nn.Module): +class SymmetricContractionWrapper: """Wrapper around SymmetricContraction/cuet.SymmetricContraction""" + class CuetForward: def __init__(self, instance, layout): self.instance = instance self.layout = layout - + def __call__(self, x: torch.Tensor, attrs: torch.Tensor) -> torch.Tensor: if self.layout == cue.mul_ir: x = torch.transpose(x, 1, 2) - index_attrs = torch.nonzero(attrs)[:,1].int() + index_attrs = torch.nonzero(attrs)[:, 1].int() return self.instance( x.flatten(1), index_attrs, @@ -209,10 +225,10 @@ def __new__( ) instance.forward = cls.CuetForward(instance, cueq_config.layout) return instance - + return SymmetricContraction( - irreps_in=irreps_in, - irreps_out=irreps_out, - correlation=correlation, - num_elements=num_elements, - ) \ No newline at end of file + irreps_in=irreps_in, + irreps_out=irreps_out, + correlation=correlation, + num_elements=num_elements, + ) diff --git a/tests/test_cueq.py b/tests/test_cueq.py index 129c3489..3dce77cd 100644 --- a/tests/test_cueq.py +++ b/tests/test_cueq.py @@ -10,7 +10,8 @@ from mace.tools import torch_geometric try: - import cuequivariance as cue # pylint: disable=unused-import + import cuequivariance as cue # pylint: disable=unused-import + CUET_AVAILABLE = True except ImportError: CUET_AVAILABLE = False From 6451abc8ee5bca7659689098b57885588c8da932 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Mon, 18 Nov 2024 22:11:13 +0000 Subject: [PATCH 12/17] fix the convertion e3nn to cueq --- mace/cli/convert_cueq_e3nn.py | 144 ++++++++++++++++++++++++++++ mace/cli/convert_e3nn_cueq.py | 172 ++++++++++++++++++++++++++++++++++ mace/modules/wrapper_ops.py | 6 +- tests/test_cueq.py | 71 +++++++++----- 4 files changed, 369 insertions(+), 24 deletions(-) create mode 100644 mace/cli/convert_cueq_e3nn.py create mode 100644 mace/cli/convert_e3nn_cueq.py diff --git a/mace/cli/convert_cueq_e3nn.py b/mace/cli/convert_cueq_e3nn.py new file mode 100644 index 00000000..6363b2a4 --- /dev/null +++ b/mace/cli/convert_cueq_e3nn.py @@ -0,0 +1,144 @@ +import torch +import argparse +import logging +from pathlib import Path +from typing import Dict, List, Tuple +try: + import cuequivariance as cue + import cuequivariance_torch as cuet + CUET_AVAILABLE = True +except ImportError: + raise ImportError("cuequivariance or cuequivariance_torch is not available. Cuequivariance acceleration will be disabled.") + +from mace.modules.wrapper_ops import CuEquivarianceConfig +from mace.tools.scripts_utils import extract_config_mace_model +from mace import modules + + +def get_transfer_keys() -> List[str]: + """Get list of keys that need to be transferred""" + return [ + 'node_embedding.linear.weight', + 'radial_embedding.bessel_fn.bessel_weights', + 'atomic_energies_fn.atomic_energies', + 'readouts.0.linear.weight', + 'scale_shift.scale', + 'scale_shift.shift', + *[f'readouts.1.linear_{i}.weight' for i in range(1,3)] + ] + [ + s for j in range(2) for s in [ + f'interactions.{j}.linear_up.weight', + *[f'interactions.{j}.conv_tp_weights.layer{i}.weight' for i in range(4)], + f'interactions.{j}.linear.weight', + f'interactions.{j}.skip_tp.weight', + f'products.{j}.linear.weight' + ] + ] + +def get_kmax_pairs(max_L: int, correlation: int) -> List[Tuple[int, int]]: + """Determine kmax pairs based on max_L and correlation""" + if correlation == 2: + # For 3-body correlations + return [[0,1], [1,0]] + elif correlation == 3: + # For 4-body correlations + if max_L <= 2: + return [[0,1], [1,0]] + else: + return [[0,2], [1,0]] + else: + logging.warning(f"Unexpected correlation {correlation}, defaulting to [[0,1], [1,0]]") + return [[0,1], [1,0]] + +def transfer_symmetric_contractions(source_dict: Dict[str, torch.Tensor], + target_dict: Dict[str, torch.Tensor], + max_L: int, + correlation: int): + """Transfer symmetric contraction weights""" + kmax_pairs = get_kmax_pairs(max_L, correlation) + logging.info(f"Using kmax pairs {kmax_pairs} for max_L={max_L}, correlation={correlation}") + + for i, kmax in kmax_pairs: + for k in range(kmax + 1): + for suffix in ['.0', '.1', '_max']: + key = f'products.{i}.symmetric_contractions.contractions.{k}.weights{suffix}' + if key in source_dict: # Check if key exists to avoid errors + target_dict[key] = source_dict[key] + else: + logging.warning(f"Key {key} not found in source model") + +def transfer_weights(source_model: torch.nn.Module, target_model: torch.nn.Module, + max_L: int, correlation: int): + """Transfer weights with proper remapping""" + # Get source state dict + source_dict = source_model.state_dict() + target_dict = target_model.state_dict() + + # Transfer main weights + transfer_keys = get_transfer_keys() + logging.info("Transferring main weights...") + for key in transfer_keys: + if key in source_dict: # Check if key exists + target_dict[key] = source_dict[key] + else: + logging.warning(f"Key {key} not found in source model") + + # Transfer symmetric contractions + logging.info("Transferring symmetric contractions...") + transfer_symmetric_contractions(source_dict, target_dict, max_L, correlation) + + # Transfer avg_num_neighbors + for i in range(2): + target_model.interactions[i].avg_num_neighbors = source_model.interactions[i].avg_num_neighbors + + # Load state dict into target model + target_model.load_state_dict(target_dict) + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('input_model', help='Path to input cuequivariance model') + parser.add_argument('output_model', help='Path to output e3nn model') + parser.add_argument('--device', default='cpu', help='Device to use') + args = parser.parse_args() + + # Setup logging + logging.basicConfig(level=logging.INFO) + + # Load cuequivariance model + logging.info(f"Loading model from {args.input_model}") + source_model = torch.load(args.input_model, map_location=args.device) + + # Extract configuration + logging.info("Extracting model configuration") + config = extract_config_mace_model(source_model) + + # Get max_L and correlation from config + max_L = config["max_ell"] + correlation = config["correlation"] + logging.info(f"Extracted max_L={max_L}, correlation={correlation}") + + # Replace cuequivariance config with disabled version + config["cueq_config"] = CuEquivarianceConfig( + layout_str="ir_mul", + group="O3", + max_L=max_L, + correlation=correlation + ) + + # Create new model with e3nn config + logging.info("Creating new model with e3nn settings") + if isinstance(source_model, modules.MACE): + target_model = modules.MACE(**config) + else: + target_model = modules.ScaleShiftMACE(**config) + + # Transfer weights with proper remapping + logging.info("Transferring weights with remapping...") + transfer_weights(source_model, target_model, max_L, correlation) + + # Save model + logging.info(f"Saving e3nn model to {args.output_model}") + torch.save(target_model, args.output_model) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/mace/cli/convert_e3nn_cueq.py b/mace/cli/convert_e3nn_cueq.py new file mode 100644 index 00000000..13a10ab1 --- /dev/null +++ b/mace/cli/convert_e3nn_cueq.py @@ -0,0 +1,172 @@ +import torch +import argparse +import logging +from pathlib import Path +from typing import Dict, List, Tuple +import cuequivariance as cue + +from mace.modules.wrapper_ops import CuEquivarianceConfig +from mace.tools.scripts_utils import extract_config_mace_model +from mace import modules + + +def get_transfer_keys() -> List[str]: + """Get list of keys that need to be transferred""" + return [ + 'node_embedding.linear.weight', + 'radial_embedding.bessel_fn.bessel_weights', + 'atomic_energies_fn.atomic_energies', + 'readouts.0.linear.weight', + 'scale_shift.scale', + 'scale_shift.shift', + *[f'readouts.1.linear_{i}.weight' for i in range(1,3)] + ] + [ + s for j in range(2) for s in [ + f'interactions.{j}.linear_up.weight', + *[f'interactions.{j}.conv_tp_weights.layer{i}.weight' for i in range(4)], + f'interactions.{j}.linear.weight', + f'interactions.{j}.skip_tp.weight', + f'products.{j}.linear.weight' + ] + ] + +def get_kmax_pairs(max_L: int, correlation: int) -> List[Tuple[int, int]]: + """Determine kmax pairs based on max_L and correlation""" + if correlation == 2: + raise NotImplementedError("Correlation 2 not supported yet") + elif correlation == 3: + return [[0, max_L], [1, 0]] + else: + raise NotImplementedError(f"Correlation {correlation} not supported") + +def transfer_symmetric_contractions(source_dict: Dict[str, torch.Tensor], + target_dict: Dict[str, torch.Tensor], + max_L: int, + correlation: int): + """Transfer symmetric contraction weights""" + kmax_pairs = get_kmax_pairs(max_L, correlation) + logging.info(f"Using kmax pairs {kmax_pairs} for max_L={max_L}, correlation={correlation}") + + for i, kmax in kmax_pairs: + wm = torch.concatenate([ + source_dict[f'products.{i}.symmetric_contractions.contractions.{k}.weights{j}'] + for k in range(kmax+1) for j in ['_max','.0','.1']],dim=1) #.float() + target_dict[f'products.{i}.symmetric_contractions.weight'] = wm + +def transfer_weights(source_model: torch.nn.Module, target_model: torch.nn.Module, + max_L: int, correlation: int): + """Transfer weights with proper remapping""" + # Get source state dict + source_dict = source_model.state_dict() + target_dict = target_model.state_dict() + + # Transfer main weights + transfer_keys = get_transfer_keys() + logging.info("Transferring main weights...") + for key in transfer_keys: + if key in source_dict: # Check if key exists + target_dict[key] = source_dict[key] + else: + logging.warning(f"Key {key} not found in source model") + + # Transfer symmetric contractions + logging.info("Transferring symmetric contractions...") + transfer_symmetric_contractions(source_dict, target_dict, max_L, correlation) + + transferred_keys = set(transfer_keys) + remaining_keys = set(source_dict.keys()) & set(target_dict.keys()) - transferred_keys + remaining_keys = {k for k in remaining_keys if 'symmetric_contraction' not in k} + + if remaining_keys: + logging.info(f"Found {len(remaining_keys)} additional matching keys to transfer") + for key in remaining_keys: + if source_dict[key].shape == target_dict[key].shape: + logging.debug(f"Transferring additional key: {key}") + target_dict[key] = source_dict[key] + else: + logging.warning( + f"Shape mismatch for key {key}: " + f"source {source_dict[key].shape} vs target {target_dict[key].shape}" + ) + # Transfer avg_num_neighbors + for i in range(2): + target_model.interactions[i].avg_num_neighbors = source_model.interactions[i].avg_num_neighbors + + # Load state dict into target model + target_model.load_state_dict(target_dict) + +def run( + input_model, + output_model, + device='cuda', + layout='mul_ir', + group='O3', + return_model=True +): + # Setup logging + logging.basicConfig(level=logging.INFO) + + # Load original model + logging.info(f"Loading model from {input_model}") + # check if input_model is a path or a model + if isinstance(input_model, str): + source_model = torch.load(input_model, map_location=device) + else: + source_model = input_model + + # Extract configuration + logging.info("Extracting model configuration") + config = extract_config_mace_model(source_model) + + # Get max_L and correlation from config + max_L = config["hidden_irreps"].lmax + correlation = config["correlation"] + logging.info(f"Extracted max_L={max_L}, correlation={correlation}") + + # Add cuequivariance config + config["cueq_config"] = CuEquivarianceConfig( + enabled=True, + layout="mul_ir", + group="O3_e3nn", + optimize_all=True, + ) + + # Create new model with cuequivariance config + logging.info("Creating new model with cuequivariance settings") + target_model = source_model.__class__(**config) + + # Transfer weights with proper remapping + logging.info("Transferring weights with remapping...") + transfer_weights(source_model, target_model, max_L, correlation) + + if return_model: + return target_model + else: + # Save model + output_model = Path(input_model).parent / output_model + logging.info(f"Saving cuequivariance model to {output_model}") + torch.save(target_model, output_model) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('input_model', help='Path to input MACE model') + parser.add_argument('output_model', help='Path to output cuequivariance model', default='cuequivariance_model.pt') + parser.add_argument('--device', default='cpu', help='Device to use') + parser.add_argument('--layout', default='mul_ir', choices=['mul_ir', 'ir_mul'], help='Memory layout for tensors') + parser.add_argument('--group', default='O3_e3nn', choices=['O3', 'O3_e3nn'], help='Symmetry group') + parser.add_argument('--return_model', action='store_false', help='Return model instead of saving to file') + args = parser.parse_args() + + run( + input_model=args.input_model, + output_model=args.output_model, + device=args.device, + layout=args.layout, + group=args.group, + return_model=args.return_model + ) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/mace/modules/wrapper_ops.py b/mace/modules/wrapper_ops.py index 6ef98faa..87770d3d 100644 --- a/mace/modules/wrapper_ops.py +++ b/mace/modules/wrapper_ops.py @@ -191,12 +191,13 @@ class CuetForward: def __init__(self, instance, layout): self.instance = instance self.layout = layout + self.original_forward = instance.forward def __call__(self, x: torch.Tensor, attrs: torch.Tensor) -> torch.Tensor: if self.layout == cue.mul_ir: x = torch.transpose(x, 1, 2) index_attrs = torch.nonzero(attrs)[:, 1].int() - return self.instance( + return self.original_forward( x.flatten(1), index_attrs, ) @@ -222,6 +223,9 @@ def __new__( layout_out=cueq_config.layout, contraction_degree=correlation, num_elements=num_elements, + original_mace=True, + dtype=torch.get_default_dtype(), + math_dtype=torch.get_default_dtype(), ) instance.forward = cls.CuetForward(instance, cueq_config.layout) return instance diff --git a/tests/test_cueq.py b/tests/test_cueq.py index 3dce77cd..e11ed5ed 100644 --- a/tests/test_cueq.py +++ b/tests/test_cueq.py @@ -8,34 +8,33 @@ from mace import data, modules, tools from mace.modules.wrapper_ops import CuEquivarianceConfig from mace.tools import torch_geometric +from mace.cli.convert_e3nn_cueq import run as run_convert try: import cuequivariance as cue # pylint: disable=unused-import - CUET_AVAILABLE = True except ImportError: CUET_AVAILABLE = False +torch.set_default_dtype(torch.float64) @pytest.mark.skipif(not CUET_AVAILABLE, reason="cuequivariance not installed") class TestCueq: @pytest.fixture - def model_config(self) -> Dict[str, Any]: + def model_config(self, interaction_cls_first, hidden_irreps) -> Dict[str, Any]: table = tools.AtomicNumberTable([6]) + print("interaction_cls_first", interaction_cls_first) + print("hidden_irreps", hidden_irreps) return { "r_max": 5.0, "num_bessel": 8, "num_polynomial_cutoff": 6, "max_ell": 3, - "interaction_cls": modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - "interaction_cls_first": modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], + "interaction_cls": modules.interaction_classes["RealAgnosticResidualInteractionBlock"], + "interaction_cls_first": interaction_cls_first, "num_interactions": 2, "num_elements": 1, - "hidden_irreps": o3.Irreps("32x0e + 32x1o"), + "hidden_irreps": hidden_irreps, "MLP_irreps": o3.Irreps("16x0e"), "gate": F.silu, "atomic_energies": torch.tensor([1.0]), @@ -43,15 +42,21 @@ def model_config(self) -> Dict[str, Any]: "atomic_numbers": table.zs, "correlation": 3, "radial_type": "bessel", + "atomic_inter_scale": 1.0, + "atomic_inter_shift": 0.0, } @pytest.fixture def batch(self, device: str): from ase import build - table = tools.AtomicNumberTable([6]) atoms = build.bulk("C", "diamond", a=3.567, cubic=True) + import numpy as np + displacement = np.random.uniform( + -0.1, 0.1, size=atoms.positions.shape + ) + atoms.positions += displacement atoms_list = [atoms.repeat((2, 2, 2))] configs = [data.config_from_atoms(atoms) for atoms in atoms_list] @@ -68,13 +73,28 @@ def batch(self, device: str): return batch.to(device).to_dict() @pytest.mark.parametrize("device", ["cuda"]) + @pytest.mark.parametrize("interaction_cls_first", [ + modules.interaction_classes["RealAgnosticResidualInteractionBlock"], + modules.interaction_classes["RealAgnosticInteractionBlock"], + modules.interaction_classes["RealAgnosticDensityInteractionBlock"], + ]) + @pytest.mark.parametrize("hidden_irreps", [ + #o3.Irreps("32x0e + 32x1o"), + #o3.Irreps("32x0e + 32x1o + 32x2e"), + o3.Irreps("32x0e"), + ]) def test_cueq_equivalence( - self, model_config: Dict[str, Any], batch: Dict[str, torch.Tensor], device: str + self, + model_config: Dict[str, Any], + batch: Dict[str, torch.Tensor], + device: str, + interaction_cls_first, + hidden_irreps ): torch.manual_seed(42) # Create model without cuequivariance - model_std = modules.MACE(**model_config) + model_std = modules.ScaleShiftMACE(**model_config) model_std = model_std.to(device) # Create model with cuequivariance @@ -82,25 +102,30 @@ def test_cueq_equivalence( enabled=True, layout="mul_ir", group="O3_e3nn", optimize_all=True ) model_config["cueq_config"] = cueq_config - model_cueq = modules.MACE(**model_config) + model_cueq = modules.ScaleShiftMACE(**model_config) model_cueq = model_cueq.to(device) - # Copy weights - # model_cueq.load_state_dict(model_std.state_dict()) - + # Copy weights + model_cueq_convert = run_convert(model_std, None) + model_cueq_convert = model_cueq_convert.to(device) + # Compare outputs out_std = model_std(batch, training=True) - out_cueq = model_cueq(batch, training=True) - - torch.testing.assert_close(out_std["energy"], out_cueq["energy"]) - torch.testing.assert_close(out_std["forces"], out_cueq["forces"]) + out_cueq_convert = model_cueq_convert(batch, training=True) + torch.testing.assert_close(out_std["energy"], out_cueq_convert["energy"]) + torch.testing.assert_close(out_std["forces"], out_cueq_convert["forces"]) + loss_std = out_std["energy"].sum() - loss_cueq = out_cueq["energy"].sum() + loss_cueq = out_cueq_convert["energy"].sum() loss_std.backward() loss_cueq.backward() - for p1, p2 in zip(model_std.parameters(), model_cueq.parameters()): + for (name_1, p1), (name_2, p2) in zip(model_std.named_parameters(), model_cueq_convert.named_parameters()): if p1.grad is not None: - torch.testing.assert_close(p1.grad, p2.grad) + if p1.grad.shape == p2.grad.shape: + if name_1.split(".", 2)[:2] == name_2.split(".", 2)[:2]: + error = torch.abs(p1.grad - p2.grad) + print(f"Parameter {name_1}, Parameter {name_2}, Max error: {error.max()}") + torch.testing.assert_close(p1.grad, p2.grad, atol=1e-5, rtol=1e-10) \ No newline at end of file From 8b51f34732b35600ed3646f0c010e332b278771e Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Mon, 18 Nov 2024 22:49:46 +0000 Subject: [PATCH 13/17] fix e3nn to cueq convertion and add full circular test --- mace/cli/convert_cueq_e3nn.py | 137 ++++++++++++++++++++-------------- mace/cli/convert_e3nn_cueq.py | 2 - mace/tools/scripts_utils.py | 11 ++- tests/test_cueq.py | 92 ++++++++++++++--------- 4 files changed, 147 insertions(+), 95 deletions(-) diff --git a/mace/cli/convert_cueq_e3nn.py b/mace/cli/convert_cueq_e3nn.py index 6363b2a4..286540dd 100644 --- a/mace/cli/convert_cueq_e3nn.py +++ b/mace/cli/convert_cueq_e3nn.py @@ -3,17 +3,8 @@ import logging from pathlib import Path from typing import Dict, List, Tuple -try: - import cuequivariance as cue - import cuequivariance_torch as cuet - CUET_AVAILABLE = True -except ImportError: - raise ImportError("cuequivariance or cuequivariance_torch is not available. Cuequivariance acceleration will be disabled.") -from mace.modules.wrapper_ops import CuEquivarianceConfig from mace.tools.scripts_utils import extract_config_mace_model -from mace import modules - def get_transfer_keys() -> List[str]: """Get list of keys that need to be transferred""" @@ -38,39 +29,47 @@ def get_transfer_keys() -> List[str]: def get_kmax_pairs(max_L: int, correlation: int) -> List[Tuple[int, int]]: """Determine kmax pairs based on max_L and correlation""" if correlation == 2: - # For 3-body correlations - return [[0,1], [1,0]] + raise NotImplementedError("Correlation 2 not supported yet") elif correlation == 3: - # For 4-body correlations - if max_L <= 2: - return [[0,1], [1,0]] - else: - return [[0,2], [1,0]] + return [[0, max_L], [1, 0]] else: - logging.warning(f"Unexpected correlation {correlation}, defaulting to [[0,1], [1,0]]") - return [[0,1], [1,0]] + raise NotImplementedError(f"Correlation {correlation} not supported") def transfer_symmetric_contractions(source_dict: Dict[str, torch.Tensor], target_dict: Dict[str, torch.Tensor], max_L: int, correlation: int): - """Transfer symmetric contraction weights""" + """Transfer symmetric contraction weights from CuEq to E3nn format""" kmax_pairs = get_kmax_pairs(max_L, correlation) logging.info(f"Using kmax pairs {kmax_pairs} for max_L={max_L}, correlation={correlation}") - + for i, kmax in kmax_pairs: + # Get the combined weight tensor from source + wm = source_dict[f'products.{i}.symmetric_contractions.weight'] + + # Get split sizes based on target dimensions + splits = [] for k in range(kmax + 1): - for suffix in ['.0', '.1', '_max']: + for suffix in ['_max', '.0', '.1']: key = f'products.{i}.symmetric_contractions.contractions.{k}.weights{suffix}' - if key in source_dict: # Check if key exists to avoid errors - target_dict[key] = source_dict[key] - else: - logging.warning(f"Key {key} not found in source model") + target_shape = target_dict[key].shape + splits.append(target_shape[1]) + + # Split the weights using the calculated sizes + weights_split = torch.split(wm, splits, dim=1) + + # Assign back to target dictionary + idx = 0 + for k in range(kmax + 1): + target_dict[f'products.{i}.symmetric_contractions.contractions.{k}.weights_max'] = weights_split[idx] + target_dict[f'products.{i}.symmetric_contractions.contractions.{k}.weights.0'] = weights_split[idx + 1] + target_dict[f'products.{i}.symmetric_contractions.contractions.{k}.weights.1'] = weights_split[idx + 2] + idx += 3 def transfer_weights(source_model: torch.nn.Module, target_model: torch.nn.Module, max_L: int, correlation: int): - """Transfer weights with proper remapping""" - # Get source state dict + """Transfer weights from CuEq to E3nn format""" + # Get state dicts source_dict = source_model.state_dict() target_dict = target_model.state_dict() @@ -87,6 +86,23 @@ def transfer_weights(source_model: torch.nn.Module, target_model: torch.nn.Modul logging.info("Transferring symmetric contractions...") transfer_symmetric_contractions(source_dict, target_dict, max_L, correlation) + # Transfer remaining matching keys + transferred_keys = set(transfer_keys) + remaining_keys = set(source_dict.keys()) & set(target_dict.keys()) - transferred_keys + remaining_keys = {k for k in remaining_keys if 'symmetric_contraction' not in k} + + if remaining_keys: + logging.info(f"Found {len(remaining_keys)} additional matching keys to transfer") + for key in remaining_keys: + if source_dict[key].shape == target_dict[key].shape: + logging.debug(f"Transferring additional key: {key}") + target_dict[key] = source_dict[key] + else: + logging.warning( + f"Shape mismatch for key {key}: " + f"source {source_dict[key].shape} vs target {target_dict[key].shape}" + ) + # Transfer avg_num_neighbors for i in range(2): target_model.interactions[i].avg_num_neighbors = source_model.interactions[i].avg_num_neighbors @@ -94,51 +110,64 @@ def transfer_weights(source_model: torch.nn.Module, target_model: torch.nn.Modul # Load state dict into target model target_model.load_state_dict(target_dict) -def main(): - parser = argparse.ArgumentParser() - parser.add_argument('input_model', help='Path to input cuequivariance model') - parser.add_argument('output_model', help='Path to output e3nn model') - parser.add_argument('--device', default='cpu', help='Device to use') - args = parser.parse_args() - +def run( + input_model, + output_model, + device='cuda', + return_model=True +): # Setup logging logging.basicConfig(level=logging.INFO) - # Load cuequivariance model - logging.info(f"Loading model from {args.input_model}") - source_model = torch.load(args.input_model, map_location=args.device) + # Load CuEq model + logging.info(f"Loading CuEq model from {input_model}") + if isinstance(input_model, str): + source_model = torch.load(input_model, map_location=device) + else: + source_model = input_model # Extract configuration logging.info("Extracting model configuration") config = extract_config_mace_model(source_model) # Get max_L and correlation from config - max_L = config["max_ell"] + max_L = config["hidden_irreps"].lmax correlation = config["correlation"] logging.info(f"Extracted max_L={max_L}, correlation={correlation}") - # Replace cuequivariance config with disabled version - config["cueq_config"] = CuEquivarianceConfig( - layout_str="ir_mul", - group="O3", - max_L=max_L, - correlation=correlation - ) + # Remove CuEq config + config.pop("cueq_config", None) - # Create new model with e3nn config - logging.info("Creating new model with e3nn settings") - if isinstance(source_model, modules.MACE): - target_model = modules.MACE(**config) - else: - target_model = modules.ScaleShiftMACE(**config) + # Create new model without CuEq config + logging.info("Creating new model without CuEq settings") + target_model = source_model.__class__(**config) # Transfer weights with proper remapping logging.info("Transferring weights with remapping...") transfer_weights(source_model, target_model, max_L, correlation) - # Save model - logging.info(f"Saving e3nn model to {args.output_model}") - torch.save(target_model, args.output_model) + if return_model: + return target_model + else: + # Save model + output_model = Path(input_model).parent / output_model + logging.info(f"Saving E3nn model to {output_model}") + torch.save(target_model, output_model) + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('input_model', help='Path to input CuEq model') + parser.add_argument('output_model', help='Path to output E3nn model', default='e3nn_model.pt') + parser.add_argument('--device', default='cpu', help='Device to use') + parser.add_argument('--return_model', action='store_false', help='Return model instead of saving to file') + args = parser.parse_args() + + run( + input_model=args.input_model, + output_model=args.output_model, + device=args.device, + return_model=args.return_model + ) if __name__ == '__main__': main() \ No newline at end of file diff --git a/mace/cli/convert_e3nn_cueq.py b/mace/cli/convert_e3nn_cueq.py index 13a10ab1..e0f6abba 100644 --- a/mace/cli/convert_e3nn_cueq.py +++ b/mace/cli/convert_e3nn_cueq.py @@ -3,11 +3,9 @@ import logging from pathlib import Path from typing import Dict, List, Tuple -import cuequivariance as cue from mace.modules.wrapper_ops import CuEquivarianceConfig from mace.tools.scripts_utils import extract_config_mace_model -from mace import modules def get_transfer_keys() -> List[str]: diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index be96558d..cd201f2b 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -175,6 +175,12 @@ def radial_to_transform(radial): scale = model.scale_shift.scale shift = model.scale_shift.shift + try: + correlation = len( + model.products[0].symmetric_contractions.contractions[0].weights + ) + 1 + except AttributeError: + correlation = model.products[0].symmetric_contractions.contraction_degree config = { "r_max": model.r_max.item(), "num_bessel": len(model.radial_embedding.bessel_fn.bessel_weights), @@ -200,10 +206,7 @@ def radial_to_transform(radial): "atomic_energies": model.atomic_energies_fn.atomic_energies.cpu().numpy(), "avg_num_neighbors": model.interactions[0].avg_num_neighbors, "atomic_numbers": model.atomic_numbers, - "correlation": len( - model.products[0].symmetric_contractions.contractions[0].weights - ) - + 1, + "correlation": correlation, "radial_type": radial_to_name( model.radial_embedding.bessel_fn.__class__.__name__ ), diff --git a/tests/test_cueq.py b/tests/test_cueq.py index e11ed5ed..9ab0a38b 100644 --- a/tests/test_cueq.py +++ b/tests/test_cueq.py @@ -8,7 +8,8 @@ from mace import data, modules, tools from mace.modules.wrapper_ops import CuEquivarianceConfig from mace.tools import torch_geometric -from mace.cli.convert_e3nn_cueq import run as run_convert +from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq +from mace.cli.convert_cueq_e3nn import run as run_cueq_to_e3nn try: import cuequivariance as cue # pylint: disable=unused-import @@ -79,11 +80,11 @@ def batch(self, device: str): modules.interaction_classes["RealAgnosticDensityInteractionBlock"], ]) @pytest.mark.parametrize("hidden_irreps", [ - #o3.Irreps("32x0e + 32x1o"), - #o3.Irreps("32x0e + 32x1o + 32x2e"), + o3.Irreps("32x0e + 32x1o"), + o3.Irreps("32x0e + 32x1o + 32x2e"), o3.Irreps("32x0e"), ]) - def test_cueq_equivalence( + def test_bidirectional_conversion( self, model_config: Dict[str, Any], batch: Dict[str, torch.Tensor], @@ -93,39 +94,60 @@ def test_cueq_equivalence( ): torch.manual_seed(42) - # Create model without cuequivariance - model_std = modules.ScaleShiftMACE(**model_config) - model_std = model_std.to(device) + # Create original E3nn model + model_e3nn = modules.ScaleShiftMACE(**model_config) + model_e3nn = model_e3nn.to(device) - # Create model with cuequivariance - cueq_config = CuEquivarianceConfig( - enabled=True, layout="mul_ir", group="O3_e3nn", optimize_all=True - ) - model_config["cueq_config"] = cueq_config - model_cueq = modules.ScaleShiftMACE(**model_config) + # Convert E3nn to CuEq + model_cueq = run_e3nn_to_cueq(model_e3nn, None) model_cueq = model_cueq.to(device) - # Copy weights - model_cueq_convert = run_convert(model_std, None) - model_cueq_convert = model_cueq_convert.to(device) - - # Compare outputs - out_std = model_std(batch, training=True) - out_cueq_convert = model_cueq_convert(batch, training=True) - - torch.testing.assert_close(out_std["energy"], out_cueq_convert["energy"]) - torch.testing.assert_close(out_std["forces"], out_cueq_convert["forces"]) - - loss_std = out_std["energy"].sum() - loss_cueq = out_cueq_convert["energy"].sum() - - loss_std.backward() + # Convert CuEq back to E3nn + model_e3nn_back = run_cueq_to_e3nn(model_cueq, None) + model_e3nn_back = model_e3nn_back.to(device) + + # Test forward pass equivalence + out_e3nn = model_e3nn(batch, training=True) + out_cueq = model_cueq(batch, training=True) + out_e3nn_back = model_e3nn_back(batch, training=True) + + # Check outputs match for both conversions + torch.testing.assert_close(out_e3nn["energy"], out_cueq["energy"]) + torch.testing.assert_close(out_cueq["energy"], out_e3nn_back["energy"]) + torch.testing.assert_close(out_e3nn["forces"], out_cueq["forces"]) + torch.testing.assert_close(out_cueq["forces"], out_e3nn_back["forces"]) + + # Test backward pass equivalence + loss_e3nn = out_e3nn["energy"].sum() + loss_cueq = out_cueq["energy"].sum() + loss_e3nn_back = out_e3nn_back["energy"].sum() + + loss_e3nn.backward() loss_cueq.backward() + loss_e3nn_back.backward() + + # Compare gradients for all conversions + def print_gradient_diff(name1, p1, name2, p2, conv_type): + if p1.grad is not None and p1.grad.shape == p2.grad.shape: + if name1.split(".", 2)[:2] == name2.split(".", 2)[:2]: + error = torch.abs(p1.grad - p2.grad) + print(f"{conv_type} - Parameter {name1}/{name2}, Max error: {error.max()}") + torch.testing.assert_close(p1.grad, p2.grad, atol=1e-5, rtol=1e-10) + + # E3nn to CuEq gradients + for (name_e3nn, p_e3nn), (name_cueq, p_cueq) in zip( + model_e3nn.named_parameters(), model_cueq.named_parameters() + ): + print_gradient_diff(name_e3nn, p_e3nn, name_cueq, p_cueq, "E3nn->CuEq") + + # CuEq to E3nn gradients + for (name_cueq, p_cueq), (name_e3nn_back, p_e3nn_back) in zip( + model_cueq.named_parameters(), model_e3nn_back.named_parameters() + ): + print_gradient_diff(name_cueq, p_cueq, name_e3nn_back, p_e3nn_back, "CuEq->E3nn") - for (name_1, p1), (name_2, p2) in zip(model_std.named_parameters(), model_cueq_convert.named_parameters()): - if p1.grad is not None: - if p1.grad.shape == p2.grad.shape: - if name_1.split(".", 2)[:2] == name_2.split(".", 2)[:2]: - error = torch.abs(p1.grad - p2.grad) - print(f"Parameter {name_1}, Parameter {name_2}, Max error: {error.max()}") - torch.testing.assert_close(p1.grad, p2.grad, atol=1e-5, rtol=1e-10) \ No newline at end of file + # Full circle comparison (E3nn -> E3nn) + for (name_e3nn, p_e3nn), (name_e3nn_back, p_e3nn_back) in zip( + model_e3nn.named_parameters(), model_e3nn_back.named_parameters() + ): + print_gradient_diff(name_e3nn, p_e3nn, name_e3nn_back, p_e3nn_back, "Full circle") \ No newline at end of file From e17943208c7674a4494f7667dc9e989b6cc08ba1 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Tue, 19 Nov 2024 09:57:21 +0000 Subject: [PATCH 14/17] fix all formatting --- mace/cli/convert_cueq_e3nn.py | 177 ++++++++++++++++++++-------------- mace/cli/convert_e3nn_cueq.py | 163 ++++++++++++++++++------------- mace/modules/wrapper_ops.py | 4 +- mace/tools/scripts_utils.py | 6 +- setup.cfg | 2 + tests/test_cueq.py | 67 +++++++------ 6 files changed, 249 insertions(+), 170 deletions(-) diff --git a/mace/cli/convert_cueq_e3nn.py b/mace/cli/convert_cueq_e3nn.py index 286540dd..57b2808e 100644 --- a/mace/cli/convert_cueq_e3nn.py +++ b/mace/cli/convert_cueq_e3nn.py @@ -1,78 +1,98 @@ -import torch import argparse import logging -from pathlib import Path +import os from typing import Dict, List, Tuple +import torch + from mace.tools.scripts_utils import extract_config_mace_model + def get_transfer_keys() -> List[str]: """Get list of keys that need to be transferred""" return [ - 'node_embedding.linear.weight', - 'radial_embedding.bessel_fn.bessel_weights', - 'atomic_energies_fn.atomic_energies', - 'readouts.0.linear.weight', - 'scale_shift.scale', - 'scale_shift.shift', - *[f'readouts.1.linear_{i}.weight' for i in range(1,3)] + "node_embedding.linear.weight", + "radial_embedding.bessel_fn.bessel_weights", + "atomic_energies_fn.atomic_energies", + "readouts.0.linear.weight", + "scale_shift.scale", + "scale_shift.shift", + *[f"readouts.1.linear_{i}.weight" for i in range(1, 3)], ] + [ - s for j in range(2) for s in [ - f'interactions.{j}.linear_up.weight', - *[f'interactions.{j}.conv_tp_weights.layer{i}.weight' for i in range(4)], - f'interactions.{j}.linear.weight', - f'interactions.{j}.skip_tp.weight', - f'products.{j}.linear.weight' + s + for j in range(2) + for s in [ + f"interactions.{j}.linear_up.weight", + *[f"interactions.{j}.conv_tp_weights.layer{i}.weight" for i in range(4)], + f"interactions.{j}.linear.weight", + f"interactions.{j}.skip_tp.weight", + f"products.{j}.linear.weight", ] ] + def get_kmax_pairs(max_L: int, correlation: int) -> List[Tuple[int, int]]: """Determine kmax pairs based on max_L and correlation""" if correlation == 2: raise NotImplementedError("Correlation 2 not supported yet") - elif correlation == 3: + if correlation == 3: return [[0, max_L], [1, 0]] - else: - raise NotImplementedError(f"Correlation {correlation} not supported") + raise NotImplementedError(f"Correlation {correlation} not supported") + -def transfer_symmetric_contractions(source_dict: Dict[str, torch.Tensor], - target_dict: Dict[str, torch.Tensor], - max_L: int, - correlation: int): +def transfer_symmetric_contractions( + source_dict: Dict[str, torch.Tensor], + target_dict: Dict[str, torch.Tensor], + max_L: int, + correlation: int, +): """Transfer symmetric contraction weights from CuEq to E3nn format""" kmax_pairs = get_kmax_pairs(max_L, correlation) - logging.info(f"Using kmax pairs {kmax_pairs} for max_L={max_L}, correlation={correlation}") - + logging.info( + f"Using kmax pairs {kmax_pairs} for max_L={max_L}, correlation={correlation}" + ) + for i, kmax in kmax_pairs: # Get the combined weight tensor from source - wm = source_dict[f'products.{i}.symmetric_contractions.weight'] - + wm = source_dict[f"products.{i}.symmetric_contractions.weight"] + # Get split sizes based on target dimensions splits = [] for k in range(kmax + 1): - for suffix in ['_max', '.0', '.1']: - key = f'products.{i}.symmetric_contractions.contractions.{k}.weights{suffix}' + for suffix in ["_max", ".0", ".1"]: + key = f"products.{i}.symmetric_contractions.contractions.{k}.weights{suffix}" target_shape = target_dict[key].shape splits.append(target_shape[1]) - + # Split the weights using the calculated sizes weights_split = torch.split(wm, splits, dim=1) - + # Assign back to target dictionary idx = 0 for k in range(kmax + 1): - target_dict[f'products.{i}.symmetric_contractions.contractions.{k}.weights_max'] = weights_split[idx] - target_dict[f'products.{i}.symmetric_contractions.contractions.{k}.weights.0'] = weights_split[idx + 1] - target_dict[f'products.{i}.symmetric_contractions.contractions.{k}.weights.1'] = weights_split[idx + 2] + target_dict[ + f"products.{i}.symmetric_contractions.contractions.{k}.weights_max" + ] = weights_split[idx] + target_dict[ + f"products.{i}.symmetric_contractions.contractions.{k}.weights.0" + ] = weights_split[idx + 1] + target_dict[ + f"products.{i}.symmetric_contractions.contractions.{k}.weights.1" + ] = weights_split[idx + 2] idx += 3 -def transfer_weights(source_model: torch.nn.Module, target_model: torch.nn.Module, - max_L: int, correlation: int): + +def transfer_weights( + source_model: torch.nn.Module, + target_model: torch.nn.Module, + max_L: int, + correlation: int, +): """Transfer weights from CuEq to E3nn format""" # Get state dicts source_dict = source_model.state_dict() target_dict = target_model.state_dict() - + # Transfer main weights transfer_keys = get_transfer_keys() logging.info("Transferring main weights...") @@ -81,18 +101,22 @@ def transfer_weights(source_model: torch.nn.Module, target_model: torch.nn.Modul target_dict[key] = source_dict[key] else: logging.warning(f"Key {key} not found in source model") - + # Transfer symmetric contractions logging.info("Transferring symmetric contractions...") transfer_symmetric_contractions(source_dict, target_dict, max_L, correlation) - - # Transfer remaining matching keys + + # Transfer remaining matching keys transferred_keys = set(transfer_keys) - remaining_keys = set(source_dict.keys()) & set(target_dict.keys()) - transferred_keys - remaining_keys = {k for k in remaining_keys if 'symmetric_contraction' not in k} - + remaining_keys = ( + set(source_dict.keys()) & set(target_dict.keys()) - transferred_keys + ) + remaining_keys = {k for k in remaining_keys if "symmetric_contraction" not in k} + if remaining_keys: - logging.info(f"Found {len(remaining_keys)} additional matching keys to transfer") + logging.info( + f"Found {len(remaining_keys)} additional matching keys to transfer" + ) for key in remaining_keys: if source_dict[key].shape == target_dict[key].shape: logging.debug(f"Transferring additional key: {key}") @@ -102,72 +126,81 @@ def transfer_weights(source_model: torch.nn.Module, target_model: torch.nn.Modul f"Shape mismatch for key {key}: " f"source {source_dict[key].shape} vs target {target_dict[key].shape}" ) - + # Transfer avg_num_neighbors for i in range(2): - target_model.interactions[i].avg_num_neighbors = source_model.interactions[i].avg_num_neighbors - + target_model.interactions[i].avg_num_neighbors = source_model.interactions[ + i + ].avg_num_neighbors + # Load state dict into target model target_model.load_state_dict(target_dict) -def run( - input_model, - output_model, - device='cuda', - return_model=True -): + +def run(input_model, output_model="_e3nn.model", device="cuda", return_model=True): # Setup logging logging.basicConfig(level=logging.INFO) - + # Load CuEq model logging.info(f"Loading CuEq model from {input_model}") if isinstance(input_model, str): source_model = torch.load(input_model, map_location=device) else: source_model = input_model - + # Extract configuration logging.info("Extracting model configuration") config = extract_config_mace_model(source_model) - + # Get max_L and correlation from config max_L = config["hidden_irreps"].lmax correlation = config["correlation"] logging.info(f"Extracted max_L={max_L}, correlation={correlation}") - + # Remove CuEq config config.pop("cueq_config", None) - + # Create new model without CuEq config logging.info("Creating new model without CuEq settings") target_model = source_model.__class__(**config) - + # Transfer weights with proper remapping logging.info("Transferring weights with remapping...") transfer_weights(source_model, target_model, max_L, correlation) - + if return_model: return target_model - else: - # Save model - output_model = Path(input_model).parent / output_model - logging.info(f"Saving E3nn model to {output_model}") - torch.save(target_model, output_model) + + # Save model + if isinstance(input_model, str): + base = os.path.splitext(input_model)[0] + output_model = f"{base}.{output_model}" + logging.info(f"Saving E3nn model to {output_model}") + torch.save(target_model, output_model) + return None + def main(): parser = argparse.ArgumentParser() - parser.add_argument('input_model', help='Path to input CuEq model') - parser.add_argument('output_model', help='Path to output E3nn model', default='e3nn_model.pt') - parser.add_argument('--device', default='cpu', help='Device to use') - parser.add_argument('--return_model', action='store_false', help='Return model instead of saving to file') + parser.add_argument("input_model", help="Path to input CuEq model") + parser.add_argument( + "--output_model", help="Path to output E3nn model", default="e3nn_model.pt" + ) + parser.add_argument("--device", default="cpu", help="Device to use") + parser.add_argument( + "--return_model", + action="store_false", + help="Return model instead of saving to file", + ) args = parser.parse_args() - + run( input_model=args.input_model, output_model=args.output_model, device=args.device, - return_model=args.return_model + return_model=args.return_model, ) -if __name__ == '__main__': - main() \ No newline at end of file + +if __name__ == "__main__": + main() diff --git a/mace/cli/convert_e3nn_cueq.py b/mace/cli/convert_e3nn_cueq.py index e0f6abba..715d7bbf 100644 --- a/mace/cli/convert_e3nn_cueq.py +++ b/mace/cli/convert_e3nn_cueq.py @@ -1,9 +1,10 @@ -import torch import argparse import logging -from pathlib import Path +import os from typing import Dict, List, Tuple +import torch + from mace.modules.wrapper_ops import CuEquivarianceConfig from mace.tools.scripts_utils import extract_config_mace_model @@ -11,53 +12,72 @@ def get_transfer_keys() -> List[str]: """Get list of keys that need to be transferred""" return [ - 'node_embedding.linear.weight', - 'radial_embedding.bessel_fn.bessel_weights', - 'atomic_energies_fn.atomic_energies', - 'readouts.0.linear.weight', - 'scale_shift.scale', - 'scale_shift.shift', - *[f'readouts.1.linear_{i}.weight' for i in range(1,3)] + "node_embedding.linear.weight", + "radial_embedding.bessel_fn.bessel_weights", + "atomic_energies_fn.atomic_energies", + "readouts.0.linear.weight", + "scale_shift.scale", + "scale_shift.shift", + *[f"readouts.1.linear_{i}.weight" for i in range(1, 3)], ] + [ - s for j in range(2) for s in [ - f'interactions.{j}.linear_up.weight', - *[f'interactions.{j}.conv_tp_weights.layer{i}.weight' for i in range(4)], - f'interactions.{j}.linear.weight', - f'interactions.{j}.skip_tp.weight', - f'products.{j}.linear.weight' + s + for j in range(2) + for s in [ + f"interactions.{j}.linear_up.weight", + *[f"interactions.{j}.conv_tp_weights.layer{i}.weight" for i in range(4)], + f"interactions.{j}.linear.weight", + f"interactions.{j}.skip_tp.weight", + f"products.{j}.linear.weight", ] ] + def get_kmax_pairs(max_L: int, correlation: int) -> List[Tuple[int, int]]: """Determine kmax pairs based on max_L and correlation""" if correlation == 2: raise NotImplementedError("Correlation 2 not supported yet") - elif correlation == 3: + if correlation == 3: return [[0, max_L], [1, 0]] - else: - raise NotImplementedError(f"Correlation {correlation} not supported") + raise NotImplementedError(f"Correlation {correlation} not supported") + -def transfer_symmetric_contractions(source_dict: Dict[str, torch.Tensor], - target_dict: Dict[str, torch.Tensor], - max_L: int, - correlation: int): +def transfer_symmetric_contractions( + source_dict: Dict[str, torch.Tensor], + target_dict: Dict[str, torch.Tensor], + max_L: int, + correlation: int, +): """Transfer symmetric contraction weights""" kmax_pairs = get_kmax_pairs(max_L, correlation) - logging.info(f"Using kmax pairs {kmax_pairs} for max_L={max_L}, correlation={correlation}") - + logging.info( + f"Using kmax pairs {kmax_pairs} for max_L={max_L}, correlation={correlation}" + ) + for i, kmax in kmax_pairs: - wm = torch.concatenate([ - source_dict[f'products.{i}.symmetric_contractions.contractions.{k}.weights{j}'] - for k in range(kmax+1) for j in ['_max','.0','.1']],dim=1) #.float() - target_dict[f'products.{i}.symmetric_contractions.weight'] = wm + wm = torch.concatenate( + [ + source_dict[ + f"products.{i}.symmetric_contractions.contractions.{k}.weights{j}" + ] + for k in range(kmax + 1) + for j in ["_max", ".0", ".1"] + ], + dim=1, + ) # .float() + target_dict[f"products.{i}.symmetric_contractions.weight"] = wm + -def transfer_weights(source_model: torch.nn.Module, target_model: torch.nn.Module, - max_L: int, correlation: int): +def transfer_weights( + source_model: torch.nn.Module, + target_model: torch.nn.Module, + max_L: int, + correlation: int, +): """Transfer weights with proper remapping""" # Get source state dict source_dict = source_model.state_dict() target_dict = target_model.state_dict() - + # Transfer main weights transfer_keys = get_transfer_keys() logging.info("Transferring main weights...") @@ -66,17 +86,21 @@ def transfer_weights(source_model: torch.nn.Module, target_model: torch.nn.Modul target_dict[key] = source_dict[key] else: logging.warning(f"Key {key} not found in source model") - + # Transfer symmetric contractions logging.info("Transferring symmetric contractions...") transfer_symmetric_contractions(source_dict, target_dict, max_L, correlation) transferred_keys = set(transfer_keys) - remaining_keys = set(source_dict.keys()) & set(target_dict.keys()) - transferred_keys - remaining_keys = {k for k in remaining_keys if 'symmetric_contraction' not in k} + remaining_keys = ( + set(source_dict.keys()) & set(target_dict.keys()) - transferred_keys + ) + remaining_keys = {k for k in remaining_keys if "symmetric_contraction" not in k} if remaining_keys: - logging.info(f"Found {len(remaining_keys)} additional matching keys to transfer") + logging.info( + f"Found {len(remaining_keys)} additional matching keys to transfer" + ) for key in remaining_keys: if source_dict[key].shape == target_dict[key].shape: logging.debug(f"Transferring additional key: {key}") @@ -88,22 +112,23 @@ def transfer_weights(source_model: torch.nn.Module, target_model: torch.nn.Modul ) # Transfer avg_num_neighbors for i in range(2): - target_model.interactions[i].avg_num_neighbors = source_model.interactions[i].avg_num_neighbors - + target_model.interactions[i].avg_num_neighbors = source_model.interactions[ + i + ].avg_num_neighbors + # Load state dict into target model target_model.load_state_dict(target_dict) + def run( input_model, - output_model, - device='cuda', - layout='mul_ir', - group='O3', - return_model=True + output_model="_cueq.model", + device="cuda", + return_model=True, ): # Setup logging logging.basicConfig(level=logging.INFO) - + # Load original model logging.info(f"Loading model from {input_model}") # check if input_model is a path or a model @@ -111,16 +136,16 @@ def run( source_model = torch.load(input_model, map_location=device) else: source_model = input_model - + # Extract configuration logging.info("Extracting model configuration") config = extract_config_mace_model(source_model) - + # Get max_L and correlation from config max_L = config["hidden_irreps"].lmax correlation = config["correlation"] logging.info(f"Extracted max_L={max_L}, correlation={correlation}") - + # Add cuequivariance config config["cueq_config"] = CuEquivarianceConfig( enabled=True, @@ -128,43 +153,49 @@ def run( group="O3_e3nn", optimize_all=True, ) - + # Create new model with cuequivariance config logging.info("Creating new model with cuequivariance settings") target_model = source_model.__class__(**config) - + # Transfer weights with proper remapping logging.info("Transferring weights with remapping...") transfer_weights(source_model, target_model, max_L, correlation) - + if return_model: return target_model - else: - # Save model - output_model = Path(input_model).parent / output_model - logging.info(f"Saving cuequivariance model to {output_model}") - torch.save(target_model, output_model) + + if isinstance(input_model, str): + base = os.path.splitext(input_model)[0] + output_model = f"{base}.{output_model}" + logging.info(f"Saving CuEq model to {output_model}") + torch.save(target_model, output_model) + return None def main(): parser = argparse.ArgumentParser() - parser.add_argument('input_model', help='Path to input MACE model') - parser.add_argument('output_model', help='Path to output cuequivariance model', default='cuequivariance_model.pt') - parser.add_argument('--device', default='cpu', help='Device to use') - parser.add_argument('--layout', default='mul_ir', choices=['mul_ir', 'ir_mul'], help='Memory layout for tensors') - parser.add_argument('--group', default='O3_e3nn', choices=['O3', 'O3_e3nn'], help='Symmetry group') - parser.add_argument('--return_model', action='store_false', help='Return model instead of saving to file') + parser.add_argument("input_model", help="Path to input MACE model") + parser.add_argument( + "--output_model", + help="Path to output cuequivariance model", + default="cueq_model.pt", + ) + parser.add_argument("--device", default="cpu", help="Device to use") + parser.add_argument( + "--return_model", + action="store_false", + help="Return model instead of saving to file", + ) args = parser.parse_args() - + run( input_model=args.input_model, output_model=args.output_model, device=args.device, - layout=args.layout, - group=args.group, - return_model=args.return_model + return_model=args.return_model, ) -if __name__ == '__main__': - main() \ No newline at end of file +if __name__ == "__main__": + main() diff --git a/mace/modules/wrapper_ops.py b/mace/modules/wrapper_ops.py index 87770d3d..44149f40 100644 --- a/mace/modules/wrapper_ops.py +++ b/mace/modules/wrapper_ops.py @@ -23,7 +23,7 @@ if CUET_AVAILABLE: class O3_e3nn(cue.O3): - def __mul__( # pylint: disable=no-self-argument + def __mul__( # pylint: disable=no-self-argument rep1: "O3_e3nn", rep2: "O3_e3nn" ) -> Iterator["O3_e3nn"]: return [O3_e3nn(l=ir.l, p=ir.p) for ir in cue.O3.__mul__(rep1, rep2)] @@ -40,7 +40,7 @@ def clebsch_gordan( ) return np.zeros((0, rep1.dim, rep2.dim, rep3.dim)) - def __lt__( # pylint: disable=no-self-argument + def __lt__( # pylint: disable=no-self-argument rep1: "O3_e3nn", rep2: "O3_e3nn" ) -> bool: rep2 = rep1._from(rep2) diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index cd201f2b..1f1be22d 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -176,9 +176,9 @@ def radial_to_transform(radial): scale = model.scale_shift.scale shift = model.scale_shift.shift try: - correlation = len( - model.products[0].symmetric_contractions.contractions[0].weights - ) + 1 + correlation = ( + len(model.products[0].symmetric_contractions.contractions[0].weights) + 1 + ) except AttributeError: correlation = model.products[0].symmetric_contractions.contraction_degree config = { diff --git a/setup.cfg b/setup.cfg index b5b2c7a2..7b13092a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -44,6 +44,8 @@ console_scripts = mace_finetuning = mace.cli.fine_tuning_select:main mace_convert_device = mace.cli.convert_device:main mace_select_head = mace.cli.select_head:main + mace_e3nn_cueq = mace.cli.convert_e3nn_cueq:main + mace_cueq_to_e3nn = mace.cli.convert_cueq_e3nn:main [options.extras_require] wandb = wandb diff --git a/tests/test_cueq.py b/tests/test_cueq.py index 9ab0a38b..302f772d 100644 --- a/tests/test_cueq.py +++ b/tests/test_cueq.py @@ -6,19 +6,20 @@ from e3nn import o3 from mace import data, modules, tools -from mace.modules.wrapper_ops import CuEquivarianceConfig -from mace.tools import torch_geometric -from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq from mace.cli.convert_cueq_e3nn import run as run_cueq_to_e3nn +from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq +from mace.tools import torch_geometric try: import cuequivariance as cue # pylint: disable=unused-import + CUET_AVAILABLE = True except ImportError: CUET_AVAILABLE = False torch.set_default_dtype(torch.float64) + @pytest.mark.skipif(not CUET_AVAILABLE, reason="cuequivariance not installed") class TestCueq: @pytest.fixture @@ -31,7 +32,9 @@ def model_config(self, interaction_cls_first, hidden_irreps) -> Dict[str, Any]: "num_bessel": 8, "num_polynomial_cutoff": 6, "max_ell": 3, - "interaction_cls": modules.interaction_classes["RealAgnosticResidualInteractionBlock"], + "interaction_cls": modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], "interaction_cls_first": interaction_cls_first, "num_interactions": 2, "num_elements": 1, @@ -50,13 +53,13 @@ def model_config(self, interaction_cls_first, hidden_irreps) -> Dict[str, Any]: @pytest.fixture def batch(self, device: str): from ase import build + table = tools.AtomicNumberTable([6]) atoms = build.bulk("C", "diamond", a=3.567, cubic=True) import numpy as np - displacement = np.random.uniform( - -0.1, 0.1, size=atoms.positions.shape - ) + + displacement = np.random.uniform(-0.1, 0.1, size=atoms.positions.shape) atoms.positions += displacement atoms_list = [atoms.repeat((2, 2, 2))] @@ -74,23 +77,27 @@ def batch(self, device: str): return batch.to(device).to_dict() @pytest.mark.parametrize("device", ["cuda"]) - @pytest.mark.parametrize("interaction_cls_first", [ - modules.interaction_classes["RealAgnosticResidualInteractionBlock"], - modules.interaction_classes["RealAgnosticInteractionBlock"], - modules.interaction_classes["RealAgnosticDensityInteractionBlock"], - ]) - @pytest.mark.parametrize("hidden_irreps", [ - o3.Irreps("32x0e + 32x1o"), - o3.Irreps("32x0e + 32x1o + 32x2e"), - o3.Irreps("32x0e"), - ]) + @pytest.mark.parametrize( + "interaction_cls_first", + [ + modules.interaction_classes["RealAgnosticResidualInteractionBlock"], + modules.interaction_classes["RealAgnosticInteractionBlock"], + modules.interaction_classes["RealAgnosticDensityInteractionBlock"], + ], + ) + @pytest.mark.parametrize( + "hidden_irreps", + [ + o3.Irreps("32x0e + 32x1o"), + o3.Irreps("32x0e + 32x1o + 32x2e"), + o3.Irreps("32x0e"), + ], + ) def test_bidirectional_conversion( - self, - model_config: Dict[str, Any], - batch: Dict[str, torch.Tensor], + self, + model_config: Dict[str, Any], + batch: Dict[str, torch.Tensor], device: str, - interaction_cls_first, - hidden_irreps ): torch.manual_seed(42) @@ -99,11 +106,11 @@ def test_bidirectional_conversion( model_e3nn = model_e3nn.to(device) # Convert E3nn to CuEq - model_cueq = run_e3nn_to_cueq(model_e3nn, None) + model_cueq = run_e3nn_to_cueq(model_e3nn) model_cueq = model_cueq.to(device) # Convert CuEq back to E3nn - model_e3nn_back = run_cueq_to_e3nn(model_cueq, None) + model_e3nn_back = run_cueq_to_e3nn(model_cueq) model_e3nn_back = model_e3nn_back.to(device) # Test forward pass equivalence @@ -131,7 +138,9 @@ def print_gradient_diff(name1, p1, name2, p2, conv_type): if p1.grad is not None and p1.grad.shape == p2.grad.shape: if name1.split(".", 2)[:2] == name2.split(".", 2)[:2]: error = torch.abs(p1.grad - p2.grad) - print(f"{conv_type} - Parameter {name1}/{name2}, Max error: {error.max()}") + print( + f"{conv_type} - Parameter {name1}/{name2}, Max error: {error.max()}" + ) torch.testing.assert_close(p1.grad, p2.grad, atol=1e-5, rtol=1e-10) # E3nn to CuEq gradients @@ -144,10 +153,14 @@ def print_gradient_diff(name1, p1, name2, p2, conv_type): for (name_cueq, p_cueq), (name_e3nn_back, p_e3nn_back) in zip( model_cueq.named_parameters(), model_e3nn_back.named_parameters() ): - print_gradient_diff(name_cueq, p_cueq, name_e3nn_back, p_e3nn_back, "CuEq->E3nn") + print_gradient_diff( + name_cueq, p_cueq, name_e3nn_back, p_e3nn_back, "CuEq->E3nn" + ) # Full circle comparison (E3nn -> E3nn) for (name_e3nn, p_e3nn), (name_e3nn_back, p_e3nn_back) in zip( model_e3nn.named_parameters(), model_e3nn_back.named_parameters() ): - print_gradient_diff(name_e3nn, p_e3nn, name_e3nn_back, p_e3nn_back, "Full circle") \ No newline at end of file + print_gradient_diff( + name_e3nn, p_e3nn, name_e3nn_back, p_e3nn_back, "Full circle" + ) From df0bbf341c5476af1e7781dd865b51c74a4e0dde Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Tue, 19 Nov 2024 10:59:23 +0000 Subject: [PATCH 15/17] swap to types.MethodType wrapper --- mace/modules/wrapper_ops.py | 55 +++++++++------ tests/test_cueq.py | 133 ++++++++++++++++++++++++------------ 2 files changed, 125 insertions(+), 63 deletions(-) diff --git a/mace/modules/wrapper_ops.py b/mace/modules/wrapper_ops.py index 44149f40..bde46598 100644 --- a/mace/modules/wrapper_ops.py +++ b/mace/modules/wrapper_ops.py @@ -4,6 +4,7 @@ import dataclasses import itertools +import types from typing import Iterator, List, Optional import numpy as np @@ -98,12 +99,18 @@ def __new__( and cueq_config.enabled and (cueq_config.optimize_all or cueq_config.optimize_linear) ): - return cuet.Linear( + instance = cuet.Linear( cue.Irreps(cueq_config.group, irreps_in), cue.Irreps(cueq_config.group, irreps_out), layout=cueq_config.layout, shared_weights=shared_weights, ) + instance._original_forward = instance.forward + def cuet_forward(self, x: torch.Tensor) -> torch.Tensor: + return self._original_forward(x, use_fallback=None) + instance.forward = types.MethodType(cuet_forward, instance) + return instance + return o3.Linear( irreps_in, irreps_out, @@ -114,7 +121,7 @@ def __new__( class TensorProduct: """Wrapper around o3.TensorProduct/cuet.ChannelwiseTensorProduct""" - + def __new__( cls, irreps_in1: o3.Irreps, @@ -131,7 +138,7 @@ def __new__( and cueq_config.enabled and (cueq_config.optimize_all or cueq_config.optimize_channelwise) ): - return cuet.ChannelWiseTensorProduct( + instance = cuet.ChannelWiseTensorProduct( cue.Irreps(cueq_config.group, irreps_in1), cue.Irreps(cueq_config.group, irreps_in2), cue.Irreps(cueq_config.group, irreps_out), @@ -139,6 +146,12 @@ def __new__( shared_weights=shared_weights, internal_weights=internal_weights, ) + instance._original_forward = instance.forward + def cuet_forward(self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor) -> torch.Tensor: + return self._original_forward(x, y, z, use_fallback=None) + instance.forward = types.MethodType(cuet_forward, instance) + return instance + return o3.TensorProduct( irreps_in1, irreps_in2, @@ -152,6 +165,7 @@ def __new__( class FullyConnectedTensorProduct: """Wrapper around o3.FullyConnectedTensorProduct/cuet.FullyConnectedTensorProduct""" + def __new__( cls, irreps_in1: o3.Irreps, @@ -167,7 +181,7 @@ def __new__( and cueq_config.enabled and (cueq_config.optimize_all or cueq_config.optimize_fctp) ): - return cuet.FullyConnectedTensorProduct( + instance = cuet.FullyConnectedTensorProduct( cue.Irreps(cueq_config.group, irreps_in1), cue.Irreps(cueq_config.group, irreps_in2), cue.Irreps(cueq_config.group, irreps_out), @@ -175,6 +189,12 @@ def __new__( shared_weights=shared_weights, internal_weights=internal_weights, ) + instance._original_forward = instance.forward + def cuet_forward(self, x: torch.Tensor, attrs: torch.Tensor) -> torch.Tensor: + return self._original_forward(x, attrs, use_fallback=None) + instance.forward = types.MethodType(cuet_forward, instance) + return instance + return o3.FullyConnectedTensorProduct( irreps_in1, irreps_in2, @@ -187,21 +207,6 @@ def __new__( class SymmetricContractionWrapper: """Wrapper around SymmetricContraction/cuet.SymmetricContraction""" - class CuetForward: - def __init__(self, instance, layout): - self.instance = instance - self.layout = layout - self.original_forward = instance.forward - - def __call__(self, x: torch.Tensor, attrs: torch.Tensor) -> torch.Tensor: - if self.layout == cue.mul_ir: - x = torch.transpose(x, 1, 2) - index_attrs = torch.nonzero(attrs)[:, 1].int() - return self.original_forward( - x.flatten(1), - index_attrs, - ) - def __new__( cls, irreps_in: o3.Irreps, @@ -227,7 +232,17 @@ def __new__( dtype=torch.get_default_dtype(), math_dtype=torch.get_default_dtype(), ) - instance.forward = cls.CuetForward(instance, cueq_config.layout) + instance._original_forward = instance.forward + instance.layout = cueq_config.layout + def cuet_forward(self, x: torch.Tensor, attrs: torch.Tensor) -> torch.Tensor: + if self.layout == cue.mul_ir: + x = torch.transpose(x, 1, 2) + index_attrs = torch.nonzero(attrs)[:, 1].int() + return self._original_forward( + x.flatten(1), + index_attrs, + ) + instance.forward = types.MethodType(cuet_forward, instance) return instance return SymmetricContraction( diff --git a/tests/test_cueq.py b/tests/test_cueq.py index 302f772d..c2b1d048 100644 --- a/tests/test_cueq.py +++ b/tests/test_cueq.py @@ -6,6 +6,7 @@ from e3nn import o3 from mace import data, modules, tools +from e3nn.util import jit from mace.cli.convert_cueq_e3nn import run as run_cueq_to_e3nn from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq from mace.tools import torch_geometric @@ -88,12 +89,84 @@ def batch(self, device: str): @pytest.mark.parametrize( "hidden_irreps", [ - o3.Irreps("32x0e + 32x1o"), - o3.Irreps("32x0e + 32x1o + 32x2e"), + # o3.Irreps("32x0e + 32x1o"), + # o3.Irreps("32x0e + 32x1o + 32x2e"), o3.Irreps("32x0e"), ], ) - def test_bidirectional_conversion( + # def test_bidirectional_conversion( + # self, + # model_config: Dict[str, Any], + # batch: Dict[str, torch.Tensor], + # device: str, + # ): + # torch.manual_seed(42) + + # # Create original E3nn model + # model_e3nn = modules.ScaleShiftMACE(**model_config) + # model_e3nn = model_e3nn.to(device) + + # # Convert E3nn to CuEq + # model_cueq = run_e3nn_to_cueq(model_e3nn) + # model_cueq = model_cueq.to(device) + + # # Convert CuEq back to E3nn + # model_e3nn_back = run_cueq_to_e3nn(model_cueq) + # model_e3nn_back = model_e3nn_back.to(device) + + # # Test forward pass equivalence + # out_e3nn = model_e3nn(batch, training=True) + # out_cueq = model_cueq(batch, training=True) + # out_e3nn_back = model_e3nn_back(batch, training=True) + + # # Check outputs match for both conversions + # torch.testing.assert_close(out_e3nn["energy"], out_cueq["energy"]) + # torch.testing.assert_close(out_cueq["energy"], out_e3nn_back["energy"]) + # torch.testing.assert_close(out_e3nn["forces"], out_cueq["forces"]) + # torch.testing.assert_close(out_cueq["forces"], out_e3nn_back["forces"]) + + # # Test backward pass equivalence + # loss_e3nn = out_e3nn["energy"].sum() + # loss_cueq = out_cueq["energy"].sum() + # loss_e3nn_back = out_e3nn_back["energy"].sum() + + # loss_e3nn.backward() + # loss_cueq.backward() + # loss_e3nn_back.backward() + + # # Compare gradients for all conversions + # def print_gradient_diff(name1, p1, name2, p2, conv_type): + # if p1.grad is not None and p1.grad.shape == p2.grad.shape: + # if name1.split(".", 2)[:2] == name2.split(".", 2)[:2]: + # error = torch.abs(p1.grad - p2.grad) + # print( + # f"{conv_type} - Parameter {name1}/{name2}, Max error: {error.max()}" + # ) + # torch.testing.assert_close(p1.grad, p2.grad, atol=1e-5, rtol=1e-10) + + # # E3nn to CuEq gradients + # for (name_e3nn, p_e3nn), (name_cueq, p_cueq) in zip( + # model_e3nn.named_parameters(), model_cueq.named_parameters() + # ): + # print_gradient_diff(name_e3nn, p_e3nn, name_cueq, p_cueq, "E3nn->CuEq") + + # # CuEq to E3nn gradients + # for (name_cueq, p_cueq), (name_e3nn_back, p_e3nn_back) in zip( + # model_cueq.named_parameters(), model_e3nn_back.named_parameters() + # ): + # print_gradient_diff( + # name_cueq, p_cueq, name_e3nn_back, p_e3nn_back, "CuEq->E3nn" + # ) + + # # Full circle comparison (E3nn -> E3nn) + # for (name_e3nn, p_e3nn), (name_e3nn_back, p_e3nn_back) in zip( + # model_e3nn.named_parameters(), model_e3nn_back.named_parameters() + # ): + # print_gradient_diff( + # name_e3nn, p_e3nn, name_e3nn_back, p_e3nn_back, "Full circle" + # ) + + def test_jit_compile( self, model_config: Dict[str, Any], batch: Dict[str, torch.Tensor], @@ -113,54 +186,28 @@ def test_bidirectional_conversion( model_e3nn_back = run_cueq_to_e3nn(model_cueq) model_e3nn_back = model_e3nn_back.to(device) + # # Compile all models + model_e3nn_compiled = jit.compile(model_e3nn) + model_cueq_compiled = jit.compile(model_cueq) + model_e3nn_back_compiled = jit.compile(model_e3nn_back) + # Test forward pass equivalence out_e3nn = model_e3nn(batch, training=True) out_cueq = model_cueq(batch, training=True) out_e3nn_back = model_e3nn_back(batch, training=True) + out_e3nn_compiled = model_e3nn_compiled(batch, training=True) + out_cueq_compiled = model_cueq_compiled(batch, training=True) + out_e3nn_back_compiled = model_e3nn_back_compiled(batch, training=True) + # Check outputs match for both conversions torch.testing.assert_close(out_e3nn["energy"], out_cueq["energy"]) torch.testing.assert_close(out_cueq["energy"], out_e3nn_back["energy"]) torch.testing.assert_close(out_e3nn["forces"], out_cueq["forces"]) torch.testing.assert_close(out_cueq["forces"], out_e3nn_back["forces"]) - # Test backward pass equivalence - loss_e3nn = out_e3nn["energy"].sum() - loss_cueq = out_cueq["energy"].sum() - loss_e3nn_back = out_e3nn_back["energy"].sum() - - loss_e3nn.backward() - loss_cueq.backward() - loss_e3nn_back.backward() - - # Compare gradients for all conversions - def print_gradient_diff(name1, p1, name2, p2, conv_type): - if p1.grad is not None and p1.grad.shape == p2.grad.shape: - if name1.split(".", 2)[:2] == name2.split(".", 2)[:2]: - error = torch.abs(p1.grad - p2.grad) - print( - f"{conv_type} - Parameter {name1}/{name2}, Max error: {error.max()}" - ) - torch.testing.assert_close(p1.grad, p2.grad, atol=1e-5, rtol=1e-10) - - # E3nn to CuEq gradients - for (name_e3nn, p_e3nn), (name_cueq, p_cueq) in zip( - model_e3nn.named_parameters(), model_cueq.named_parameters() - ): - print_gradient_diff(name_e3nn, p_e3nn, name_cueq, p_cueq, "E3nn->CuEq") - - # CuEq to E3nn gradients - for (name_cueq, p_cueq), (name_e3nn_back, p_e3nn_back) in zip( - model_cueq.named_parameters(), model_e3nn_back.named_parameters() - ): - print_gradient_diff( - name_cueq, p_cueq, name_e3nn_back, p_e3nn_back, "CuEq->E3nn" - ) - - # Full circle comparison (E3nn -> E3nn) - for (name_e3nn, p_e3nn), (name_e3nn_back, p_e3nn_back) in zip( - model_e3nn.named_parameters(), model_e3nn_back.named_parameters() - ): - print_gradient_diff( - name_e3nn, p_e3nn, name_e3nn_back, p_e3nn_back, "Full circle" - ) + torch.testing.assert_close(out_e3nn["energy"], out_e3nn_compiled["energy"]) + torch.testing.assert_close(out_cueq["energy"], out_cueq_compiled["energy"]) + torch.testing.assert_close(out_e3nn_back["energy"], out_e3nn_back_compiled["energy"]) + torch.testing.assert_close(out_e3nn["forces"], out_e3nn_compiled["forces"]) + torch.testing.assert_close(out_cueq["forces"], out_cueq_compiled["forces"]) \ No newline at end of file From 2e8e5c50c4ef3a13f9f08e10c0bdf04a6152bca7 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Tue, 19 Nov 2024 11:10:17 +0000 Subject: [PATCH 16/17] Update wrapper_ops.py --- mace/modules/wrapper_ops.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mace/modules/wrapper_ops.py b/mace/modules/wrapper_ops.py index bde46598..437d106c 100644 --- a/mace/modules/wrapper_ops.py +++ b/mace/modules/wrapper_ops.py @@ -241,6 +241,7 @@ def cuet_forward(self, x: torch.Tensor, attrs: torch.Tensor) -> torch.Tensor: return self._original_forward( x.flatten(1), index_attrs, + use_fallback=None, ) instance.forward = types.MethodType(cuet_forward, instance) return instance From ef42dba4da105bfd6d722242e753ff9f0fc82b3e Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Tue, 19 Nov 2024 12:06:23 +0000 Subject: [PATCH 17/17] add test_run_train_cueq --- mace/cli/run_train.py | 9 +- mace/tools/arg_parser.py | 46 +------- mace/tools/model_script_utils.py | 11 -- tests/test_cueq.py | 174 +++++++++++++++---------------- tests/test_run_train.py | 67 ++++++++++++ 5 files changed, 163 insertions(+), 144 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 3813b055..4196e083 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -54,6 +54,8 @@ from mace.tools.slurm_distributed import DistributedEnvironment from mace.tools.tables_utils import create_error_table from mace.tools.utils import AtomicNumberTable +from mace.cli.convert_cueq_e3nn import run as run_cueq_to_e3nn +from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq def main() -> None: @@ -600,7 +602,10 @@ def run(args: argparse.Namespace) -> None: if args.wandb: setup_wandb(args) - + if args.enable_cueq: + logging.info("Converting model to CUEQ for accelerated training") + assert args.model in ["MACE", "ScaleShiftMACE"], "Model must be MACE or ScaleShiftMACE" + model = run_e3nn_to_cueq(model) if args.distributed: distributed_model = DDP(model, device_ids=[local_rank]) else: @@ -752,6 +757,8 @@ def run(args: argparse.Namespace) -> None: if rank == 0: # Save entire model + if args.enable_cueq: + model = run_cueq_to_e3nn(model) if swa_eval: model_path = Path(args.checkpoints_dir) / (tag + "_stagetwo.model") else: diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index f513b9f9..07e02e49 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -662,55 +662,11 @@ def build_default_arg_parser() -> argparse.ArgumentParser: ) # option for cuequivariance acceleration parser.add_argument( - "--cue_enabled", + "--enable_cueq", help="Enable cuequivariance acceleration", type=str2bool, default=False, ) - parser.add_argument( - "--cue_layout", - help="Memory layout for cuequivariance tensors", - type=str, - choices=["mul_ir", "ir_mul"], - default="mul_ir", - ) - parser.add_argument( - "--cue_group", - help="Symmetry group for cuequivariance", - type=str, - choices=["O3_e3nn, O3"], - default="O3_e3nn", - ) - parser.add_argument( - "--cue_optimize_all", - help="Enable all cuequivariance optimizations", - type=str2bool, - default=False, - ) - parser.add_argument( - "--cue_optimize_linear", - help="Enable cuequivariance linear layer optimization", - type=str2bool, - default=False, - ) - parser.add_argument( - "--cue_optimize_channelwise", - help="Enable cuequivariance channelwise optimization", - type=str2bool, - default=False, - ) - parser.add_argument( - "--cue_optimize_symmetric", - help="Enable cuequivariance symmetric contraction optimization", - type=str2bool, - default=False, - ) - parser.add_argument( - "--cue_optimize_fctp", - help="Enable cuequivariance fully connected tensor product optimization", - type=str2bool, - default=False, - ) # options for using Weights and Biases for experiment tracking # to install see https://wandb.ai parser.add_argument( diff --git a/mace/tools/model_script_utils.py b/mace/tools/model_script_utils.py index 25ef0a93..1d647683 100644 --- a/mace/tools/model_script_utils.py +++ b/mace/tools/model_script_utils.py @@ -29,16 +29,6 @@ def configure_model( logging.info( f"During training the following quantities will be reported: {', '.join([f'{report}' for report, value in output_args.items() if value])}" ) - cueq_config = CuEquivarianceConfig( - enabled=args.cue_enabled, - layout=args.cue_layout, - group=args.cue_group, - optimize_all=args.cue_optimize_all, - optimize_linear=args.cue_optimize_linear, - optimize_channelwise=args.cue_optimize_channelwise, - optimize_symmetric=args.cue_optimize_symmetric, - optimize_fctp=args.cue_optimize_fctp, - ) logging.info("===========MODEL DETAILS===========") if args.scaling == "no_scaling": @@ -120,7 +110,6 @@ def configure_model( atomic_energies=atomic_energies, avg_num_neighbors=args.avg_num_neighbors, atomic_numbers=z_table.zs, - cueq_config=cueq_config, ) model_config_foundation = None diff --git a/tests/test_cueq.py b/tests/test_cueq.py index c2b1d048..79bacc6c 100644 --- a/tests/test_cueq.py +++ b/tests/test_cueq.py @@ -94,79 +94,7 @@ def batch(self, device: str): o3.Irreps("32x0e"), ], ) - # def test_bidirectional_conversion( - # self, - # model_config: Dict[str, Any], - # batch: Dict[str, torch.Tensor], - # device: str, - # ): - # torch.manual_seed(42) - - # # Create original E3nn model - # model_e3nn = modules.ScaleShiftMACE(**model_config) - # model_e3nn = model_e3nn.to(device) - - # # Convert E3nn to CuEq - # model_cueq = run_e3nn_to_cueq(model_e3nn) - # model_cueq = model_cueq.to(device) - - # # Convert CuEq back to E3nn - # model_e3nn_back = run_cueq_to_e3nn(model_cueq) - # model_e3nn_back = model_e3nn_back.to(device) - - # # Test forward pass equivalence - # out_e3nn = model_e3nn(batch, training=True) - # out_cueq = model_cueq(batch, training=True) - # out_e3nn_back = model_e3nn_back(batch, training=True) - - # # Check outputs match for both conversions - # torch.testing.assert_close(out_e3nn["energy"], out_cueq["energy"]) - # torch.testing.assert_close(out_cueq["energy"], out_e3nn_back["energy"]) - # torch.testing.assert_close(out_e3nn["forces"], out_cueq["forces"]) - # torch.testing.assert_close(out_cueq["forces"], out_e3nn_back["forces"]) - - # # Test backward pass equivalence - # loss_e3nn = out_e3nn["energy"].sum() - # loss_cueq = out_cueq["energy"].sum() - # loss_e3nn_back = out_e3nn_back["energy"].sum() - - # loss_e3nn.backward() - # loss_cueq.backward() - # loss_e3nn_back.backward() - - # # Compare gradients for all conversions - # def print_gradient_diff(name1, p1, name2, p2, conv_type): - # if p1.grad is not None and p1.grad.shape == p2.grad.shape: - # if name1.split(".", 2)[:2] == name2.split(".", 2)[:2]: - # error = torch.abs(p1.grad - p2.grad) - # print( - # f"{conv_type} - Parameter {name1}/{name2}, Max error: {error.max()}" - # ) - # torch.testing.assert_close(p1.grad, p2.grad, atol=1e-5, rtol=1e-10) - - # # E3nn to CuEq gradients - # for (name_e3nn, p_e3nn), (name_cueq, p_cueq) in zip( - # model_e3nn.named_parameters(), model_cueq.named_parameters() - # ): - # print_gradient_diff(name_e3nn, p_e3nn, name_cueq, p_cueq, "E3nn->CuEq") - - # # CuEq to E3nn gradients - # for (name_cueq, p_cueq), (name_e3nn_back, p_e3nn_back) in zip( - # model_cueq.named_parameters(), model_e3nn_back.named_parameters() - # ): - # print_gradient_diff( - # name_cueq, p_cueq, name_e3nn_back, p_e3nn_back, "CuEq->E3nn" - # ) - - # # Full circle comparison (E3nn -> E3nn) - # for (name_e3nn, p_e3nn), (name_e3nn_back, p_e3nn_back) in zip( - # model_e3nn.named_parameters(), model_e3nn_back.named_parameters() - # ): - # print_gradient_diff( - # name_e3nn, p_e3nn, name_e3nn_back, p_e3nn_back, "Full circle" - # ) - - def test_jit_compile( + def test_bidirectional_conversion( self, model_config: Dict[str, Any], batch: Dict[str, torch.Tensor], @@ -186,28 +114,100 @@ def test_jit_compile( model_e3nn_back = run_cueq_to_e3nn(model_cueq) model_e3nn_back = model_e3nn_back.to(device) - # # Compile all models - model_e3nn_compiled = jit.compile(model_e3nn) - model_cueq_compiled = jit.compile(model_cueq) - model_e3nn_back_compiled = jit.compile(model_e3nn_back) - # Test forward pass equivalence out_e3nn = model_e3nn(batch, training=True) out_cueq = model_cueq(batch, training=True) out_e3nn_back = model_e3nn_back(batch, training=True) - out_e3nn_compiled = model_e3nn_compiled(batch, training=True) - out_cueq_compiled = model_cueq_compiled(batch, training=True) - out_e3nn_back_compiled = model_e3nn_back_compiled(batch, training=True) - # Check outputs match for both conversions torch.testing.assert_close(out_e3nn["energy"], out_cueq["energy"]) torch.testing.assert_close(out_cueq["energy"], out_e3nn_back["energy"]) torch.testing.assert_close(out_e3nn["forces"], out_cueq["forces"]) torch.testing.assert_close(out_cueq["forces"], out_e3nn_back["forces"]) - torch.testing.assert_close(out_e3nn["energy"], out_e3nn_compiled["energy"]) - torch.testing.assert_close(out_cueq["energy"], out_cueq_compiled["energy"]) - torch.testing.assert_close(out_e3nn_back["energy"], out_e3nn_back_compiled["energy"]) - torch.testing.assert_close(out_e3nn["forces"], out_e3nn_compiled["forces"]) - torch.testing.assert_close(out_cueq["forces"], out_cueq_compiled["forces"]) \ No newline at end of file + # Test backward pass equivalence + loss_e3nn = out_e3nn["energy"].sum() + loss_cueq = out_cueq["energy"].sum() + loss_e3nn_back = out_e3nn_back["energy"].sum() + + loss_e3nn.backward() + loss_cueq.backward() + loss_e3nn_back.backward() + + # Compare gradients for all conversions + def print_gradient_diff(name1, p1, name2, p2, conv_type): + if p1.grad is not None and p1.grad.shape == p2.grad.shape: + if name1.split(".", 2)[:2] == name2.split(".", 2)[:2]: + error = torch.abs(p1.grad - p2.grad) + print( + f"{conv_type} - Parameter {name1}/{name2}, Max error: {error.max()}" + ) + torch.testing.assert_close(p1.grad, p2.grad, atol=1e-5, rtol=1e-10) + + # E3nn to CuEq gradients + for (name_e3nn, p_e3nn), (name_cueq, p_cueq) in zip( + model_e3nn.named_parameters(), model_cueq.named_parameters() + ): + print_gradient_diff(name_e3nn, p_e3nn, name_cueq, p_cueq, "E3nn->CuEq") + + # CuEq to E3nn gradients + for (name_cueq, p_cueq), (name_e3nn_back, p_e3nn_back) in zip( + model_cueq.named_parameters(), model_e3nn_back.named_parameters() + ): + print_gradient_diff( + name_cueq, p_cueq, name_e3nn_back, p_e3nn_back, "CuEq->E3nn" + ) + + # Full circle comparison (E3nn -> E3nn) + for (name_e3nn, p_e3nn), (name_e3nn_back, p_e3nn_back) in zip( + model_e3nn.named_parameters(), model_e3nn_back.named_parameters() + ): + print_gradient_diff( + name_e3nn, p_e3nn, name_e3nn_back, p_e3nn_back, "Full circle" + ) + + # def test_jit_compile( + # self, + # model_config: Dict[str, Any], + # batch: Dict[str, torch.Tensor], + # device: str, + # ): + # torch.manual_seed(42) + + # # Create original E3nn model + # model_e3nn = modules.ScaleShiftMACE(**model_config) + # model_e3nn = model_e3nn.to(device) + + # # Convert E3nn to CuEq + # model_cueq = run_e3nn_to_cueq(model_e3nn) + # model_cueq = model_cueq.to(device) + + # # Convert CuEq back to E3nn + # model_e3nn_back = run_cueq_to_e3nn(model_cueq) + # model_e3nn_back = model_e3nn_back.to(device) + + # # # Compile all models + # model_e3nn_compiled = jit.compile(model_e3nn) + # model_cueq_compiled = jit.compile(model_cueq) + # model_e3nn_back_compiled = jit.compile(model_e3nn_back) + + # # Test forward pass equivalence + # out_e3nn = model_e3nn(batch, training=True) + # out_cueq = model_cueq(batch, training=True) + # out_e3nn_back = model_e3nn_back(batch, training=True) + + # out_e3nn_compiled = model_e3nn_compiled(batch, training=True) + # out_cueq_compiled = model_cueq_compiled(batch, training=True) + # out_e3nn_back_compiled = model_e3nn_back_compiled(batch, training=True) + + # # Check outputs match for both conversions + # torch.testing.assert_close(out_e3nn["energy"], out_cueq["energy"]) + # torch.testing.assert_close(out_cueq["energy"], out_e3nn_back["energy"]) + # torch.testing.assert_close(out_e3nn["forces"], out_cueq["forces"]) + # torch.testing.assert_close(out_cueq["forces"], out_e3nn_back["forces"]) + + # torch.testing.assert_close(out_e3nn["energy"], out_e3nn_compiled["energy"]) + # torch.testing.assert_close(out_cueq["energy"], out_cueq_compiled["energy"]) + # torch.testing.assert_close(out_e3nn_back["energy"], out_e3nn_back_compiled["energy"]) + # torch.testing.assert_close(out_e3nn["forces"], out_e3nn_compiled["forces"]) + # torch.testing.assert_close(out_cueq["forces"], out_cueq_compiled["forces"]) \ No newline at end of file diff --git a/tests/test_run_train.py b/tests/test_run_train.py index ca196c47..2dbd857b 100644 --- a/tests/test_run_train.py +++ b/tests/test_run_train.py @@ -847,3 +847,70 @@ def test_run_train_multihead_replay_custum_finetuning( assert len(Es) == len(fitting_configs) assert all(isinstance(E, float) for E in Es) assert len(set(Es)) > 1 # Ens + +def test_run_train_cueq(tmp_path, fitting_configs): + ase.io.write(tmp_path / "fit.xyz", fitting_configs) + + mace_params = _mace_params.copy() + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["train_file"] = tmp_path / "fit.xyz" + mace_params["enable_cueq"] = True + + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 + + calc = MACECalculator(model_paths=tmp_path / "MACE.model", device="cpu") + + Es = [] + for at in fitting_configs: + at.calc = calc + Es.append(at.get_potential_energy()) + + print("Es", Es) + # from a run on 04/06/2024 on stress_bugfix 967f0bfb6490086599da247874b24595d149caa7 + ref_Es = [ + 0.0, + 0.0, + -0.039181344585828524, + -0.0915223395136733, + -0.14953484236456582, + -0.06662480820063998, + -0.09983737353050133, + 0.12477442296789745, + -0.06486086271762856, + -0.1460607988519944, + 0.12886334908465508, + -0.14000990081920373, + -0.05319886578958313, + 0.07780520158391, + -0.08895480281886901, + -0.15474719614734422, + 0.007756765146527644, + -0.044879267197498685, + -0.036065736712447574, + -0.24413743841886623, + -0.0838104612106429, + -0.14751978636626545, + ] + + assert np.allclose(Es, ref_Es) \ No newline at end of file