Skip to content

Commit

Permalink
bug: fix spin nlist in spin_model (deepmodeling#3718)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Expanded backend support to include PyTorch and DP in the energy spin
training documentation.
- Updated spin settings for TensorFlow and PyTorch/DP in the energy spin
training documentation.
- Added new sections for loss functions and data preparation in the
energy spin training documentation.

- **Bug Fixes**
- Adjusted test conditions and initialization parameters in various
model tests to align with updated functionalities.

- **Tests**
- Increased selection values in permutation tests to enhance test
coverage and reliability.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Han Wang <[email protected]>
  • Loading branch information
2 people authored and Mathieu Taillefumier committed Sep 18, 2024
1 parent 9ac48c3 commit b8d8c9f
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 22 deletions.
19 changes: 12 additions & 7 deletions deepmd/pt/model/model/spin_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,15 +170,20 @@ def extend_nlist(extended_atype, nlist):
nlist_shift = nlist + nall
nlist[~nlist_mask] = -1
nlist_shift[~nlist_mask] = -1
self_spin = torch.arange(0, nloc, dtype=nlist.dtype, device=nlist.device) + nall
self_spin = self_spin.view(1, -1, 1).expand(nframes, -1, -1)
# self spin + real neighbor + virtual neighbor
self_real = (
torch.arange(0, nloc, dtype=nlist.dtype, device=nlist.device)
.view(1, -1, 1)
.expand(nframes, -1, -1)
)
self_spin = self_real + nall
# real atom's neighbors: self spin + real neighbor + virtual neighbor
# nf x nloc x (1 + nnei + nnei)
real_nlist = torch.cat([self_spin, nlist, nlist_shift], dim=-1)
# spin atom's neighbors: real + real neighbor + virtual neighbor
# nf x nloc x (1 + nnei + nnei)
extended_nlist = torch.cat([self_spin, nlist, nlist_shift], dim=-1)
spin_nlist = torch.cat([self_real, nlist, nlist_shift], dim=-1)
# nf x (nloc + nloc) x (1 + nnei + nnei)
extended_nlist = torch.cat(
[extended_nlist, -1 * torch.ones_like(extended_nlist)], dim=-2
)
extended_nlist = torch.cat([real_nlist, spin_nlist], dim=-2)
# update the index for switch
first_part_index = (nloc <= extended_nlist) & (extended_nlist < nall)
second_part_index = (nall <= extended_nlist) & (extended_nlist < (nall + nloc))
Expand Down
63 changes: 58 additions & 5 deletions doc/model/train-energy-spin.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,35 @@
# Fit spin energy {{ tensorflow_icon }}
# Fit spin energy {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }}

:::{note}
**Supported backends**: TensorFlow {{ tensorflow_icon }}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, DP {{ dpmodel_icon }}
:::

In this section, we will take `$deepmd_source_dir/examples/NiO/se_e2_a/input.json` as an example of the input file.
To train a model that takes additional spin information as input, you only need to modify the following sections to define the spin-specific settings,
keeping other sections the same as the normal energy model's input script.

:::{warning}
Note that when adding spin into the model, there will be some implicit modifications automatically done by the program:

- In the TensorFlow backend, the `se_e2_a` descriptor will treat those atom types with spin as new (virtual) types,
and duplicate their corresponding selected numbers of neighbors ({ref}`sel <model/descriptor[se_e2_a]/sel>`) from their real atom types.
- In the PyTorch backend, if spin settings are added, all the types (with or without spin) will have their virtual types.
The `se_e2_a` descriptor will thus double the {ref}`sel <model/descriptor[se_e2_a]/sel>` list,
while in other descriptors with mixed types (such as `dpa1` or `dpa2`), the sel number will not be changed for clarity.
If you are using descriptors with mixed types, to achieve better performance,
you should manually extend your sel number (maybe double) depending on the balance between performance and efficiency.
:::

## Spin

The construction of the fitting net is give by section {ref}`spin <model/spin>`
The spin settings are given by the {ref}`spin <model/spin>` section, which sets the magnetism for each type of atoms as described in the following sections.

:::{note}
Note that the construction of spin settings is different between TensorFlow and PyTorch/DP.
:::

### Spin settings in TensorFlow

The implementation in TensorFlow only supports `se_e2_a` descriptor. See examples in `$deepmd_source_dir/examples/spin/se_e2_a/input_tf.json`, the {ref}`spin <model/spin>` section is defined as the following:

```json
"spin" : {
Expand All @@ -18,10 +39,38 @@ The construction of the fitting net is give by section {ref}`spin <model/spin>`
},
```

- {ref}`use_spin <model/spin[ener_spin]/use_spin>` determines whether to turn on the magnetism of the atoms.The index of this option matches option `type_map <model/type_map>`.
- {ref}`use_spin <model/spin[ener_spin]/use_spin>` is a list of boolean values indicating whether to use atomic spin for each atom type.
True for spin and False for not. The index of this option matches option `type_map <model/type_map>`.
- {ref}`virtual_len <model/spin[ener_spin]/virtual_len>` specifies the distance between virtual atom and the belonging real atom.
- {ref}`spin_norm <model/spin[ener_spin]/spin_norm>` gives the magnitude of the magnetic moment for each magnatic atom.

### Spin settings in PyTorch/DP

In PyTorch/DP, the spin implementation is more flexible and so far supports the following descriptors:

- `se_e2_a`
- `dpa1`(`se_atten`)
- `dpa2`

See `se_e2_a` examples in `$deepmd_source_dir/examples/spin/se_e2_a/input_torch.json`, the {ref}`spin <model/spin>` section is defined as the following with a much more clear interface:

```json
"spin": {
"use_spin": [true, false],
"virtual_scale": [0.3140]
},
```

- {ref}`use_spin <model/spin[ener_spin]/use_spin>` is a list of boolean values indicating whether to use atomic spin for each atom type, or a list of type indexes that use atomic spin.
The index of this option matches option `type_map <model/type_map>`.
- {ref}`virtual_len <model/spin[ener_spin]/virtual_scale>` defines the scaling factor to determine the virtual distance
between a virtual atom representing spin and its corresponding real atom
for each atom type with spin. This factor is defined as the virtual distance
divided by the magnitude of atomic spin for each atom type with spin.
The virtual coordinate is defined as the real coordinate plus spin \* virtual_scale.
List of float values with shape of `ntypes` or `ntypes_spin` or one single float value for all types,
only used when {ref}`use_spin <model/spin[ener_spin]/use_spin>` is True for each atom type.

## Spin Loss

The spin loss function $L$ for training energy is given by
Expand Down Expand Up @@ -59,3 +108,7 @@ The {ref}`loss <loss>` section in the `input.json` is
The options {ref}`start_pref_e <loss[ener_spin]/start_pref_e>`, {ref}`limit_pref_e <loss[ener_spin]/limit_pref_e>`, {ref}`start_pref_fr <loss[ener_spin]/start_pref_fr>`, {ref}`limit_pref_fm <loss[ener_spin]/limit_pref_fm>`, {ref}`start_pref_v <loss[ener_spin]/start_pref_v>` and {ref}`limit_pref_v <loss[ener_spin]/limit_pref_v>` determine the start and limit prefactors of energy, atomic force, magnatic force and virial, respectively.

If one does not want to train with virial, then he/she may set the virial prefactors {ref}`start_pref_v <loss[ener_spin]/start_pref_v>` and {ref}`limit_pref_v <loss[ener_spin]/limit_pref_v>` to 0.

## Data preparation

(Need a documentation for data format for TensorFlow and PyTorch/DP.)
30 changes: 22 additions & 8 deletions source/tests/pt/model/test_forward_lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ def test(
) = extend_input_and_build_neighbor_list(
coord.unsqueeze(0),
atype.unsqueeze(0),
self.model.get_rcut(),
self.model.get_rcut() + 1.0
if test_spin
else self.model.get_rcut(), # buffer region for spin nlist
self.model.get_sel(),
mixed_types=self.model.mixed_types(),
box=cell.unsqueeze(0),
Expand Down Expand Up @@ -128,15 +130,13 @@ class TestEnergyModelSeA(unittest.TestCase, ForwardLowerTest):
def setUp(self):
self.prec = 1e-10
model_params = copy.deepcopy(model_se_e2_a)
self.type_split = False
self.model = get_model(model_params).to(env.DEVICE)


class TestEnergyModelDPA1(unittest.TestCase, ForwardLowerTest):
def setUp(self):
self.prec = 1e-10
model_params = copy.deepcopy(model_dpa1)
self.type_split = True
self.model = get_model(model_params).to(env.DEVICE)


Expand All @@ -151,24 +151,38 @@ def setUp(self):
"repinit_nsel"
]
model_params = copy.deepcopy(model_dpa2)
self.type_split = True
self.model = get_model(model_params).to(env.DEVICE)


class TestEnergyModelZBL(unittest.TestCase, ForwardLowerTest):
def setUp(self):
self.prec = 1e-10
model_params = copy.deepcopy(model_zbl)
self.type_split = False
self.model = get_model(model_params).to(env.DEVICE)


class TestEnergyModelSpinSeA(unittest.TestCase, ForwardLowerTest):
def setUp(self):
# still need to figure out why only 1e-5 rtol and atol
self.prec = 1e-5
self.prec = 1e-10
model_params = copy.deepcopy(model_spin)
self.test_spin = True
self.model = get_model(model_params).to(env.DEVICE)


class TestEnergyModelSpinDPA1(unittest.TestCase, ForwardLowerTest):
def setUp(self):
self.prec = 1e-10
model_params = copy.deepcopy(model_spin)
model_params["descriptor"] = copy.deepcopy(model_dpa1)["descriptor"]
self.test_spin = True
self.model = get_model(model_params).to(env.DEVICE)


class TestEnergyModelSpinDPA2(unittest.TestCase, ForwardLowerTest):
def setUp(self):
self.prec = 1e-10
model_params = copy.deepcopy(model_spin)
self.type_split = False
model_params["descriptor"] = copy.deepcopy(model_dpa2)["descriptor"]
self.test_spin = True
self.model = get_model(model_params).to(env.DEVICE)

Expand Down
4 changes: 2 additions & 2 deletions source/tests/pt/model/test_permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,10 @@
"type": "dpa2",
"repinit_rcut": 6.0,
"repinit_rcut_smth": 2.0,
"repinit_nsel": 30,
"repinit_nsel": 100,
"repformer_rcut": 4.0,
"repformer_rcut_smth": 0.5,
"repformer_nsel": 20,
"repformer_nsel": 40,
"repinit_neuron": [2, 4, 8],
"repinit_axis_neuron": 4,
"repinit_activation": "tanh",
Expand Down

0 comments on commit b8d8c9f

Please sign in to comment.