Skip to content

Commit

Permalink
Adapt mock to utilize the new schema generation
Browse files Browse the repository at this point in the history
This commit adapts the mock to utilize the new schema generation.
Pycapnp is not able to provide the node schema as requiered by
the reflection schema. Therefore we have written a small script
to generate the schema in c++ and export it to bytes-packed.
The scripts creates a python module. The module `hpk_schema` is
hardcoded.
  • Loading branch information
tobiasah committed Feb 4, 2024
1 parent 918fe08 commit 6a491b1
Show file tree
Hide file tree
Showing 10 changed files with 181 additions and 100 deletions.
3 changes: 3 additions & 0 deletions src/labone/core/reflection/capnp_dynamic_type_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ def _build_types_from_node(
if node.name == "":
logger.debug("Skipping node %s because it has no name", node_id)
return
if node.file_of_origin == "capnp/c++.capnp":
logger.debug("Skipping node %s because it is in capnp/c++.capnp", node_id)
return
logger.debug("Loading %s into module %s", node.name, module)
submodule = _build_one_type(node.name, node.schema)
setattr(module, node.name, submodule)
Expand Down
6 changes: 3 additions & 3 deletions src/labone/core/reflection/parsed_wire_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,11 @@ def _load_encoded_schema(
for serialized_node in encoded_schema:
node = self._loader.load_dynamic(serialized_node)
node_proto = node.get_proto()
if ":" not in node_proto.displayName:
continue
splitted_name = node_proto.displayName.split(":")
full_name = splitted_name[1]
name = full_name if "." not in full_name else full_name.split(".")[-1]
loaded_node = LoadedNode(
name=name,
name=splitted_name[1].split(".")[-1],
file_of_origin=splitted_name[0],
schema=node,
)
Expand Down
19 changes: 11 additions & 8 deletions src/labone/mock/entry_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@

from labone.core.reflection.server import ReflectionServer
from labone.core.session import Session
from labone.mock.mock_server import MockServer
from labone.mock.mock_server import start_local_mock
from labone.mock.session_mock_template import SessionMockTemplate

if TYPE_CHECKING:
import capnp

from labone.core.helper import CapnpCapability
from labone.mock.session_mock_template import SessionMockFunctionality

from labone.mock.hpk_schema import get_schema

SESSION_REFLECTION_BIN = Path(__file__).parent.parent / "resources" / "session.bin"


Expand All @@ -29,7 +33,7 @@ class MockSession(Session):

def __init__(
self,
mock_server: MockServer,
mock_server: capnp.TwoPartyServer,
capnp_session: CapnpCapability,
*,
reflection_server: ReflectionServer,
Expand Down Expand Up @@ -61,14 +65,13 @@ async def spawn_hpk_mock(
capnp.lib.capnp.KjException: If the schema is invalid. Or the id
of the concrete server is not in the schema.
"""
mock_server = MockServer(
capability_bytes=SESSION_REFLECTION_BIN,
concrete_server=SessionMockTemplate(functionality),
server, client = await start_local_mock(
schema=get_schema(),
mock=SessionMockTemplate(functionality),
)
client_connection = await mock_server.start()
reflection_client = await ReflectionServer.create_from_connection(client_connection)
reflection_client = await ReflectionServer.create_from_connection(client)
return MockSession(
mock_server,
server,
reflection_client.session, # type: ignore[attr-defined]
reflection_server=reflection_client,
)
113 changes: 113 additions & 0 deletions src/labone/mock/hpk_schema.py

Large diffs are not rendered by default.

118 changes: 41 additions & 77 deletions src/labone/mock/mock_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,16 @@
from __future__ import annotations

import socket
import typing as t
from abc import ABC
from typing import TYPE_CHECKING

import capnp

from labone.core.helper import ensure_capnp_event_loop
from labone.core.helper import CapnpStructReader, ensure_capnp_event_loop
from labone.core.reflection.parsed_wire_schema import ParsedWireSchema
from labone.core.reflection.server import reflection_capnp

if TYPE_CHECKING:
from pathlib import Path

from capnp.lib.capnp import _CallContext, _DynamicStructBuilder, _InterfaceModule
from capnp.lib.capnp import _CallContext, _DynamicStructBuilder


class ServerTemplate(ABC):
Expand All @@ -36,37 +32,46 @@ class ServerTemplate(ABC):
concrete server.
"""

id_: int
server_id: int
type_id: int


def capnp_server_factory( # noqa: ANN201
interface: _InterfaceModule,
def capnp_server_factory(
stream: capnp.AsyncIoStream,
schema: CapnpStructReader,
mock: ServerTemplate,
schema_parsed_dict: dict[str, t.Any],
):
) -> capnp.TwoPartyServer:
"""Dynamically create a capnp server.
As a reflection schema is used, the concrete server interface
is only known at runtime. This function is the
at-runtime-approach to creating the concrete server.
Args:
interface: Capnp interface for the server.
stream: Stream for the server.
schema: Parsed capnp schema (`reflection_capnp.CapSchema`).
mock: The concrete server implementation.
schema_parsed_dict: The parsed capnp schema as a dictionary.
Returns:
Dynamically created capnp server.
"""

class MockServerImpl(interface.Server):
schema_parsed_dict = schema.to_dict()
parsed_schema = ParsedWireSchema(schema.theSchema)
capnp_interface = capnp.lib.capnp._InterfaceModule( # noqa: SLF001
parsed_schema.full_schema[mock.server_id].schema.as_interface(),
parsed_schema.full_schema[mock.server_id].name,
)

class MockServerImpl(capnp_interface.Server): # type: ignore[name-defined]
"""Dynamically created capnp server.
Redirects all calls (except getTheSchema) to the concrete server implementation.
"""

def __init__(self) -> None:
self._mock = mock
# parsed schema needs to stay alive as long as the server is.
self._parsed_schema = parsed_schema

def __getattr__(
self,
Expand Down Expand Up @@ -97,74 +102,33 @@ async def getTheSchema( # noqa: N802
# Use `from_dict` to benefit from pycapnp lifetime management
# Otherwise the underlying capnp object need to be copied manually to avoid
# segfaults
return _context.results.theSchema.from_dict(schema_parsed_dict)
_context.results.theSchema.from_dict(schema_parsed_dict)
_context.results.theSchema.typeId = mock.type_id

return MockServerImpl
return capnp.TwoPartyServer(stream, bootstrap=MockServerImpl())


class MockServer:
"""Abstracr reflection server.
async def start_local_mock(
schema: CapnpStructReader,
mock: ServerTemplate,
) -> tuple[capnp.TwoPartyServer, capnp.AsyncIoStream]:
"""Starting a local mock server.
Takes in another server implementation defining the specific functionality.
This is equivalent to the `capnp_server_factory` but with the addition that
a local socket pair is created for the server.
Args:
capability_bytes: Path to the binary schema file.
concrete_server: ServerTemplate with the actual functionality.
schema: Parsed capnp schema (`reflection_capnp.CapSchema`).
mock: The concrete server implementation.
Returns:
A MockServer instance which can be started with `start`.
Raises:
FileNotFoundError: If the file does not exist.
PermissionError: If the file cannot be read.
capnp.lib.capnp.KjException: If the schema is invalid. Or the id
of the concrete server is not in the schema.
The server and the client connection.
"""

def __init__(
self,
*,
capability_bytes: Path,
concrete_server: ServerTemplate,
):
self._concrete_server = concrete_server
with capability_bytes.open("rb") as f:
schema_bytes = f.read()
with reflection_capnp.CapSchema.from_bytes(schema_bytes) as schema:
self._schema_parsed_dict = schema.to_dict()
self._schema = ParsedWireSchema(schema.theSchema)
self._capnp_interface = capnp.lib.capnp._InterfaceModule( # noqa: SLF001
self._schema.full_schema[concrete_server.id_].schema.as_interface(),
self._schema.full_schema[concrete_server.id_].name,
)
self._server = None

async def start(self) -> capnp.AsyncIoStream:
"""Starting the server and returning the client connection.
Returns:
The client connection.
Raises:
RuntimeError: If the server is already started.
"""
if self._server is not None: # pragma: no cover
msg = "Server already started." # pragma: no cover
raise RuntimeError(msg) # pragma: no cover
await ensure_capnp_event_loop()
# create local socket pair
# Since there is only a single client there is no need to use a asyncio server
read, write = socket.socketpair()
reader = await capnp.AsyncIoStream.create_connection(sock=read)
writer = await capnp.AsyncIoStream.create_connection(sock=write)
# create server for the local socket pair
self._server = capnp.TwoPartyServer(
writer,
bootstrap=capnp_server_factory(
self._capnp_interface,
self._concrete_server,
self._schema_parsed_dict,
)(),
)
return reader
await ensure_capnp_event_loop()
# create local socket pair
# Since there is only a single client there is no need to use a asyncio server
read, write = socket.socketpair()
reader = await capnp.AsyncIoStream.create_connection(sock=read)
writer = await capnp.AsyncIoStream.create_connection(sock=write)
# create server for the local socket pair
return capnp_server_factory(writer, schema, mock), reader
5 changes: 3 additions & 2 deletions src/labone/mock/session_mock_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@


HPK_SCHEMA_ID = 11970870220622790664
SESSION_SCHEMA_ID = 13390403837104530780
SERVER_ERROR = "SERVER_ERROR"


Expand Down Expand Up @@ -249,8 +250,8 @@ class SessionMockTemplate(ServerTemplate):
functionality: The implementation of the mock server behavior.
"""

# unique capnp id of the Hpk schema
id_ = HPK_SCHEMA_ID
server_id = HPK_SCHEMA_ID
type_id = SESSION_SCHEMA_ID

def __init__(self, functionality: SessionMockFunctionality) -> None:
self._functionality = functionality
Expand Down
3 changes: 0 additions & 3 deletions src/labone/nodetree/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,12 @@ def __setitem__(self, key: str, item: T | NestedDict) -> None: ...

def keys(self) -> t.KeysView[str]:
"""..."""
...

def items(self) -> t.ItemsView[str, T | NestedDict[T]]:
"""..."""
...

def __iter__(self) -> t.Iterator[str]:
"""..."""
...


FlatPathDict: TypeAlias = t.Dict[
Expand Down
Binary file removed src/labone/resources/session.bin
Binary file not shown.
2 changes: 1 addition & 1 deletion tests/core/test_annotated_value_to_capnp.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def test_value_from_python_types_vector_data_complex_double(reflection_server, i
assert vec_data.data == inp.tobytes()


@given(arrays(dtype=(np.string_), shape=(1, 2)))
@given(arrays(dtype=(np.bytes_), shape=(1, 2)))
@settings(suppress_health_check=(HealthCheck.function_scoped_fixture,))
def test_value_from_python_types_vector_data_invalid(reflection_server, inp):
with pytest.raises(ValueError):
Expand Down
12 changes: 6 additions & 6 deletions tests/core/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@
from labone.core.subscription import DataQueue
from labone.core.value import AnnotatedValue
from labone.mock import AutomaticSessionFunctionality, spawn_hpk_mock
from labone.mock.entry_point import SESSION_REFLECTION_BIN, MockSession
from labone.mock.mock_server import MockServer
from labone.mock.entry_point import MockSession
from labone.mock.hpk_schema import get_schema
from labone.mock.mock_server import start_local_mock
from labone.mock.session_mock_template import SessionMockTemplate

from .resources import session_protocol_capnp, testfile_capnp, value_capnp
Expand Down Expand Up @@ -974,11 +975,10 @@ async def getSessionVersion(self, _context): # noqa: N802
)
@pytest.mark.asyncio()
async def test_ensure_compatibility_mismatch(version, should_fail):
mock_server = MockServer(
capability_bytes=SESSION_REFLECTION_BIN,
concrete_server=DummyServerVersionTest(version),
mock_server, client_connection = await start_local_mock(
schema=get_schema(),
mock=DummyServerVersionTest(version),
)
client_connection = await mock_server.start()
reflection_client = await ReflectionServer.create_from_connection(client_connection)
session = MockSession(
mock_server,
Expand Down

0 comments on commit 6a491b1

Please sign in to comment.