Skip to content

Commit

Permalink
Add encryption between listener and processes
Browse files Browse the repository at this point in the history
  • Loading branch information
gregorjerse committed Jun 5, 2024
1 parent 034c3b9 commit 3ae6fb2
Show file tree
Hide file tree
Showing 16 changed files with 545 additions and 25 deletions.
2 changes: 2 additions & 0 deletions docs/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ Changed
- **BACKWARD INCOMPATIBLE:** Make move between collections a background job
- **BACKWARD INCOMPATIBLE:** Remove support for ``Python 3.10``
- Use ``simple_unaccent`` full text search configuration instead of ``simple``
- Authenticate worker with the listener (and vice versa) and encrypt the
communication between them using ``CurveZMQ`` protocol


===================
Expand Down
9 changes: 9 additions & 0 deletions resolwe/flow/executors/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import argparse
import asyncio
import logging
import os
import sys
import traceback
from contextlib import suppress
Expand All @@ -45,6 +46,11 @@

logger = logging.getLogger(__name__)

# Secrets necessary to connect to the listener service.
LISTENER_PUBLIC_KEY = os.getenv("LISTENER_PUBLIC_KEY").encode()
PUBLIC_KEY = os.getenv("CURVE_PUBLIC_KEY").encode()
PRIVATE_KEY = os.getenv("CURVE_PRIVATE_KEY").encode()


def handle_exception(exc_type, exc_value, exc_traceback):
"""Log unhandled exceptions."""
Expand All @@ -70,6 +76,9 @@ async def open_listener_connection(data_id, host, port, protocol) -> ZMQCommunic
"""Connect to the listener service."""
zmq_context = zmq.asyncio.Context.instance()
zmq_socket = zmq_context.socket(zmq.DEALER)
zmq_socket.curve_secretkey = PRIVATE_KEY
zmq_socket.curve_publickey = PUBLIC_KEY
zmq_socket.curve_serverkey = LISTENER_PUBLIC_KEY
zmq_socket.setsockopt(zmq.IDENTITY, f"-{data_id}".encode())
connect_string = f"{protocol}://{host}:{port}"
zmq_socket.connect(connect_string)
Expand Down
8 changes: 8 additions & 0 deletions resolwe/flow/executors/docker/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@
DOCKER_MEMORY_SWAP_RATIO = 2
DOCKER_MEMORY_SWAPPINESS = 1

# Secrets necessary to connect to the listener service.
LISTENER_PUBLIC_KEY = os.getenv("LISTENER_PUBLIC_KEY")
PUBLIC_KEY = os.getenv("CURVE_PUBLIC_KEY")
PRIVATE_KEY = os.getenv("CURVE_PRIVATE_KEY")

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -340,6 +345,9 @@ async def start(self):
)

environment = {
"LISTENER_PUBLIC_KEY": LISTENER_PUBLIC_KEY,
"CURVE_PUBLIC_KEY": PUBLIC_KEY,
"CURVE_PRIVATE_KEY": PRIVATE_KEY,
"LISTENER_SERVICE_HOST": self.listener_connection[0],
"LISTENER_SERVICE_PORT": self.listener_connection[1],
"LISTENER_PROTOCOL": self.listener_connection[2],
Expand Down
8 changes: 8 additions & 0 deletions resolwe/flow/executors/init_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@
LISTENER_PORT = os.getenv("LISTENER_SERVICE_PORT", "53893")
LISTENER_PROTOCOL = os.getenv("LISTENER_PROTOCOL", "tcp")

# Secrets necessary to connect to the listener service.
LISTENER_PUBLIC_KEY = os.getenv("LISTENER_PUBLIC_KEY").encode()
PUBLIC_KEY = os.getenv("CURVE_PUBLIC_KEY").encode()
PRIVATE_KEY = os.getenv("CURVE_PRIVATE_KEY").encode()

DATA_ID = int(os.getenv("DATA_ID", "-1"))


Expand Down Expand Up @@ -206,6 +211,9 @@ def _get_communicator() -> ZMQCommunicator:
"""Connect to the listener."""
zmq_context = zmq.asyncio.Context.instance()
zmq_socket = zmq_context.socket(zmq.DEALER)
zmq_socket.curve_secretkey = PRIVATE_KEY
zmq_socket.curve_publickey = PUBLIC_KEY
zmq_socket.curve_serverkey = LISTENER_PUBLIC_KEY
zmq_socket.setsockopt(zmq.IDENTITY, str(DATA_ID).encode())
connect_string = f"{LISTENER_PROTOCOL}://{LISTENER_IP}:{LISTENER_PORT}"
logger.debug("Opening connection to %s", connect_string)
Expand Down
12 changes: 9 additions & 3 deletions resolwe/flow/executors/socket_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ async def async_send_data(

async def async_receive_data(
reader: asyncio.StreamReader, size_bytes: int = 8
) -> Optional[Tuple[PeerIdentity, Any]]:
) -> Tuple[PeerIdentity, Any, Optional[bytes]]:
"""Receive data from the reader.
The data is expected to be bytes-encoded JSON representation of a Python
Expand All @@ -227,7 +227,7 @@ async def async_receive_data(
message_size = int.from_bytes(received, byteorder="big")
received = await reader.readexactly(message_size)
assert len(received) == message_size
return (b"", json.loads(received.decode("utf-8")))
return (b"", json.loads(received.decode("utf-8")), None)


class Message(Generic[MessageDataType]):
Expand All @@ -240,6 +240,7 @@ def __init__(
message_data: MessageDataType,
message_uuid: Optional[str] = None,
sent_timestamp: Optional[float] = None,
client_id: Optional[bytes] = None,
):
"""Initialize.
Expand All @@ -250,6 +251,7 @@ def __init__(
self.type_data = type_data
self.uuid = message_uuid or self._get_random_message_identifier()
self.sent_timestamp = sent_timestamp
self.client_id = client_id

def _get_random_message_identifier(self) -> str:
"""Get a random message identifier.
Expand Down Expand Up @@ -314,13 +316,15 @@ def command(
command_name: str,
message_data: MessageDataType,
message_uuid: Optional[str] = None,
client_id: Optional[bytes] = None,
) -> "Message[MessageDataType]":
"""Construct and return a command."""
return Message(
MessageType.COMMAND,
command_name,
message_data,
message_uuid,
client_id=client_id,
)

@staticmethod
Expand Down Expand Up @@ -370,6 +374,7 @@ def from_dict(message_dict: Dict) -> "Message":
message_dict["data"],
message_dict.get("uuid"),
message_dict.get("timestamp"),
message_dict.get("client_id"),
)

def to_dict(self) -> dict:
Expand Down Expand Up @@ -591,8 +596,9 @@ async def _receive_message(self) -> Optional[Tuple[PeerIdentity, Message]]:
received = None
if received is not None:
assert isinstance(received, tuple)
assert len(received) == 2
assert len(received) == 3
assert isinstance(received[0], bytes)
received[1]["client_id"] = received[2]
assert Message.is_valid(received[1])
result = received[0], Message.from_dict(received[1])
else:
Expand Down
8 changes: 8 additions & 0 deletions resolwe/flow/executors/startup_communication_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@
strtobool(os.environ.get("RUNNING_IN_KUBERNETES", "False"))
)

# Secrets necessary to connect to the listener service.
LISTNER_PUBLIC_KEY = os.getenv("LISTENER_PUBLIC_KEY").encode()
PUBLIC_KEY = os.getenv("CURVE_PUBLIC_KEY").encode()
PRIVATE_KEY = os.getenv("CURVE_PRIVATE_KEY").encode()

# How many file descriptors to receive over socket in a single message.
DESCRIPTOR_CHUNK_SIZE = int(os.environ.get("DESCRIPTOR_CHUNK_SIZE", 100))

Expand Down Expand Up @@ -578,6 +583,9 @@ async def open_listener_connection(self) -> ZMQCommunicator:
"""
zmq_context = zmq.asyncio.Context.instance()
zmq_socket = zmq_context.socket(zmq.DEALER)
zmq_socket.curve_secretkey = PRIVATE_KEY
zmq_socket.curve_publickey = PUBLIC_KEY
zmq_socket.curve_serverkey = LISTNER_PUBLIC_KEY
zmq_socket.setsockopt(zmq.IDENTITY, str(DATA_ID).encode())
connect_string = f"{LISTENER_PROTOCOL}://{LISTENER_IP}:{LISTENER_PORT}"
logger.debug("Opening listener connection to '%s'.", connect_string)
Expand Down
53 changes: 47 additions & 6 deletions resolwe/flow/executors/zeromq_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
"""Utils for working with zeromq."""

import json
import os
from contextlib import suppress
from logging import Logger
from threading import Lock
from typing import Any, Optional, Tuple

import zmq
import zmq.asyncio
from zmq.auth.asyncio import AsyncioAuthenticator

from .socket_utils import BaseCommunicator, PeerIdentity

Expand Down Expand Up @@ -34,7 +38,7 @@ async def async_zmq_send_data(

async def async_zmq_receive_data(
reader: zmq.asyncio.Socket,
) -> Optional[Tuple[PeerIdentity, Any]]:
) -> Tuple[PeerIdentity, Any, Optional[bytes]]:
"""Receive data from the reader.
The data is expected to be bytes-encoded JSON representation of a Python object.
Expand All @@ -45,13 +49,16 @@ async def async_zmq_receive_data(
:raises zmq.ZMQError: on receive error.
"""
user_id = None
if reader.socket_type == zmq.DEALER:
identity = reader.getsockopt(zmq.IDENTITY)
message = await reader.recv()
identity = str(reader.getsockopt(zmq.IDENTITY)).encode()
message = await reader.recv(copy=False)
else:
identity, message = await reader.recv_multipart()
decoded = json.loads(message.decode())
return (identity, decoded)
received_identity, message = await reader.recv_multipart(copy=False)
identity = received_identity.bytes
user_id = str(message["User-Id"]).encode()
decoded = json.loads(message.bytes.decode())
return (identity, decoded, user_id)


class ZMQCommunicator(BaseCommunicator):
Expand All @@ -72,3 +79,37 @@ def __init__(
async_zmq_send_data,
async_zmq_receive_data,
)


class ZMQAuthenticator(AsyncioAuthenticator):
"""The singleton authenticator."""

_instance = None
_instance_lock = Lock()
_instance_pid: int | None = None

@classmethod
def has_instance(cls):
"""Check if the instance exists."""
return not (cls._instance is None or cls._instance_pid != os.getpid())

@classmethod
def instance(cls, context=None):
"""Return a global ZMQAuthenticator instance."""
if not cls.has_instance():
with cls._instance_lock:
if not cls.has_instance():
cls._instance = cls(context=context)
cls._instance_pid = os.getpid()
return cls._instance

def start(self):
"""Ignore possible exception when testing."""
# The is_testing is not available in the executor so it is imported here.
from resolwe.test.utils import is_testing

if is_testing():
with suppress(zmq.error.ZMQError):
super().start()
else:
super().start()
9 changes: 8 additions & 1 deletion resolwe/flow/managers/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from channels.db import database_sync_to_async
from channels.exceptions import ChannelFull
from zmq import curve_keypair

from django.conf import settings
from django.core.exceptions import ImproperlyConfigured, PermissionDenied
Expand Down Expand Up @@ -249,7 +250,13 @@ def _prepare_data_dir(self, data: Data):
with transaction.atomic():
# Create Worker object and set its status to preparing if needed.
if not Worker.objects.filter(data=data).exists():
Worker.objects.get_or_create(data=data, status=Worker.STATUS_PREPARING)
public_key, private_key = curve_keypair()
Worker.objects.get_or_create(
data=data,
status=Worker.STATUS_PREPARING,
public_key=public_key,
private_key=private_key,
)

file_storage = FileStorage.objects.create()
# Data produced by the processing container will be uploaded to the
Expand Down
Loading

0 comments on commit 3ae6fb2

Please sign in to comment.