Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Aerie 2.18.0 #147

Merged
merged 7 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .env
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
DOCKER_TAG=v2.11.0
DOCKER_TAG=v2.18.0
REPOSITORY_DOCKER_URL=ghcr.io/nasa-ammos

AERIE_USERNAME=aerie
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ jobs:
strategy:
matrix:
python-version: ["3.6.15", "3.11"]
aerie-version: ["2.11.0"]
aerie-version: ["2.18.0"]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand Down
2 changes: 1 addition & 1 deletion src/aerie_cli/aerie_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1862,7 +1862,7 @@ def __expand_activity_arguments(self, plan: ActivityPlanRead, full_args: str = N
for activity in plan.activities:
if expand_all or activity.type in expand_types:
query = """
query ($args: ActivityArguments!, $act_type: String!, $model_id: ID!) {
query ($args: ActivityArguments!, $act_type: String!, $model_id: Int!) {
getActivityEffectiveArguments(
activityArguments: $args,
activityTypeName: $act_type,
Expand Down
43 changes: 40 additions & 3 deletions src/aerie_cli/aerie_host.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@

from attrs import define, field

COMPATIBLE_AERIE_VERSIONS = [
"2.18.0"
]

class AerieHostVersionError(RuntimeError):
pass


def process_gateway_response(resp: requests.Response) -> dict:
"""Throw a RuntimeError if the Gateway response is malformed or contains errors
Expand All @@ -18,12 +25,12 @@ def process_gateway_response(resp: requests.Response) -> dict:
dict: Contents of response JSON
"""
if not resp.ok:
raise RuntimeError(f"Bad response from Aerie Gateway.")
raise RuntimeError("Bad response from Aerie Gateway")

try:
resp_json = resp.json()
except requests.exceptions.JSONDecodeError:
raise RuntimeError(f"Failed to get response JSON")
raise RuntimeError("Bad response from Aerie Gateway")

if "success" in resp_json.keys() and not resp_json["success"]:
raise RuntimeError(f"Aerie Gateway request was not successful")
Expand Down Expand Up @@ -260,7 +267,15 @@ def is_auth_enabled(self) -> bool:

return True

def authenticate(self, username: str, password: str = None):
def authenticate(self, username: str, password: str = None, force: bool = False):

try:
self.check_aerie_version()
except AerieHostVersionError as e:
if force:
print("Warning: " + e.args[0])
else:
raise
Mythicaeda marked this conversation as resolved.
Show resolved Hide resolved

resp = self.session.post(
self.gateway_url + "/auth/login",
Expand All @@ -278,6 +293,28 @@ def authenticate(self, username: str, password: str = None):
if not self.check_auth():
raise RuntimeError(f"Failed to open session")

def check_aerie_version(self) -> None:
"""Assert that the Aerie host is a compatible version

Raises a `RuntimeError` if the host appears to be incompatible.
"""

resp = self.session.get(self.gateway_url + "/version")

try:
resp_json = process_gateway_response(resp)
host_version = resp_json["version"]
except (RuntimeError, KeyError):
# If the Gateway responded, the route doesn't exist
if resp.text and "Aerie Gateway" in resp.text:
raise AerieHostVersionError("Incompatible Aerie version: host version unknown")
Mythicaeda marked this conversation as resolved.
Show resolved Hide resolved

# Otherwise, it could just be a failed connection
raise

if host_version not in COMPATIBLE_AERIE_VERSIONS:
raise AerieHostVersionError(f"Incompatible Aerie version: {host_version}")


@define
class ExternalAuthConfiguration:
Expand Down
5 changes: 3 additions & 2 deletions src/aerie_cli/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ def activate_session(
),
role: str = typer.Option(
None, "--role", "-r", help="Specify a non-default role", metavar="ROLE"
)
),
force: bool = typer.Option(False, "--force", help="Force connection to Aerie host and ignore version compatibility")
):
"""
Activate a session with an Aerie host using a given configuration
Expand All @@ -102,7 +103,7 @@ def activate_session(

conf = PersistentConfigurationManager.get_configuration_by_name(name)

session = start_session_from_configuration(conf, username)
session = start_session_from_configuration(conf, username, force=force)

if role is not None:
if role in session.aerie_jwt.allowed_roles:
Expand Down
6 changes: 4 additions & 2 deletions src/aerie_cli/utils/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ def start_session_from_configuration(
configuration: AerieHostConfiguration,
username: str = None,
password: str = None,
secret_post_vars: Dict[str, str] = None
secret_post_vars: Dict[str, str] = None,
force: bool = False
):
"""Start and authenticate an Aerie Host session, with prompts if necessary

Expand All @@ -136,6 +137,7 @@ def start_session_from_configuration(
username (str, optional): Aerie username.
password (str, optional): Aerie password.
secret_post_vars (Dict[str, str], optional): Optionally provide values for some or all secret post request variable values. Defaults to None.
force (bool, optional): Force connection to Aerie host and ignore version compatibility. Defaults to False.

Returns:
AerieHost:
Expand All @@ -162,6 +164,6 @@ def start_session_from_configuration(
if password is None and hs.is_auth_enabled():
password = typer.prompt("Aerie Password", hide_input=True)

hs.authenticate(username, password)
hs.authenticate(username, password, force)

return hs
2 changes: 1 addition & 1 deletion tests/integration_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
CONFIGURATIONS_PATH = os.path.join(FILES_PATH, "configuration")
CONFIGURATION_PATH = os.path.join(CONFIGURATIONS_PATH, "localhost_config.json")
MODELS_PATH = os.path.join(FILES_PATH, "models")
MODEL_VERSION = os.environ.get("AERIE_VERSION", "2.11.0")
MODEL_VERSION = os.environ.get("AERIE_VERSION", "2.18.0")
MODEL_JAR = os.path.join(MODELS_PATH, f"banananation-{MODEL_VERSION}.jar")
MODEL_NAME = "banananation"
MODEL_VERSION = "0.0.1"
Expand Down
Binary file not shown.
106 changes: 106 additions & 0 deletions tests/unit_tests/test_aerie_host.py
cartermak marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from typing import Dict
import pytest
import requests

from aerie_cli.aerie_host import AerieHost, COMPATIBLE_AERIE_VERSIONS, AerieJWT


class MockJWT:
def __init__(self, *args, **kwargs):
self.default_role = 'viewer'

class MockResponse:
def __init__(self, json: Dict, text: str = None, ok: bool = True) -> None:
self.json_data = json
self.text = text
self.ok = ok

def json(self) -> Dict:
if self.json_data is None:
raise requests.exceptions.JSONDecodeError("", "", 0)
return self.json_data


class MockSession:

def __init__(self, mock_response: MockResponse) -> None:
self.mock_response = mock_response

def get(self, *args, **kwargs) -> MockResponse:
return self.mock_response

def post(self, *args, **kwargs) -> MockResponse:
return self.mock_response


def get_mock_aerie_host(json: Dict = None, text: str = None, ok: bool = True) -> AerieHost:
mock_response = MockResponse(json, text, ok)
mock_session = MockSession(mock_response)
return AerieHost("", "", mock_session)


def test_check_aerie_version():
aerie_host = get_mock_aerie_host(
json={"version": COMPATIBLE_AERIE_VERSIONS[0]})

aerie_host.check_aerie_version()
Mythicaeda marked this conversation as resolved.
Show resolved Hide resolved


def test_authenticate_invalid_version(capsys, monkeypatch):
ah = AerieHost("", "")

def mock_get(*_, **__):
return MockResponse({"version": "1.0.0"})
def mock_post(*_, **__):
return MockResponse({"token": ""})
def mock_check_auth(*_, **__):
return True

monkeypatch.setattr(requests.Session, "get", mock_get)
monkeypatch.setattr(requests.Session, "post", mock_post)
monkeypatch.setattr(AerieHost, "check_auth", mock_check_auth)
monkeypatch.setattr(AerieJWT, "__init__", MockJWT.__init__)

with pytest.raises(RuntimeError) as e:
ah.authenticate("")

assert "Incompatible Aerie version: 1.0.0" in str(e.value)


def test_authenticate_invalid_version_force(capsys, monkeypatch):
ah = AerieHost("", "")

def mock_get(*_, **__):
return MockResponse({"version": "1.0.0"})
def mock_post(*_, **__):
return MockResponse({"token": ""})
def mock_check_auth(*_, **__):
return True

monkeypatch.setattr(requests.Session, "get", mock_get)
monkeypatch.setattr(requests.Session, "post", mock_post)
monkeypatch.setattr(AerieHost, "check_auth", mock_check_auth)
monkeypatch.setattr(AerieJWT, "__init__", MockJWT.__init__)

ah.authenticate("", force=True)

assert capsys.readouterr().out == "Warning: Incompatible Aerie version: 1.0.0\n"


def test_no_version_endpoint():
aerie_host = get_mock_aerie_host(text="blah Aerie Gateway blah", ok=True)

with pytest.raises(RuntimeError) as e:
aerie_host.check_aerie_version()

assert "Incompatible Aerie version: host version unknown" in str(e.value)


def test_version_broken_gateway():
aerie_host = get_mock_aerie_host(
text="502 Bad Gateway or something", ok=True)

with pytest.raises(RuntimeError) as e:
aerie_host.check_aerie_version()

assert "Bad response from Aerie Gateway" in str(e.value)