Skip to content

Commit

Permalink
Enable caching for get_ion_reference_data (#429)
Browse files Browse the repository at this point in the history
  • Loading branch information
rkingsbury authored Nov 30, 2021
1 parent a5214e9 commit f31191a
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
16 changes: 10 additions & 6 deletions src/mp_api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import warnings
from functools import lru_cache
from os import environ
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union, Dict
from typing_extensions import Literal

from emmet.core.mpid import MPID
Expand Down Expand Up @@ -533,10 +533,7 @@ def get_pourbaix_entries(

return pbx_entries

# TODO - @lru_cache causes this method to fail when chemsys is given as a list,
# with an 'unhashable type' error
# @lru_cache
def get_ion_reference_data(self, chemsys: Union[str, List]) -> List[dict]:
def get_ion_reference_data(self, chemsys: Union[str, List]) -> List[Dict]:
"""
Download aqueous ion reference data used in the construction of Pourbaix diagrams.
Expand Down Expand Up @@ -575,7 +572,14 @@ def get_ion_reference_data(self, chemsys: Union[str, List]) -> List[dict]:
# capitalize and sort the elements
chemsys = sorted(e.capitalize() for e in chemsys)

# TODO - see if there is a way to avoid querying the entire collection
# convert to a tuple which is hashable
return self._get_ion_reference_data(tuple(chemsys)) # type: ignore

@lru_cache # type: ignore
def _get_ion_reference_data(self, chemsys: Tuple): # type: ignore
"""
Private, cacheable helper method for get_ion_reference data.
"""
ion_data = [
d
for d in self.contribs.contributions.get_entries(
Expand Down
4 changes: 2 additions & 2 deletions src/mp_api/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def get_data_by_id(

try:
results = self._query_resource_data(
criteria=criteria, fields=fields, suburl=document_id,
criteria=criteria, fields=fields, suburl=document_id, # type: ignore
)
except MPRestError:

Expand All @@ -388,7 +388,7 @@ def get_data_by_id(
document_id = new_document_id

results = self._query_resource_data(
criteria=criteria, fields=fields, suburl=document_id,
criteria=criteria, fields=fields, suburl=document_id, # type: ignore
)

if not results:
Expand Down

0 comments on commit f31191a

Please sign in to comment.