Skip to content

Commit

Permalink
Discard malformed traffic sent to node-to-node ports (#5889)
Browse files Browse the repository at this point in the history
  • Loading branch information
eddyashton authored Jan 9, 2024
1 parent 6734750 commit e618e51
Show file tree
Hide file tree
Showing 7 changed files with 175 additions and 10 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).

## [5.0.0-dev12]

[5.0.0-dev12]: https://github.com/microsoft/CCF/releases/tag/ccf-5.0.0-dev12

### Fixed

- Nodes are now more robust to unexpected traffic on node-to-node ports (#5889).

## [5.0.0-dev11]

[5.0.0-dev11]: https://github.com/microsoft/CCF/releases/tag/ccf-5.0.0-dev11
Expand Down
10 changes: 9 additions & 1 deletion src/enclave/enclave.h
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,15 @@ namespace ccf

DISPATCHER_SET_MESSAGE_HANDLER(
bp, ccf::node_inbound, [this](const uint8_t* data, size_t size) {
node->recv_node_inbound(data, size);
try
{
node->recv_node_inbound(data, size);
}
catch (const std::exception& e)
{
LOG_DEBUG_FMT(
"Ignoring node_inbound message due to exception: {}", e.what());
}
});

DISPATCHER_SET_MESSAGE_HANDLER(
Expand Down
35 changes: 32 additions & 3 deletions src/host/node_connections.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ namespace asynchost
node(node)
{}

void on_read(size_t len, uint8_t*& incoming, sockaddr)
bool on_read(size_t len, uint8_t*& incoming, sockaddr) override
{
LOG_DEBUG_FMT(
"from node {} received {} bytes",
Expand Down Expand Up @@ -71,8 +71,35 @@ namespace asynchost
}

const auto size_pre_headers = size;
auto msg_type = serialized::read<ccf::NodeMsgType>(data, size);
ccf::NodeId from = serialized::read<ccf::NodeId::Value>(data, size);

ccf::NodeMsgType msg_type;
try
{
msg_type = serialized::read<ccf::NodeMsgType>(data, size);
}
catch (const std::exception& e)
{
LOG_DEBUG_FMT(
"Received invalid node-to-node traffic. Unable to read message "
"type ({}). Closing connection.",
e.what());
return false;
}

ccf::NodeId from;
try
{
from = serialized::read<ccf::NodeId::Value>(data, size);
}
catch (const std::exception& e)
{
LOG_DEBUG_FMT(
"Received invalid node-to-node traffic. Unable to read sender "
"node ID ({}). Closing connection.",
e.what());
return false;
}

const auto size_post_headers = size;
const size_t payload_size =
msg_size.value() - (size_pre_headers - size_post_headers);
Expand Down Expand Up @@ -107,6 +134,8 @@ namespace asynchost
{
pending.erase(pending.begin(), pending.begin() + used);
}

return true;
}

virtual void associate_incoming(const ccf::NodeId&) {}
Expand Down
8 changes: 6 additions & 2 deletions src/host/rpc_connections.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ namespace asynchost
cleanup();
}

void on_read(size_t len, uint8_t*& data, sockaddr) override
bool on_read(size_t len, uint8_t*& data, sockaddr) override
{
LOG_DEBUG_FMT("rpc read {}: {}", id, len);

Expand All @@ -125,6 +125,8 @@ namespace asynchost
parent.to_enclave,
id,
serializer::ByteRange{data, len});

return true;
}

void on_disconnect() override
Expand Down Expand Up @@ -195,7 +197,7 @@ namespace asynchost
}
}

void on_read(size_t len, uint8_t*& data, sockaddr addr) override
bool on_read(size_t len, uint8_t*& data, sockaddr addr) override
{
// UDP connections don't have clients, it's all done in the server
if constexpr (isUDP<ConnType>())
Expand All @@ -211,6 +213,8 @@ namespace asynchost
addr_data,
serializer::ByteRange{data, len});
}

return true;
}

void cleanup()
Expand Down
6 changes: 5 additions & 1 deletion src/host/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ namespace asynchost
virtual ~SocketBehaviour() {}

/// To be implemented by clients
virtual void on_read(size_t, uint8_t*&, sockaddr) {}
/// Return false to immediately disconnect socket.
virtual bool on_read(size_t, uint8_t*&, sockaddr)
{
return true;
}

/// To be implemented by servers with connections
virtual void on_accept(ConnType&) {}
Expand Down
8 changes: 7 additions & 1 deletion src/host/tcp.h
Original file line number Diff line number Diff line change
Expand Up @@ -789,12 +789,18 @@ namespace asynchost
}

uint8_t* p = (uint8_t*)buf->base;
behaviour->on_read((size_t)sz, p, {});
const bool read_good = behaviour->on_read((size_t)sz, p, {});

if (p != nullptr)
{
on_free(buf);
}

if (!read_good)
{
behaviour->on_disconnect();
return;
}
}

static void on_write(uv_write_t* req, int)
Expand Down
110 changes: 108 additions & 2 deletions tests/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import functools
import httpx
import os
import socket
import struct
from infra.snp import IS_SNP

from loguru import logger as LOG
Expand Down Expand Up @@ -51,7 +53,7 @@ def interface_caps(i):
}


def run(args):
def run_connection_caps_tests(args):
# Listen on additional RPC interfaces with even lower session caps
for i, node_spec in enumerate(args.nodes):
caps = interface_caps(i)
Expand Down Expand Up @@ -262,14 +264,118 @@ def create_connections_until_exhaustion(
LOG.warning("Expected a fatal crash and saw none!")


@contextlib.contextmanager
def node_tcp_socket(node):
interface = node.n2n_interface
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.connect((interface.host, interface.port))
yield s
s.close()


def run_node_socket_robustness_tests(args):
with infra.network.network(
args.nodes, args.binary_dir, args.debug_nodes, args.perf_nodes, pdb=args.pdb
) as network:
network.start_and_open(args)

primary, _ = network.find_nodes()

# Protocol is:
# - 4 byte message size N (remainder is not processed until this many bytes arrive)
# - 8 byte message type (valid values are only 0, 1, or 2)
# - Sender node ID (length-prefixed string), consisting of:
# - 8 byte string length S
# - S bytes of string content
# - Message body, of N - 16 - S bytes
# Note number serialization is little-endian!

def encode_msg(
msg_type=0,
sender="OtherNode",
body=b"",
sender_len_override=None,
total_len_override=None,
):
b_type = struct.pack("<Q", msg_type)
sender_len = sender_len_override or len(sender)
b_sender = struct.pack("<Q", sender_len) + sender.encode()
total_len = total_len_override or len(b_type) + len(b_sender) + len(body)
b_size = struct.pack("<I", total_len)
encoded_msg = b_size + b_type + b_sender + body
return encoded_msg

def try_write(msg_bytes):
with node_tcp_socket(primary) as sock:
LOG.debug(
f"Sending raw TCP bytes to {primary.local_node_id}'s node-to-node port: {msg_bytes}"
)
sock.send(msg_bytes)
assert (
not primary.remote.check_done()
), f"Crashed node with N2N message: {msg_bytes}"
LOG.success(f"Node {primary.local_node_id} tolerated this message")

LOG.info("Sending messages which do not contain initial size")
try_write(b"")
try_write(b"\x00")
for size in range(1, 4):
# NB: Regardless of what these bytes contain!
for i in range(5):
msg = random.getrandbits(8 * size).to_bytes(size, byteorder="little")
try_write(msg)

LOG.info("Sending messages which do not contain initial header")
for size in range(0, 16):
try_write(struct.pack("<I", size) + b"\x00" * size)

LOG.info("Sending plausible messages")
try_write(encode_msg())
try_write(encode_msg(msg_type=1))
try_write(encode_msg(msg_type=100))
try_write(encode_msg(sender="abcd"))
try_write(encode_msg(body=struct.pack("<QQQQ", 100, 200, 300, 400)))
try_write(
encode_msg(
msg_type=2, sender="abcd", body=struct.pack("<QQQQ", 100, 200, 300, 400)
)
)

LOG.info("Sending messages with incorrect sender length")
try_write(encode_msg(sender="abcd", sender_len_override=0))
try_write(encode_msg(sender="abcd", sender_len_override=1))
try_write(encode_msg(sender="abcd", sender_len_override=5))
try_write(
b"\x0b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00bbbb"
)

LOG.info("Sending messages with randomised bodies")
for _ in range(10):
body_len = random.randrange(10, 100)
body = random.getrandbits(body_len * 8).to_bytes(body_len, "little")
try_write(encode_msg(msg_type=random.randrange(0, 3), body=body))

# Don't fill the output with failure messages from this probing
network.ignore_error_pattern_on_shutdown(
"Exception in bool ccf::Channel::recv_key_exchange_message"
)
network.ignore_error_pattern_on_shutdown("Unknown node message type")
network.ignore_error_pattern_on_shutdown("Unhandled AFT message type")
network.ignore_error_pattern_on_shutdown("Unknown frontend msg type")


if __name__ == "__main__":
args = infra.e2e_args.cli_args()
args.package = "samples/apps/logging/liblogging"

args.nodes = infra.e2e_args.nodes(args, 1)
run_node_socket_robustness_tests(args)

# Set a relatively low cap on max open sessions, so we can saturate it in a reasonable amount of time
args.max_open_sessions = 40
args.max_open_sessions_hard = args.max_open_sessions + 5

args.nodes = infra.e2e_args.nodes(args, 1)
args.initial_user_count = 1
run(args)

run_connection_caps_tests(args)

0 comments on commit e618e51

Please sign in to comment.