Skip to content

Commit

Permalink
Fix no axis type bug (#120)
Browse files Browse the repository at this point in the history
* add rules of local axis for NoAxisType

* add test for NoAxisType
  • Loading branch information
plumbum082 authored Oct 20, 2023
1 parent 22f0922 commit e0877ee
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 0 deletions.
8 changes: 8 additions & 0 deletions dmff/admp/spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def generate_construct_local_frames(axis_types, axis_indices):
Bisector_filter = (axis_types == Bisector)
ZBisect_filter = (axis_types == ZBisect)
ThreeFold_filter = (axis_types == ThreeFold)
NoAxisType_filter = (axis_types == NoAxisType)

def construct_local_frames(positions, box):
'''
Expand Down Expand Up @@ -139,6 +140,13 @@ def construct_local_frames(positions, box):
vec_x = normalize(vec_x - vec_z * xz_projection, axis=1)
# up to this point, x-axis should be ready
vec_y = jnp.cross(vec_z, vec_x)

# NoAxisType
if np.sum(NoAxisType_filter) > 0:
vec_y = vec_y.at[NoAxisType_filter].set(jnp.array([0,1,0]))
vec_z = vec_z.at[NoAxisType_filter].set(jnp.array([0,0,1]))
vec_x = vec_x.at[NoAxisType_filter].set(jnp.array([1,0,0]))


return jnp.stack((vec_x, vec_y, vec_z), axis=1)

Expand Down
50 changes: 50 additions & 0 deletions tests/test_admp/test_noaxistype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import jax.numpy as jnp
import numpy as np
import numpy.testing as npt
import pytest
from dmff.admp.spatial import (build_quasi_internal,
generate_construct_local_frames, pbc_shift,
v_pbc_shift)


class TestSpatial:

@pytest.mark.parametrize(
"axis_types, axis_indices, positions, box, expected_local_frames",
[
(
np.array([5]),
np.array(
[
[-1, -1, -1],
]
),
jnp.array(
[
[0.992, 0.068, -0.073],
]
),
jnp.array([[50.000, 0.0, 0.0], [0.0, 50.000, 0.0], [0.0, 0.0, 50.000]]),
np.array(
[
[
[1, 0, 0],
[0, 1, 0],
[0, 0, 1],
],
]
),
)
],
)
def test_generate_construct_local_frames(
self, axis_types, axis_indices, positions, box, expected_local_frames
):
construct_local_frame_fn = generate_construct_local_frames(
axis_types, axis_indices
)
assert construct_local_frame_fn
npt.assert_allclose(
construct_local_frame_fn(positions, box), expected_local_frames, rtol=1e-5
)

0 comments on commit e0877ee

Please sign in to comment.