Skip to content

Commit

Permalink
MPRester lazily get endpoint and api_key (#936)
Browse files Browse the repository at this point in the history
* get env var lazily

* remove seemingly unused deprecation warn

* remove unused ignore tag

* fix URL case

* access self.endpoint for updated entry

* only patch api key env var in CI env

* add unit test for lazy mp api key

* remove skip decorator

* also check endpoint

* add more tests for endpoint

* don't patch api_key and recover skip mark

* avoid duplicate endpoint

* also test default and invalid api key

* os.environ.get -> os.getenv

* BaseRester also get lazily

* make sure self.endpoint is set

* remove duplicated pytest skip mark

* turn off fail-fast

* NEED CONFIRM: filter get_data_by_id deprecation warning

* cleanup

* linting

* try upper casing secrets name

* use tr for uppercase

* test api key in header

* Revert "test api key in header"

This reverts commit 6435643.

---------

Co-authored-by: Patrick Huck <[email protected]>
  • Loading branch information
DanielYang59 and tschaume authored Oct 3, 2024
1 parent 4865bf1 commit 96b9e37
Show file tree
Hide file tree
Showing 12 changed files with 87 additions and 80 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ jobs:
- name: Format API key name (Linux/MacOS)
if: matrix.os == 'ubuntu-latest' || matrix.os == 'macos-latest'
run: |
echo "API_KEY_NAME=$(echo ${{ format('MP_API_KEY_{0}_{1}', matrix.os, matrix.python-version) }} | awk '{gsub(/-|\./, "_"); print}')" >> $GITHUB_ENV
echo "API_KEY_NAME=$(echo ${{ format('MP_API_KEY_{0}_{1}', matrix.os, matrix.python-version) }} | awk '{gsub(/-|\./, "_"); print}' | tr '[:lower:]' '[:upper:]')" >> $GITHUB_ENV
- name: Format API key name (Windows)
if: matrix.os == 'windows-latest'
run: |
echo "API_KEY_NAME=$(echo ${{ format('MP_API_KEY_{0}_{1}', matrix.os, matrix.python-version) }} | awk '{gsub(/-|\./, "_"); print}')" | Out-File -FilePath $Env:GITHUB_ENV -Encoding utf8 -Append
echo "API_KEY_NAME=$(echo ${{ format('MP_API_KEY_{0}_{1}', matrix.os, matrix.python-version) }} | awk '{gsub(/-|\./, "_"); print}' | tr '[:lower:]' '[:upper:]')" | Out-File -FilePath $Env:GITHUB_ENV -Encoding utf8 -Append
- name: Test with pytest
env:
Expand Down
24 changes: 12 additions & 12 deletions mp_api/client/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from importlib.metadata import PackageNotFoundError, version
from json import JSONDecodeError
from math import ceil
from typing import Any, Callable, Generic, TypeVar
from typing import TYPE_CHECKING, Generic, TypeVar
from urllib.parse import quote, urljoin

import requests
Expand All @@ -42,18 +42,16 @@
except ImportError:
boto3 = None

if TYPE_CHECKING:
from typing import Any, Callable

try:
__version__ = version("mp_api")
except PackageNotFoundError: # pragma: no cover
__version__ = os.getenv("SETUPTOOLS_SCM_PRETEND_VERSION")

# TODO: think about how to migrate from PMG_MAPI_KEY
DEFAULT_API_KEY = os.environ.get("MP_API_KEY", None)
DEFAULT_ENDPOINT = os.environ.get(
"MP_API_ENDPOINT", "https://api.materialsproject.org/"
)

settings = MAPIClientSettings() # type: ignore
SETTINGS = MAPIClientSettings() # type: ignore

T = TypeVar("T")

Expand All @@ -69,7 +67,7 @@ class BaseRester(Generic[T]):
def __init__(
self,
api_key: str | None = None,
endpoint: str = DEFAULT_ENDPOINT,
endpoint: str | None = None,
include_user_agent: bool = True,
session: requests.Session | None = None,
s3_client: Any | None = None,
Expand All @@ -78,7 +76,7 @@ def __init__(
use_document_model: bool = True,
timeout: int = 20,
headers: dict | None = None,
mute_progress_bars: bool = settings.MUTE_PROGRESS_BARS,
mute_progress_bars: bool = SETTINGS.MUTE_PROGRESS_BARS,
):
"""Initialize the REST API helper class.
Expand Down Expand Up @@ -111,9 +109,11 @@ def __init__(
headers: Custom headers for localhost connections.
mute_progress_bars: Whether to disable progress bars.
"""
self.api_key = api_key or DEFAULT_API_KEY
self.base_endpoint = endpoint
self.endpoint = endpoint
# TODO: think about how to migrate from PMG_MAPI_KEY
self.api_key = api_key or os.getenv("MP_API_KEY")
self.base_endpoint = self.endpoint = endpoint or os.getenv(
"MP_API_ENDPOINT", "https://api.materialsproject.org/"
)
self.debug = debug
self.include_user_agent = include_user_agent
self.monty_decode = monty_decode
Expand Down
35 changes: 15 additions & 20 deletions mp_api/client/mprester.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import warnings
from functools import cache, lru_cache
from json import loads
from typing import Literal
from typing import TYPE_CHECKING

from emmet.core.electronic_structure import BSPathType
from emmet.core.mpid import MPID
Expand Down Expand Up @@ -60,19 +60,12 @@
from mp_api.client.routes.materials.materials import MaterialsRester
from mp_api.client.routes.molecules import MoleculeRester

_DEPRECATION_WARNING = (
"MPRester is being modernized. Please use the new method suggested and "
"read more about these changes at https://docs.materialsproject.org/api. The current "
"methods will be retained until at least January 2022 for backwards compatibility."
)
if TYPE_CHECKING:
from typing import Literal

_EMMET_SETTINGS = EmmetSettings() # type: ignore
_MAPI_SETTINGS = MAPIClientSettings() # typeL ignore # type: ignore

DEFAULT_API_KEY = os.environ.get("MP_API_KEY", None)
DEFAULT_ENDPOINT = os.environ.get(
"MP_API_ENDPOINT", "https://api.materialsproject.org/"
)
_EMMET_SETTINGS = EmmetSettings()
_MAPI_SETTINGS = MAPIClientSettings()


class MPRester:
Expand Down Expand Up @@ -124,7 +117,7 @@ class MPRester:
def __init__(
self,
api_key: str | None = None,
endpoint: str = DEFAULT_ENDPOINT,
endpoint: str | None = None,
notify_db_version: bool = False,
include_user_agent: bool = True,
monty_decode: bool = True,
Expand All @@ -143,10 +136,10 @@ def __init__(
If so, it will use that environment variable. This makes
easier for heavy users to simply add this environment variable to
their setups and MPRester can then be called without any arguments.
endpoint (str): Url of endpoint to access the MaterialsProject REST
endpoint (str): URL of endpoint to access the MaterialsProject REST
interface. Defaults to the standard Materials Project REST
address at "https://api.materialsproject.org", but
can be changed to other urls implementing a similar interface.
can be changed to other URLs implementing a similar interface.
notify_db_version (bool): If True, the current MP database version will
be retrieved and logged locally in the ~/.mprester.log.yaml. If the database
version changes, you will be notified. The current database version is
Expand All @@ -169,7 +162,7 @@ def __init__(
"""
# SETTINGS tries to read API key from ~/.config/.pmgrc.yaml
api_key = api_key or DEFAULT_API_KEY or SETTINGS.get("PMG_MAPI_KEY")
api_key = api_key or os.getenv("MP_API_KEY") or SETTINGS.get("PMG_MAPI_KEY")

if api_key and len(api_key) != 32:
raise ValueError(
Expand All @@ -179,7 +172,9 @@ def __init__(
)

self.api_key = api_key
self.endpoint = endpoint
self.endpoint = endpoint or os.getenv(
"MP_API_ENDPOINT", "https://api.materialsproject.org/"
)
self.headers = headers or {}
self.session = session or BaseRester._create_session(
api_key=self.api_key,
Expand Down Expand Up @@ -257,7 +252,7 @@ def __init__(
core_resters = {
cls.suffix.split("/")[0]: cls(
api_key=api_key,
endpoint=endpoint,
endpoint=self.endpoint,
include_user_agent=include_user_agent,
session=self.session,
monty_decode=monty_decode,
Expand All @@ -280,7 +275,7 @@ def __init__(
if len(suffix_split) == 1:
rester = cls(
api_key=api_key,
endpoint=endpoint,
endpoint=self.endpoint,
include_user_agent=include_user_agent,
session=self.session,
monty_decode=monty_decode
Expand Down Expand Up @@ -310,7 +305,7 @@ def __core_custom_getattr(_self, _attr, _rester_map):
cls = _rester_map[_attr]
rester = cls(
api_key=api_key,
endpoint=endpoint,
endpoint=self.endpoint,
include_user_agent=include_user_agent,
session=self.session,
monty_decode=monty_decode
Expand Down
4 changes: 1 addition & 3 deletions tests/materials/test_materials.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ def rester():
} # type: dict


@pytest.mark.skipif(
os.environ.get("MP_API_KEY", None) is None, reason="No API key found."
)
@pytest.mark.skipif(os.getenv("MP_API_KEY") is None, reason="No API key found.")
def test_client(rester):
search_method = rester.search

Expand Down
4 changes: 1 addition & 3 deletions tests/materials/test_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,7 @@
} # type: dict


@pytest.mark.skipif(
os.environ.get("MP_API_KEY", None) is None, reason="No API key found."
)
@pytest.mark.skipif(os.getenv("MP_API_KEY") is None, reason="No API key found.")
def test_client():
search_method = SummaryRester().search

Expand Down
4 changes: 1 addition & 3 deletions tests/materials/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,7 @@ def rester():
} # type: dict


@pytest.mark.skipif(
os.environ.get("MP_API_KEY", None) is None, reason="No API key found."
)
@pytest.mark.skipif(os.getenv("MP_API_KEY") is None, reason="No API key found.")
def test_client(rester):
search_method = rester.search

Expand Down
4 changes: 1 addition & 3 deletions tests/materials/test_thermo.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,7 @@ def rester():
} # type: dict


@pytest.mark.skipif(
os.environ.get("MP_API_KEY", None) is None, reason="No API key found."
)
@pytest.mark.skipif(os.getenv("MP_API_KEY") is None, reason="No API key found.")
def test_client(rester):
search_method = rester.search

Expand Down
4 changes: 1 addition & 3 deletions tests/materials/test_xas.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ def rester():


@pytest.mark.skip(reason="Temp skip until timeout update.")
@pytest.mark.skipif(
os.environ.get("MP_API_KEY", None) is None, reason="No API key found."
)
@pytest.mark.skipif(os.getenv("MP_API_KEY") is None, reason="No API key found.")
def test_client(rester):
search_method = rester.search

Expand Down
4 changes: 1 addition & 3 deletions tests/molecules/test_jcesr.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ def rester():
} # type: dict


@pytest.mark.skipif(
os.environ.get("MP_API_KEY", None) is None, reason="No API key found."
)
@pytest.mark.skipif(os.getenv("MP_API_KEY") is None, reason="No API key found.")
def test_client(rester):
search_method = rester.search

Expand Down
4 changes: 1 addition & 3 deletions tests/molecules/test_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@


@pytest.mark.skip(reason="Temporary until data adjustments")
@pytest.mark.skipif(
os.environ.get("MP_API_KEY", None) is None, reason="No API key found."
)
@pytest.mark.skipif(os.getenv("MP_API_KEY") is None, reason="No API key found.")
def test_client():
search_method = MoleculesSummaryRester().search

Expand Down
14 changes: 4 additions & 10 deletions tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import warnings

import pytest

Expand Down Expand Up @@ -49,9 +50,7 @@
]


@pytest.mark.skipif(
os.environ.get("MP_API_KEY", None) is None, reason="No API key found."
)
@pytest.mark.skipif(os.getenv("MP_API_KEY") is None, reason="No API key found.")
@pytest.mark.parametrize("rester", resters_to_test)
def test_generic_get_methods(rester):
# -- Test generic search and get_data_by_id methods
Expand All @@ -61,9 +60,8 @@ def test_generic_get_methods(rester):
endpoint=mpr.endpoint,
include_user_agent=True,
session=mpr.session,
monty_decode=True
if rester not in [TaskRester, ProvenanceRester] # type: ignore
else False, # Disable monty decode on nested data which may give errors
# Disable monty decode on nested data which may give errors
monty_decode=rester not in [TaskRester, ProvenanceRester],
use_document_model=True,
)

Expand All @@ -85,7 +83,3 @@ def test_generic_get_methods(rester):
key_only_resters[name], fields=[rester.primary_key]
)
assert isinstance(doc, rester.document_model)


if os.getenv("MP_API_KEY", None) is None:
pytest.mark.skip(test_generic_get_methods)
Loading

0 comments on commit 96b9e37

Please sign in to comment.