Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Discard malformed traffic sent to node-to-node ports #5889

Merged
merged 11 commits into from
Jan 9, 2024
Merged
10 changes: 9 additions & 1 deletion src/enclave/enclave.h
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,15 @@ namespace ccf

DISPATCHER_SET_MESSAGE_HANDLER(
bp, ccf::node_inbound, [this](const uint8_t* data, size_t size) {
node->recv_node_inbound(data, size);
try
{
node->recv_node_inbound(data, size);
}
catch (const std::exception& e)
{
LOG_DEBUG_FMT(
"Ignoring node_inbound message due to exception: {}", e.what());
}
});

DISPATCHER_SET_MESSAGE_HANDLER(
Expand Down
35 changes: 32 additions & 3 deletions src/host/node_connections.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ namespace asynchost
node(node)
{}

void on_read(size_t len, uint8_t*& incoming, sockaddr)
bool on_read(size_t len, uint8_t*& incoming, sockaddr) override
{
LOG_DEBUG_FMT(
"from node {} received {} bytes",
Expand Down Expand Up @@ -71,8 +71,35 @@ namespace asynchost
}

const auto size_pre_headers = size;
auto msg_type = serialized::read<ccf::NodeMsgType>(data, size);
ccf::NodeId from = serialized::read<ccf::NodeId::Value>(data, size);

ccf::NodeMsgType msg_type;
try
{
msg_type = serialized::read<ccf::NodeMsgType>(data, size);
}
catch (const std::exception& e)
{
LOG_DEBUG_FMT(
"Received invalid node-to-node traffic. Unable to read message "
"type ({}). Closing connection.",
e.what());
return false;
}

ccf::NodeId from;
try
{
from = serialized::read<ccf::NodeId::Value>(data, size);
}
catch (const std::exception& e)
{
LOG_DEBUG_FMT(
"Received invalid node-to-node traffic. Unable to read sender "
"node ID ({}). Closing connection.",
e.what());
return false;
}

const auto size_post_headers = size;
const size_t payload_size =
msg_size.value() - (size_pre_headers - size_post_headers);
Expand Down Expand Up @@ -107,6 +134,8 @@ namespace asynchost
{
pending.erase(pending.begin(), pending.begin() + used);
}

return true;
}

virtual void associate_incoming(const ccf::NodeId&) {}
Expand Down
8 changes: 6 additions & 2 deletions src/host/rpc_connections.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ namespace asynchost
cleanup();
}

void on_read(size_t len, uint8_t*& data, sockaddr) override
bool on_read(size_t len, uint8_t*& data, sockaddr) override
{
LOG_DEBUG_FMT("rpc read {}: {}", id, len);

Expand All @@ -125,6 +125,8 @@ namespace asynchost
parent.to_enclave,
id,
serializer::ByteRange{data, len});

return true;
}

void on_disconnect() override
Expand Down Expand Up @@ -195,7 +197,7 @@ namespace asynchost
}
}

void on_read(size_t len, uint8_t*& data, sockaddr addr) override
bool on_read(size_t len, uint8_t*& data, sockaddr addr) override
{
// UDP connections don't have clients, it's all done in the server
if constexpr (isUDP<ConnType>())
Expand All @@ -211,6 +213,8 @@ namespace asynchost
addr_data,
serializer::ByteRange{data, len});
}

return true;
}

void cleanup()
Expand Down
6 changes: 5 additions & 1 deletion src/host/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ namespace asynchost
virtual ~SocketBehaviour() {}

/// To be implemented by clients
virtual void on_read(size_t, uint8_t*&, sockaddr) {}
/// Return false to immediately disconnect socket.
virtual bool on_read(size_t, uint8_t*&, sockaddr)
{
return true;
}

/// To be implemented by servers with connections
virtual void on_accept(ConnType&) {}
Expand Down
8 changes: 7 additions & 1 deletion src/host/tcp.h
Original file line number Diff line number Diff line change
Expand Up @@ -789,12 +789,18 @@ namespace asynchost
}

uint8_t* p = (uint8_t*)buf->base;
behaviour->on_read((size_t)sz, p, {});
const bool read_good = behaviour->on_read((size_t)sz, p, {});

if (p != nullptr)
{
on_free(buf);
}

if (!read_good)
{
behaviour->on_disconnect();
return;
}
}

static void on_write(uv_write_t* req, int)
Expand Down
119 changes: 113 additions & 6 deletions tests/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
import functools
import httpx
import os
import copy
import socket
import struct
from infra.snp import IS_SNP

from loguru import logger as LOG
Expand Down Expand Up @@ -51,7 +54,13 @@ def interface_caps(i):
}


def run(args):
def run_connection_caps_tests(args):
args = copy.deepcopy(args)

# Set a relatively low cap on max open sessions, so we can saturate it in a reasonable amount of time
args.max_open_sessions = 40
args.max_open_sessions_hard = args.max_open_sessions + 5

# Listen on additional RPC interfaces with even lower session caps
for i, node_spec in enumerate(args.nodes):
caps = interface_caps(i)
Expand Down Expand Up @@ -262,14 +271,112 @@ def create_connections_until_exhaustion(
LOG.warning("Expected a fatal crash and saw none!")


@contextlib.contextmanager
def node_tcp_socket(node):
interface = node.n2n_interface
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.connect((interface.host, interface.port))
yield s
s.close()


def run_node_socket_robustness_tests(args):
with infra.network.network(
args.nodes, args.binary_dir, args.debug_nodes, args.perf_nodes, pdb=args.pdb
) as network:
network.start_and_open(args)

primary, _ = network.find_nodes()

# Protocol is:
# - 4 byte message size N (remainder is not processed until this many bytes arrive)
# - 8 byte message type (valid values are only 0, 1, or 2)
# - Sender node ID (length-prefixed string), consisting of:
# - 8 byte string length S
# - S bytes of string content
# - Message body, of N - 16 - S bytes
# Note number serialization is little-endian!

def encode_msg(
msg_type=0,
sender="OtherNode",
body=b"",
sender_len_override=None,
total_len_override=None,
):
b_type = struct.pack("<Q", msg_type)
sender_len = sender_len_override or len(sender)
b_sender = struct.pack("<Q", sender_len) + sender.encode()
total_len = total_len_override or len(b_type) + len(b_sender) + len(body)
b_size = struct.pack("<I", total_len)
encoded_msg = b_size + b_type + b_sender + body
return encoded_msg

def try_write(msg_bytes):
with node_tcp_socket(primary) as sock:
LOG.debug(
f"Sending raw TCP bytes to {primary.local_node_id}'s node-to-node port: {msg_bytes}"
)
sock.send(msg_bytes)
assert (
not primary.remote.check_done()
), f"Crashed node with N2N message: {msg_bytes}"
LOG.success(f"Node {primary.local_node_id} tolerated this message")

LOG.info("Sending messages which do not contain initial size")
try_write(b"")
try_write(b"\x00")
for size in range(1, 4):
# NB: Regardless of what these bytes contain!
for i in range(5):
msg = random.getrandbits(8 * size).to_bytes(size, byteorder="little")
try_write(msg)

LOG.info("Sending messages which do not contain initial header")
for size in range(0, 16):
try_write(struct.pack("<I", size) + b"\x00" * size)

LOG.info("Sending plausible messages")
try_write(encode_msg())
try_write(encode_msg(msg_type=1))
try_write(encode_msg(msg_type=100))
try_write(encode_msg(sender="abcd"))
try_write(encode_msg(body=struct.pack("<QQQQ", 100, 200, 300, 400)))
try_write(
encode_msg(
msg_type=2, sender="abcd", body=struct.pack("<QQQQ", 100, 200, 300, 400)
)
)

LOG.info("Sending messages with incorrect sender length")
try_write(encode_msg(sender="abcd", sender_len_override=0))
try_write(encode_msg(sender="abcd", sender_len_override=1))
try_write(encode_msg(sender="abcd", sender_len_override=5))
try_write(
b"\x0b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00bbbb"
)

LOG.info("Sending messages with randomised bodies")
for _ in range(10):
body_len = random.randrange(10, 100)
body = random.getrandbits(body_len * 8).to_bytes(body_len, "little")
try_write(encode_msg(msg_type=random.randrange(0, 3), body=body))

# Don't fill the output with failure messages from this probing
network.ignore_error_pattern_on_shutdown(
"Exception in bool ccf::Channel::recv_key_exchange_message"
)
network.ignore_error_pattern_on_shutdown("Unknown node message type")
network.ignore_error_pattern_on_shutdown("Unhandled AFT message type")
network.ignore_error_pattern_on_shutdown("Unknown frontend msg type")


if __name__ == "__main__":
args = infra.e2e_args.cli_args()
args.package = "samples/apps/logging/liblogging"

# Set a relatively low cap on max open sessions, so we can saturate it in a reasonable amount of time
args.max_open_sessions = 40
args.max_open_sessions_hard = args.max_open_sessions + 5

args.nodes = infra.e2e_args.nodes(args, 1)
args.initial_user_count = 1
run(args)

run_connection_caps_tests(args)
run_node_socket_robustness_tests(args)
Loading