Skip to content

Commit

Permalink
Fix case and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
sivonxay committed Oct 3, 2023
1 parent ae09ac1 commit 7046c74
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 40 deletions.
6 changes: 6 additions & 0 deletions src/NanoParticleTools/inputs/nanoparticle.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,12 @@ def __init__(self,
site for site in constraint.get_host_structure().sites
if site.species_string in dopants_dict.keys()
]

if len(_sites) == 0:
# Need at least one site for a valid pymatgen structure.
# Just keep the first site from the host structure
_sites = [constraint.get_host_structure().sites[0]]

constraint.host_structure = Structure.from_sites(_sites)

self._sites = None
Expand Down
81 changes: 41 additions & 40 deletions tests/inputs/test_nanoparticle.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,22 @@
from NanoParticleTools.inputs.nanoparticle import (
SphericalConstraint,
PrismConstraint,
CubeConstraint,
DopedNanoparticle,
get_nayf4_structure,
get_wse2_structure
)
SphericalConstraint, PrismConstraint, CubeConstraint, DopedNanoparticle,
get_nayf4_structure, get_wse2_structure)
import numpy as np
import pytest


def test_spherical_constraint():
constraint = SphericalConstraint(30)
points = [[30, 0, 0],
[0, 30, 0],
[0, 0, 30],
[0, 30, 30],
[30, 30, 30],
[17, 17, 17],
[18, 18, 18]]
points = [[30, 0, 0], [0, 30, 0], [0, 0, 30], [0, 30, 30], [30, 30, 30],
[17, 17, 17], [18, 18, 18]]

assert constraint.get_host_structure() == get_nayf4_structure()
assert constraint.radius == 30
assert constraint.bounding_box() == [30, 30, 30]

assert np.all(constraint.sites_in_bounds(points) ==
np.array([True, True, True, False, False, True, False]))
assert np.all(
constraint.sites_in_bounds(points) == np.array(
[True, True, True, False, False, True, False]))

structure = get_wse2_structure()
constraint = SphericalConstraint(20, structure)
Expand All @@ -34,31 +25,24 @@ def test_spherical_constraint():

def test_prism_constraint():
constraint = PrismConstraint(60, 60, 80)
points = [[30, 0, 0],
[0, 30, 0],
[0, 0, 40],
[0, 0, 41],
[31, 0, 0],
points = [[30, 0, 0], [0, 30, 0], [0, 0, 40], [0, 0, 41], [31, 0, 0],
[30, 30, 40]]
assert constraint.bounding_box() == [60, 60, 80]

assert np.all(constraint.sites_in_bounds(points) ==
np.array([True, True, True, False, False, True]))
assert np.all(
constraint.sites_in_bounds(points) == np.array(
[True, True, True, False, False, True]))


def test_cube_constraint():
constraint = CubeConstraint(60)
points = [[30, 0, 0],
[0, 30, 0],
[0, 0, 30],
[0, 30, 30],
[30, 30, 30],
[31, 31, 31],
[0, 0, 31]]
points = [[30, 0, 0], [0, 30, 0], [0, 0, 30], [0, 30, 30], [30, 30, 30],
[31, 31, 31], [0, 0, 31]]
assert constraint.bounding_box() == [60, 60, 60]

assert np.all(constraint.sites_in_bounds(points) ==
np.array([True, True, True, True, True, False, False]))
assert np.all(
constraint.sites_in_bounds(points) == np.array(
[True, True, True, True, True, False, False]))


def test_doped_nanoparticle():
Expand All @@ -79,6 +63,23 @@ def test_doped_nanoparticle():
assert len(dnp.sites) == 56


def test_nanoparticle_with_empty():
constraints = [
SphericalConstraint(40),
SphericalConstraint(85),
SphericalConstraint(102.5)
]
dopant_specifications = [(0, 0.5, 'Yb', 'Y'), (0, 0.2, 'Er', 'Y'),
(2, 0.5, 'Yb', 'Y'), (2, 0.1, 'Nd', 'Y')]

dnp = DopedNanoparticle(constraints,
dopant_specifications,
prune_hosts=True)
dnp.generate()
assert len(dnp.dopant_sites) == 18015
assert len(dnp.sites) == 39579


def test_empty_nanoparticle():
"""
All of these tests are expected to throw an error
Expand Down Expand Up @@ -120,21 +121,21 @@ def test_empty_nanoparticle():
def test_get_nayf4_structure():
struct = get_nayf4_structure()

assert ([str(el) for el in struct.species] ==
['Na', 'Na', 'Na', 'Y', 'Y', 'Y', 'F', 'F', 'F',
'F', 'F', 'F', 'F', 'F', 'F', 'F', 'F', 'F'])
lattice = np.array([[6.067, 0.0, 0.0],
[-3.0335, 5.25417612, 0.0],
assert ([str(el) for el in struct.species] == [
'Na', 'Na', 'Na', 'Y', 'Y', 'Y', 'F', 'F', 'F', 'F', 'F', 'F', 'F',
'F', 'F', 'F', 'F', 'F'
])
lattice = np.array([[6.067, 0.0, 0.0], [-3.0335, 5.25417612, 0.0],
[0.0, 0.0, 7.103]])
assert np.allclose(struct.lattice.matrix, lattice)


def test_get_wse2_structure():
struct = get_wse2_structure()

assert ([str(el) for el in struct.species] == ['Se', 'Se', 'Se', 'Se', 'W', 'W'])
assert ([str(el)
for el in struct.species] == ['Se', 'Se', 'Se', 'Se', 'W', 'W'])

lattice = np.array([[3.327, 0.0, 0],
[-1.6635, 2.88126652, 0],
lattice = np.array([[3.327, 0.0, 0], [-1.6635, 2.88126652, 0],
[0.0, 0.0, 15.069]])
assert np.allclose(struct.lattice.matrix, lattice)

0 comments on commit 7046c74

Please sign in to comment.