Skip to content

Commit

Permalink
Hint schemes added to materials and summary (#446)
Browse files Browse the repository at this point in the history
* Add hint schemes to summary and materials

* Bump maggma for hints

* Remove xfails and test fixes

* Fix lru_cache decorator
  • Loading branch information
Jason Munro authored Nov 30, 2021
1 parent f31191a commit c94800f
Show file tree
Hide file tree
Showing 18 changed files with 116 additions and 31 deletions.
2 changes: 1 addition & 1 deletion requirements-server.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
fastapi==0.70.0
maggma==0.32.3
maggma==0.33.0
uvicorn==0.15.0
gunicorn[gevent]==20.1.0
boto3==1.20.15
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
pydantic==1.8.2
pymatgen>=2022.0.16
typing-extensions==4.0.0
maggma==0.32.3
maggma==0.33.0
requests==2.26.0
monty==2021.8.17
emmet-core==0.18.0
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"requests>=2.23.0",
"monty>=2021.8.17",
"emmet-core",
"maggma>=0.32.2",
"maggma>=0.33.0",
"ratelimit",
],
extras_require={
Expand Down
2 changes: 1 addition & 1 deletion src/mp_api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ def get_ion_reference_data(self, chemsys: Union[str, List]) -> List[Dict]:
# convert to a tuple which is hashable
return self._get_ion_reference_data(tuple(chemsys)) # type: ignore

@lru_cache # type: ignore
@lru_cache() # type: ignore
def _get_ion_reference_data(self, chemsys: Tuple): # type: ignore
"""
Private, cacheable helper method for get_ion_reference data.
Expand Down
12 changes: 12 additions & 0 deletions src/mp_api/routes/materials/hint_scheme.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from maggma.api.resource import HintScheme


class MaterialsHintScheme(HintScheme):
"""
Hint scheme for the materials endpoint.
"""

def generate_hints(self, query):

if "nelements" in query["criteria"]:
return {"hint": {"nelements": 1}}
2 changes: 2 additions & 0 deletions src/mp_api/routes/materials/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
NumericQuery,
)

from mp_api.routes.materials.hint_scheme import MaterialsHintScheme
from mp_api.routes.materials.query_operators import (
ElementsQuery,
FormulaQuery,
Expand Down Expand Up @@ -68,6 +69,7 @@ def materials_resource(materials_store):
default_fields=["material_id", "formula_pretty", "last_updated"],
),
],
hint_scheme=MaterialsHintScheme(),
tags=["Materials"],
disable_validation=True,
)
Expand Down
15 changes: 15 additions & 0 deletions src/mp_api/routes/summary/hint_scheme.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from maggma.api.resource import HintScheme


class SummaryHintScheme(HintScheme):
"""
Hint scheme for the summary endpoint.
"""

def generate_hints(self, query):

if "nelements" in query["criteria"]:
return {"hint": {"nelements": 1}}

elif "has_props" in query["criteria"]:
return {"hint": {"has_props": 1}}
2 changes: 2 additions & 0 deletions src/mp_api/routes/summary/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)
from mp_api.routes.oxidation_states.query_operators import PossibleOxiStateQuery
from emmet.core.summary import SummaryStats
from mp_api.routes.summary.hint_scheme import SummaryHintScheme
from mp_api.routes.summary.query_operators import (
HasPropsQuery,
MaterialIDsSearchQuery,
Expand Down Expand Up @@ -49,6 +50,7 @@ def summary_resource(summary_store):
PaginationQuery(),
SparseFieldsQuery(SummaryDoc, default_fields=["material_id"]),
],
hint_scheme=SummaryHintScheme(),
tags=["Summary"],
disable_validation=True,
)
Expand Down
2 changes: 1 addition & 1 deletion src/mp_api/routes/thermo/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections import defaultdict
from typing import Optional, List, Tuple
from mp_api.core.client import BaseRester, MPRestError
from mp_api.core.client import BaseRester
from emmet.core.thermo import ThermoDoc
from pymatgen.analysis.phase_diagram import PhaseDiagram

Expand Down
21 changes: 16 additions & 5 deletions tests/bonds/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,21 @@

sub_doc_fields = [] # type: list

alt_name_dict = {} # type: dict
alt_name_dict = {
"max_bond_length": "bond_length_stats",
"min_bond_length": "bond_length_stats",
"mean_bond_length": "bond_length_stats",
} # type: dict

custom_field_tests = {} # type: dict
custom_field_tests = {
"coordination_envs": ["Mo-S(6)"],
"coordination_envs_anonymous": ["A-B(6)"],
} # type: dict


@pytest.mark.xfail(reason="Needs deployment")
@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.")
@pytest.mark.skipif(
os.environ.get("MP_API_KEY", None) is None, reason="No API key found."
)
@pytest.mark.parametrize("rester", resters)
def test_client(rester):
# Get specific search method
Expand Down Expand Up @@ -76,4 +84,7 @@ def test_client(rester):
if sub_field in doc:
doc = doc[sub_field]

assert doc[project_field if project_field is not None else param] is not None
assert (
doc[project_field if project_field is not None else param]
is not None
)
10 changes: 7 additions & 3 deletions tests/dielectric/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
custom_field_tests = {} # type: dict


@pytest.mark.xfail(reason="Needs deployment")
@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.")
@pytest.mark.skipif(
os.environ.get("MP_API_KEY", None) is None, reason="No API key found."
)
@pytest.mark.parametrize("rester", resters)
def test_client(rester):
# Get specific search method
Expand Down Expand Up @@ -76,4 +77,7 @@ def test_client(rester):
if sub_field in doc:
doc = doc[sub_field]

assert doc[project_field if project_field is not None else param] is not None
assert (
doc[project_field if project_field is not None else param]
is not None
)
10 changes: 7 additions & 3 deletions tests/magnetism/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
custom_field_tests = {"ordering": Ordering.FM} # type: dict


@pytest.mark.xfail(reason="Needs deployment")
@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.")
@pytest.mark.skipif(
os.environ.get("MP_API_KEY", None) is None, reason="No API key found."
)
@pytest.mark.parametrize("rester", resters)
def test_client(rester):
# Get specific search method
Expand Down Expand Up @@ -78,4 +79,7 @@ def test_client(rester):
if sub_field in doc:
doc = doc[sub_field]

assert doc[project_field if project_field is not None else param] is not None
assert (
doc[project_field if project_field is not None else param]
is not None
)
8 changes: 8 additions & 0 deletions tests/materials/test_hint_scheme.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from mp_api.routes.materials.hint_scheme import MaterialsHintScheme


def test_materials_hint_scheme():
scheme = MaterialsHintScheme()
assert scheme.generate_hints({"criteria": {"nelements": 3}}) == {
"hint": {"nelements": 1}
}
10 changes: 7 additions & 3 deletions tests/piezo/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
custom_field_tests = {} # type: dict


@pytest.mark.xfail(reason="Needs deployment")
@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.")
@pytest.mark.skipif(
os.environ.get("MP_API_KEY", None) is None, reason="No API key found."
)
@pytest.mark.parametrize("rester", resters)
def test_client(rester):
# Get specific search method
Expand Down Expand Up @@ -77,4 +78,7 @@ def test_client(rester):
if sub_field in doc:
doc = doc[sub_field]

assert doc[project_field if project_field is not None else param] is not None
assert (
doc[project_field if project_field is not None else param]
is not None
)
11 changes: 11 additions & 0 deletions tests/summary/test_hint_scheme.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from mp_api.routes.summary.hint_scheme import SummaryHintScheme


def test_summary_hint_scheme():
scheme = SummaryHintScheme()
assert scheme.generate_hints({"criteria": {"nelements": 3}}) == {
"hint": {"nelements": 1}
}
assert scheme.generate_hints({"criteria": {"has_props": "dos"}}) == {
"hint": {"has_props": 1}
}
1 change: 0 additions & 1 deletion tests/synthesis/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ def test_filters_time_range(rester):
@pytest.mark.skipif(
os.environ.get("MP_API_KEY", None) is None, reason="No API key found."
)
@pytest.mark.xfail # Needs fixing
def test_filters_atmosphere(rester):
search_method = None
for entry in inspect.getmembers(rester, predicate=inspect.ismethod):
Expand Down
3 changes: 0 additions & 3 deletions tests/test_mprester.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def test_get_structure_by_material_id(self, mpr):
with pytest.warns(UserWarning):
mpr.get_structure_by_material_id("mp-698856")

@pytest.mark.xfail(reason="Until deployment")
def test_get_database_version(self, mpr):
db_version = mpr.get_database_version()
assert db_version == MAPISettings().DB_VERSION
Expand Down Expand Up @@ -82,7 +81,6 @@ def test_get_structures(self, mpr):
structs = mpr.get_structures("Mn3O4", final=False)
assert len(structs) > 0

@pytest.mark.xfail(reason="Until deployment")
def test_find_structure(self, mpr):
path = os.path.join(MAPISettings().TEST_FILES, "Si_mp_149.cif")
with open(path) as file:
Expand All @@ -109,7 +107,6 @@ def test_get_entry_by_material_id(self, mpr):
assert isinstance(e[0], ComputedEntry)
assert e[0].composition.reduced_formula == "LiFePO4"

@pytest.mark.xfail(reason="Until deployment")
def test_get_entries(self, mpr):
syms = ["Li", "Fe", "O"]
chemsys = "Li-Fe-O"
Expand Down
32 changes: 24 additions & 8 deletions tests/thermo/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@
} # type: dict


@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.")
@pytest.mark.skipif(
os.environ.get("MP_API_KEY", None) is None, reason="No API key found."
)
@pytest.mark.parametrize("rester", resters)
def test_client(rester):
# Get specific search method
Expand All @@ -60,42 +62,56 @@ def test_client(rester):
param: (-100, 100),
"chunk_size": 1,
"num_chunks": 1,
"fields": [project_field if project_field is not None else param],
"fields": [
project_field if project_field is not None else param
],
}
elif param_type is typing.Tuple[float, float]:
project_field = alt_name_dict.get(param, None)
q = {
param: (-100.12, 100.12),
"chunk_size": 1,
"num_chunks": 1,
"fields": [project_field if project_field is not None else param],
"fields": [
project_field if project_field is not None else param
],
}
elif param_type is bool:
project_field = alt_name_dict.get(param, None)
q = {
param: False,
"chunk_size": 1,
"num_chunks": 1,
"fields": [project_field if project_field is not None else param],
"fields": [
project_field if project_field is not None else param
],
}
elif param in custom_field_tests:
project_field = alt_name_dict.get(param, None)
q = {
param: custom_field_tests[param],
"chunk_size": 1,
"num_chunks": 1,
"fields": [project_field if project_field is not None else param],
"fields": [
project_field if project_field is not None else param
],
}

doc = search_method(**q)[0].dict()
for sub_field in sub_doc_fields:
if sub_field in doc:
doc = doc[sub_field]

assert doc[project_field if project_field is not None else param] is not None
assert (
doc[project_field if project_field is not None else param]
is not None
)


@pytest.mark.xfail(reason="Temporary until deployment")
@pytest.mark.xfail(reason="Monty decode issue with phase diagram")
def test_get_phase_diagram_from_chemsys():
# Test that a phase diagram is returned
assert isinstance(ThermoRester().get_phase_diagram_from_chemsys("Fe-Mn-Pt"), PhaseDiagram)

assert isinstance(
ThermoRester().get_phase_diagram_from_chemsys("Hf-Pm"), PhaseDiagram
)

0 comments on commit c94800f

Please sign in to comment.