Skip to content

Commit

Permalink
Add environment ID to tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ZohebShaikh committed Jan 17, 2025
1 parent 62d2188 commit 2005a4f
Show file tree
Hide file tree
Showing 13 changed files with 105 additions and 51 deletions.
7 changes: 7 additions & 0 deletions docs/reference/openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ components:
additionalProperties: false
description: State of internal environment.
properties:
environment_id:
description: Unique ID for the environment instance, can be used to differentiate
between a new environment and old that has been torn down
format: uuid
title: Environment Id
type: string
error_message:
anyOf:
- minLength: 1
Expand All @@ -49,6 +55,7 @@ components:
title: Initialized
type: boolean
required:
- environment_id
- initialized
title: EnvironmentResponse
type: object
Expand Down
6 changes: 3 additions & 3 deletions src/blueapi/cli/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def fmt_dict(t: dict[str, Any] | Any, ind: int = 1) -> str:
if not isinstance(t, dict):
return f" {t}"
pre = " " * (ind * 4)
return NL + NL.join(f"{pre}{k}:{fmt_dict(v, ind+1)}" for k, v in t.items() if v)
return NL + NL.join(f"{pre}{k}:{fmt_dict(v, ind + 1)}" for k, v in t.items() if v)


class OutputFormat(str, enum.Enum):
Expand Down Expand Up @@ -126,14 +126,14 @@ def _describe_type(spec: dict[Any, Any], required: bool = False):
case None:
if all_of := spec.get("allOf"):
items = (_describe_type(f, False) for f in all_of)
disp += f'{" & ".join(items)}'
disp += f"{' & '.join(items)}"
elif any_of := spec.get("anyOf"):
items = (_describe_type(f, False) for f in any_of)

# Special case: Where the type is <something> | null,
# we should just print <something>
items = (item for item in items if item != "null" or len(any_of) != 2)
disp += f'{" | ".join(items)}'
disp += f"{' | '.join(items)}"
else:
disp += "Any"
case "array":
Expand Down
3 changes: 1 addition & 2 deletions src/blueapi/cli/scratch.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ def ensure_repo(remote_url: str, local_directory: Path) -> None:
logging.info(f"Found {local_directory}")
else:
raise KeyError(
f"Unable to open {local_directory} as a git repository because "
"it is a file"
f"Unable to open {local_directory} as a git repository because it is a file"
)


Expand Down
2 changes: 2 additions & 0 deletions src/blueapi/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,9 @@ def reload_environment(
"""

try:
# _ = self._rest.get_environment()
status = self._rest.delete_environment()

except Exception as e:
raise BlueskyRemoteControlError(
"Failed to tear down the environment"
Expand Down
5 changes: 3 additions & 2 deletions src/blueapi/service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,11 @@ async def delete_environment(
runner: WorkerDispatcher = Depends(_runner),
) -> EnvironmentResponse:
"""Delete the current environment, causing internal components to be reloaded."""

if runner.state.initialized or runner.state.error_message is not None:
background_tasks.add_task(runner.reload)
return EnvironmentResponse(initialized=False)
return EnvironmentResponse(
environment_id=runner.state.environment_id, initialized=False
)


@auth_router.get("/config/oidc", tags=["auth"], response_model=OIDCConfig)
Expand Down
4 changes: 2 additions & 2 deletions src/blueapi/service/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import uuid
from collections.abc import Iterable
from typing import Any
from uuid import UUID

from bluesky.protocols import HasName
from pydantic import Field
Expand Down Expand Up @@ -145,7 +145,7 @@ class EnvironmentResponse(BlueapiBaseModel):
State of internal environment.
"""

environment_id: UUID = Field(
environment_id: uuid.UUID = Field(
description="Unique ID for the environment instance, can be used to "
"differentiate between a new environment and old that has been torn down"
)
Expand Down
5 changes: 2 additions & 3 deletions tests/system_tests/test_blueapi_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
)
from blueapi.service.model import (
DeviceResponse,
EnvironmentResponse,
PlanResponse,
TaskResponse,
WorkerTask,
Expand Down Expand Up @@ -335,9 +334,9 @@ def on_event(event: AnyEvent):


def test_get_current_state_of_environment(client: BlueapiClient):
assert client.get_environment() == EnvironmentResponse(initialized=True)
assert client.get_environment().initialized


def test_delete_current_environment(client: BlueapiClient):
client.reload_environment()
assert client.get_environment() == EnvironmentResponse(initialized=True)
assert client.get_environment().initialized
38 changes: 23 additions & 15 deletions tests/unit_tests/client/test_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import uuid
from collections.abc import Callable
from unittest.mock import MagicMock, Mock, call, patch

Expand Down Expand Up @@ -42,7 +43,8 @@
TASK = TrackableTask(task_id="foo", task=Task(name="bar", params={}))
TASKS = TasksListResponse(tasks=[TASK])
ACTIVE_TASK = WorkerTask(task_id="bar")
ENV = EnvironmentResponse(initialized=True)
ENVIRONMENT_ID = uuid.uuid4()
ENV = EnvironmentResponse(environment_id=ENVIRONMENT_ID, initialized=True)
COMPLETE_EVENT = WorkerEvent(
state=WorkerState.IDLE,
task_status=TaskStatus(
Expand Down Expand Up @@ -74,7 +76,9 @@ def mock_rest() -> BlueapiRestClient:
mock.get_all_tasks.return_value = TASKS
mock.get_active_task.return_value = ACTIVE_TASK
mock.get_environment.return_value = ENV
mock.delete_environment.return_value = EnvironmentResponse(initialized=False)
mock.delete_environment.return_value = EnvironmentResponse(
environment_id=ENVIRONMENT_ID, initialized=False
)

return mock

Expand Down Expand Up @@ -268,10 +272,10 @@ def test_reload_environment_no_timeout(
mock_rest: Mock,
):
mock_rest.get_environment.side_effect = [
EnvironmentResponse(initialized=False),
EnvironmentResponse(initialized=False),
EnvironmentResponse(initialized=False),
EnvironmentResponse(initialized=True),
EnvironmentResponse(environment_id=ENVIRONMENT_ID, initialized=False),
EnvironmentResponse(environment_id=ENVIRONMENT_ID, initialized=False),
EnvironmentResponse(environment_id=ENVIRONMENT_ID, initialized=False),
EnvironmentResponse(environment_id=ENVIRONMENT_ID, initialized=True),
]
mock_time.return_value = 100.0
client.reload_environment(timeout=None)
Expand All @@ -287,10 +291,10 @@ def test_reload_environment_with_timeout(
mock_rest: Mock,
):
mock_rest.get_environment.side_effect = [
EnvironmentResponse(initialized=False),
EnvironmentResponse(initialized=False),
EnvironmentResponse(initialized=False),
EnvironmentResponse(initialized=False),
EnvironmentResponse(environment_id=ENVIRONMENT_ID, initialized=False),
EnvironmentResponse(environment_id=ENVIRONMENT_ID, initialized=False),
EnvironmentResponse(environment_id=ENVIRONMENT_ID, initialized=False),
EnvironmentResponse(environment_id=ENVIRONMENT_ID, initialized=False),
]
mock_time.side_effect = [
100.0,
Expand All @@ -315,10 +319,14 @@ def test_reload_environment_ignores_current_environment(
mock_rest: Mock,
):
mock_rest.get_environment.side_effect = [
EnvironmentResponse(initialized=True), # This is the old environment
EnvironmentResponse(initialized=False),
EnvironmentResponse(initialized=False),
EnvironmentResponse(initialized=True), # This is the new environment
EnvironmentResponse(
environment_id=ENVIRONMENT_ID, initialized=True
), # This is the old environment
EnvironmentResponse(environment_id=ENVIRONMENT_ID, initialized=False),
EnvironmentResponse(environment_id=ENVIRONMENT_ID, initialized=False),
EnvironmentResponse(
environment_id=ENVIRONMENT_ID, initialized=True
), # This is the new environment
]
mock_time.return_value = 100.0
client.reload_environment(timeout=None)
Expand All @@ -330,7 +338,7 @@ def test_reload_environment_failure(
mock_rest: Mock,
):
mock_rest.get_environment.return_value = EnvironmentResponse(
initialized=False, error_message="foo"
environment_id=ENVIRONMENT_ID, initialized=False, error_message="foo"
)
with pytest.raises(BlueskyRemoteControlError, match="foo"):
client.reload_environment()
Expand Down
19 changes: 15 additions & 4 deletions tests/unit_tests/client/test_rest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import uuid
from pathlib import Path
from unittest.mock import Mock, patch

Expand Down Expand Up @@ -53,16 +54,21 @@ def test_auth_request_functionality(
mock_authn_server: responses.RequestsMock,
cached_valid_token: Path,
):
environment_id = uuid.uuid4()
mock_authn_server.stop() # Cannot use multiple RequestsMock context manager
mock_get_env = mock_authn_server.get(
"http://localhost:8000/environment",
json=EnvironmentResponse(initialized=True).model_dump(),
json=EnvironmentResponse(
environment_id=environment_id, initialized=True
).model_dump(mode="json"),
status=200,
)
result = None
with mock_authn_server:
result = rest_with_auth.get_environment()
assert result == EnvironmentResponse(initialized=True)
assert result == EnvironmentResponse(
environment_id=environment_id, initialized=True
)
calls = mock_get_env.calls
assert len(calls) == 1
cacheManager = SessionCacheManager(cached_valid_token)
Expand All @@ -75,16 +81,21 @@ def test_refresh_if_signature_expired(
mock_authn_server: responses.RequestsMock,
cached_valid_refresh: Path,
):
environment_id = uuid.uuid4()
mock_authn_server.stop() # Cannot use multiple RequestsMock context manager
mock_get_env = mock_authn_server.get(
"http://localhost:8000/environment",
json=EnvironmentResponse(initialized=True).model_dump(),
json=EnvironmentResponse(
environment_id=environment_id, initialized=True
).model_dump(mode="json"),
status=200,
)
result = None
with mock_authn_server:
result = rest_with_auth.get_environment()
assert result == EnvironmentResponse(initialized=True)
assert result == EnvironmentResponse(
environment_id=environment_id, initialized=True
)
calls = mock_get_env.calls
assert len(calls) == 1
assert calls[0].request.headers["Authorization"] == "Bearer new_token"
3 changes: 3 additions & 0 deletions tests/unit_tests/service/test_rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,12 +535,15 @@ def test_set_state_invalid_transition(mock_runner: Mock, client: TestClient):


def test_get_environment_idle(mock_runner: Mock, client: TestClient) -> None:
environment_id = uuid.uuid4()
mock_runner.state = EnvironmentResponse(
environment_id=environment_id,
initialized=True,
error_message=None,
)

assert client.get("/environment").json() == {
"environment_id": str(environment_id),
"initialized": True,
"error_message": None,
}
Expand Down
15 changes: 11 additions & 4 deletions tests/unit_tests/service/test_runner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import uuid
from multiprocessing.pool import Pool as PoolClass
from typing import Any, Generic, TypeVar
from unittest.mock import MagicMock, Mock, patch
Expand Down Expand Up @@ -72,7 +73,9 @@ def test_raises_if_used_before_started(runner: WorkerDispatcher):

def test_error_on_runner_setup(runner: WorkerDispatcher, mock_subprocess: Mock):
error_message = "Intentional start_worker exception"
environment_id = uuid.uuid4()
expected_state = EnvironmentResponse(
environment_id=environment_id,
initialized=False,
error_message=error_message,
)
Expand All @@ -83,6 +86,7 @@ def test_error_on_runner_setup(runner: WorkerDispatcher, mock_subprocess: Mock):
# and the runner is not yet initialised
runner.reload()
state = runner.state
expected_state.environment_id = state.environment_id
assert state == expected_state


Expand Down Expand Up @@ -110,14 +114,17 @@ def test_can_reload_after_an_error(pool_mock: MagicMock):

runner = WorkerDispatcher()
runner.start()

current_env = runner.state.environment_id
assert runner.state == EnvironmentResponse(
initialized=False, error_message="invalid code"
environment_id=current_env, initialized=False, error_message="invalid code"
)

runner.reload()

assert runner.state == EnvironmentResponse(initialized=True, error_message=None)
new_env = runner.state.environment_id
assert runner.state == EnvironmentResponse(
environment_id=new_env, initialized=True, error_message=None
)
assert current_env != new_env


@patch("blueapi.service.runner.Pool")
Expand Down
Loading

0 comments on commit 2005a4f

Please sign in to comment.