Skip to content

Commit

Permalink
Merge pull request #3359 from sadielbartholomew/curvezmq
Browse files Browse the repository at this point in the history
Workflow service network layer auth & encryption
  • Loading branch information
oliver-sanders authored Nov 25, 2019
2 parents b7a3de9 + aa34375 commit 6919987
Show file tree
Hide file tree
Showing 9 changed files with 348 additions and 143 deletions.
1 change: 0 additions & 1 deletion bin/cylc-check-software
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,6 @@ def main():
req_result = (
check_py_ver(*req_py_ver_range)
and check_py_module_ver('zmq', None)
and check_py_module_ver('jose', None)
and check_py_module_ver('graphene', None)
and check_py_module_ver('colorama', None)
and check_py_module_ver('ansimarkup', None)
Expand Down
60 changes: 10 additions & 50 deletions cylc/flow/network/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,59 +13,19 @@
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Network authentication layer."""
"""Standard encode and decode methods for the network authentication layer."""

import getpass
import json

from jose import jwt

from cylc.flow.suite_files import SuiteFiles, get_auth_item
def encode_(message):
"""Convert the structure holding a message field from JSON to a string."""
return json.dumps(message)


HASH = 'HS256' # Encoding for JWT


def get_secret(suite):
"""Return the secret used for encrypting messages.
Currently this is the suite passphrase. This means we are sending
many messages all encrypted with the same hash which isn't great.
TODO: Upgrade the secret to add foreword security.
"""
return get_auth_item(
SuiteFiles.Service.PASSPHRASE,
suite, content=True
)


def decrypt(message, secret):
"""Make a message readable.
Args:
message (str): The message to decode - JWT str.
secret (str): The decrypt key.
Return:
dict - The received message plus a `user` field.
"""
message = jwt.decode(message, secret, algorithms=[HASH])
# if able to decode assume this is the user
message['user'] = getpass.getuser()
return message


def encrypt(message, secret):
"""Make a message unreadable.
Args:
message (dict): The message to send, must be serializable .
secret (str): The encrypt key.
Return:
str - JWT str.
"""
return jwt.encode(message, secret, algorithm=HASH)
def decode_(message):
"""Convert an encoded message string to JSON with an added 'user' field."""
msg = json.loads(message)
msg['user'] = getpass.getuser() # assume this is the user
return msg
95 changes: 57 additions & 38 deletions cylc/flow/network/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,32 +16,32 @@
"""Client for suite runtime API."""

import asyncio
from functools import partial
import os
from shutil import which
import socket
import sys
from functools import partial
from typing import Union

import jose.exceptions
import zmq
import zmq.asyncio

from shutil import which

import cylc.flow.flags
from cylc.flow import LOG
from cylc.flow.exceptions import (
ClientError,
ClientTimeout,
SuiteServiceFileError
)
from cylc.flow.hostuserutil import get_fqdn_by_host
from cylc.flow.network.authentication import encrypt, decrypt, get_secret
from cylc.flow.network.authentication import encode_, decode_
from cylc.flow.network.server import PB_METHOD_MAP
from cylc.flow.suite_files import (
ContactFileFields,
detect_old_contact_file,
load_contact_file
ensure_user_keys_exist,
get_auth_item,
load_contact_file,
UserFiles
)

# we should only have one ZMQ context per-process
Expand All @@ -54,22 +54,13 @@ class ZMQClient(object):
This class contains the logic for the ZMQ message interface and client -
server communication.
NOTE: Security to be provided via the encode / decode interface.
Args:
host (str):
The host to connect to.
port (int):
The port on the aforementioned host to connect to.
encode_method (function):
Translates outgoing messages into strings to be sent over the
network. ``encode_method(json, secret) -> str``
decode_method (function):
Translates incoming message strings into digestible data.
``encode_method(str, secret) -> dict``
secret_method (function):
Return the secret for use with the encode/decode methods.
Called for each encode / decode.
srv_public_key_loc (function):
Return path of server's public key for server communication.
timeout (float):
Set the default timeout in seconds. The default is
``ZMQClient.DEFAULT_TIMEOUT``.
Expand All @@ -93,11 +84,9 @@ class ZMQClient(object):

DEFAULT_TIMEOUT = 5. # 5 seconds

def __init__(self, host, port, encode_method, decode_method, secret_method,
timeout=None, timeout_handler=None, header=None):
self.encode = encode_method
self.decode = decode_method
self.secret = secret_method
def __init__(
self, host, port, srv_public_key_loc, timeout=None,
timeout_handler=None, header=None):
if timeout is None:
timeout = self.DEFAULT_TIMEOUT
else:
Expand All @@ -107,6 +96,40 @@ def __init__(self, host, port, encode_method, decode_method, secret_method,

# open the ZMQ socket
self.socket = CONTEXT.socket(zmq.REQ)

# check for, & create if nonexistent, user keys in the right location
if not ensure_user_keys_exist():
raise ClientError("Unable to generate user authentication keys.")

client_priv_keyfile = os.path.join(
UserFiles.get_user_certificate_full_path(private=True),
UserFiles.Auth.CLIENT_PRIVATE_KEY_CERTIFICATE)
error_msg = "Failed to find user's private key, so cannot connect."
try:
client_public_key, client_priv_key = zmq.auth.load_certificate(
client_priv_keyfile)
except (OSError, ValueError):
raise ClientError(error_msg)
if client_priv_key is None: # this can't be caught by exception
raise ClientError(error_msg)
self.socket.curve_publickey = client_public_key
self.socket.curve_secretkey = client_priv_key

# A client can only connect to the server if it knows its public key,
# so we grab this from the location it was created on the filesystem:
try:
# 'load_certificate' will try to load both public & private keys
# from a provided file but will return None, not throw an error,
# for the latter item if not there (as for all public key files)
# so it is OK to use; there is no method to load only the
# public key.
server_public_key = zmq.auth.load_certificate(
srv_public_key_loc)[0]
self.socket.curve_serverkey = server_public_key
except (OSError, ValueError): # ValueError raised w/ no public key
raise ClientError(
"Failed to load the suite's public key, so cannot connect.")

self.socket.connect('tcp://%s:%d' % (host, port))
# if there is no server don't keep the client hanging around
self.socket.setsockopt(zmq.LINGER, int(self.DEFAULT_TIMEOUT))
Expand All @@ -132,18 +155,16 @@ async def async_request(self, command, args=None, timeout=None):
if not args:
args = {}

# get secret for this request
# assumes secret won't change during the request
try:
secret = self.secret()
except SuiteServiceFileError:
raise ClientError('could not read suite passphrase')
# Note: we are using CurveZMQ to secure the messages (see
# self.curve_auth, self.socket.curve_...key etc.). We have set up
# public-key cryptography on the ZMQ messaging and sockets, so
# there is no need to encrypt messages ourselves before sending.

# send message
msg = {'command': command, 'args': args}
msg.update(self.header)
LOG.debug('zmq:send %s' % msg)
message = encrypt(msg, secret)
message = encode_(msg)
self.socket.send_string(message)

# receive response
Expand All @@ -157,11 +178,7 @@ async def async_request(self, command, args=None, timeout=None):
if msg['command'] in PB_METHOD_MAP:
response = {'data': res}
else:
try:
response = decrypt(res.decode(), secret)
except jose.exceptions.JWTError:
raise ClientError(
'Could not decrypt response. Has the passphrase changed?')
response = decode_(res.decode())
LOG.debug('zmq:recv %s' % response)

try:
Expand Down Expand Up @@ -250,12 +267,14 @@ def __init__(
port = int(port)
if not (host and port):
host, port = self.get_location(suite, owner, host)

super().__init__(
host=host,
port=port,
encode_method=encrypt,
decode_method=decrypt,
secret_method=partial(get_secret, suite),
srv_public_key_loc=get_auth_item(
UserFiles.Auth.SERVER_PUBLIC_KEY_CERTIFICATE, suite,
content=False
),
timeout=timeout,
header=self.get_header(),
timeout_handler=partial(self._timeout_handler, suite, host, port)
Expand Down
Loading

0 comments on commit 6919987

Please sign in to comment.