Skip to content

Commit

Permalink
Add builder tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sivonxay committed Nov 2, 2023
1 parent a070335 commit 8355617
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 26 deletions.
2 changes: 1 addition & 1 deletion src/NanoParticleTools/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def get_items(self):

# count the total items in the dict
total_items = sum([len(v) for v in unduplicated_dict.values()])
if total_items != self.n_docs_filter:
if total_items < self.n_docs_filter:
continue

yield unduplicated_dict, docs_to_avg
Expand Down
24 changes: 12 additions & 12 deletions src/NanoParticleTools/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,18 @@

class NPMCInput(MSONable):

def __init__(self,
spectral_kinetics: SpectralKinetics,
nanoparticle: DopedNanoparticle,
initial_states: Sequence[int] | None = None):

self.spectral_kinetics = spectral_kinetics
self.nanoparticle = nanoparticle
if initial_states is None:
self.initial_states = [0 for _ in self.sites]

Check warning on line 134 in src/NanoParticleTools/core.py

View check run for this annotation

Codecov / codecov/patch

src/NanoParticleTools/core.py#L131-L134

Added lines #L131 - L134 were not covered by tests
else:
self.initial_states = initial_states

Check warning on line 136 in src/NanoParticleTools/core.py

View check run for this annotation

Codecov / codecov/patch

src/NanoParticleTools/core.py#L136

Added line #L136 was not covered by tests

def load_trajectory(self, seed, database_file):
with sqlite3.connect(database_file) as con:
cur = con.cursor()
Expand Down Expand Up @@ -164,18 +176,6 @@ def load_trajectories(self, database_file: str):

self.trajectories = trajectories

def __init__(self,
spectral_kinetics: SpectralKinetics,
nanoparticle: DopedNanoparticle,
initial_states: Sequence[int] | None = None):

self.spectral_kinetics = spectral_kinetics
self.nanoparticle = nanoparticle
if initial_states is None:
self.initial_states = [0 for _ in self.sites]
else:
self.initial_states = initial_states

@property
@lru_cache
def interactions(self):
Expand Down
126 changes: 113 additions & 13 deletions tests/test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,135 @@
MultiFidelityAveragingBuilder)
from maggma.stores import MemoryStore

from pathlib import Path
import pytest
import json
from monty.serialization import MontyDecoder

MODULE_DIR = Path(__file__).absolute().parent
TEST_FILE_DIR = MODULE_DIR / 'test_files'


@pytest.fixture
def raw_docs_store():
store = MemoryStore()
# TODO: add documents to this memory store
store.connect()

# load the documents into the store
with open(TEST_FILE_DIR / 'npmc_docs/raw_documents.json', 'r') as f:
results = json.load(f, cls=MontyDecoder)

store.update(results, key='_id')
return store


@pytest.mark.skip('Not implemented yet, need to add documents')
def test_ucnp_builder(raw_docs_store):
target_store = MemoryStore()

builder = UCNPBuilder(raw_docs_store, target_store)
builder.run()

assert target_store.count() == 2
assert len(list(builder.prechunk(2))) == 2

target_store = MemoryStore()
builder = UCNPBuilder(source=raw_docs_store,
target=target_store,
docs_filter={'data.n_dopant_sites': 1399})
builder.run()

assert target_store.count() == 1
doc = target_store.query_one()
assert doc['avg_simulation_length'] == pytest.approx(40078.375)
assert doc['avg_simulation_time'] == pytest.approx(0.01018022901835798)
assert set(doc['output'].keys()) == {
'summary_keys', 'summary', 'energy_spectrum_x', 'energy_spectrum_y',
'wavelength_spectrum_x', 'wavelength_spectrum_y'
}


pass
def test_ucnp_pop_builder(raw_docs_store):
target_store = MemoryStore()
builder = UCNPPopBuilder(raw_docs_store, target_store)
builder.run()

assert target_store.count() == 2
doc = target_store.query_one()
assert set(doc['output'].keys()) == {
'energy_spectrum_x', 'energy_spectrum_y', 'wavelength_spectrum_x',
'wavelength_spectrum_y', 'avg_total_pop',
'avg_total_pop_by_constraint', 'avg_5ms_total_pop',
'avg_8ms_total_pop', 'avg_5ms_total_pop_by_constraint',
'avg_8ms_total_pop_by_constraint'
}

@pytest.mark.skip('Not implemented yet, need to add documents')
def test_ucnp_pop_builder():
pass

def test_partial_averaging_builder(raw_docs_store):
# Test building 44 averaging
target_store = MemoryStore()
builder = PartialAveragingBuilder(n_orderings=4,
n_sims=4,
source=raw_docs_store,
target=target_store)
builder.run()

@pytest.mark.skip('Not implemented yet, need to add documents')
def test_partial_averaging_builder():
pass
assert target_store.count() == 1
doc = target_store.query_one()
doc['num_averaged'] == 16

# Test building 22 averaging
target_store = MemoryStore()
builder = PartialAveragingBuilder(n_orderings=2,
n_sims=2,
source=raw_docs_store,
target=target_store)
builder.run()

@pytest.mark.skip('Not implemented yet, need to add documents')
def test_multi_fidelity_averaging_builder():
pass
assert target_store.count() == 1
doc = target_store.query_one()
doc['num_averaged'] == 4

# Test building 14 averaging
target_store = MemoryStore()
builder = PartialAveragingBuilder(n_orderings=1,
n_sims=4,
source=raw_docs_store,
target=target_store)
builder.run()

assert target_store.count() == 2
doc = target_store.query_one()
doc['num_averaged'] == 4

# Test building 41 averaging
target_store = MemoryStore()
builder = PartialAveragingBuilder(n_orderings=4,
n_sims=1,
source=raw_docs_store,
target=target_store)
builder.run()

assert target_store.count() == 1
doc = target_store.query_one()
doc['num_averaged'] == 4


def test_multi_fidelity_averaging_builder(raw_docs_store):
target_store = MemoryStore()
builder = MultiFidelityAveragingBuilder(n_docs_filter=4,
source=raw_docs_store,
target=target_store)
builder.run()
assert target_store.count() == 2
docs = list(target_store.query())
assert set(docs[0]['output']['energy_spectrum_y'].keys()) == {
'(1, 1)', '(2, 2)', '(4, 1)', '(1, 4)', '(4, 4)'
}
assert set(
docs[1]['output']['energy_spectrum_y'].keys()) == {'(1, 1)', '(4, 1)'}

target_store = MemoryStore()
builder = MultiFidelityAveragingBuilder(n_docs_filter=16,
source=raw_docs_store,
target=target_store)
builder.run()
assert target_store.count() == 1
1 change: 1 addition & 0 deletions tests/test_files/npmc_docs/raw_documents.json

Large diffs are not rendered by default.

0 comments on commit 8355617

Please sign in to comment.