Skip to content

Commit

Permalink
Merge branch 'main' into tla-off-by-one
Browse files Browse the repository at this point in the history
  • Loading branch information
heidihoward authored Jan 9, 2024
2 parents 1003ab4 + 287b0ee commit feb9d68
Show file tree
Hide file tree
Showing 9 changed files with 251 additions and 14 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).

## [5.0.0-dev12]

[5.0.0-dev12]: https://github.com/microsoft/CCF/releases/tag/ccf-5.0.0-dev12

### Fixed

- Nodes are now more robust to unexpected traffic on node-to-node ports (#5889).

## [5.0.0-dev11]

[5.0.0-dev11]: https://github.com/microsoft/CCF/releases/tag/ccf-5.0.0-dev11
Expand Down
7 changes: 3 additions & 4 deletions src/consensus/aft/test/driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -279,9 +279,7 @@ class RaftDriver

if (_nodes.find(node_id) == _nodes.end())
{
throw std::runtime_error(fmt::format(
"Node {} does not exist yet. Use \"create_new_node, <node_id>\"",
node_id));
create_new_node(node_id_s);
}

configuration.try_emplace(node_id);
Expand Down Expand Up @@ -1053,7 +1051,8 @@ class RaftDriver
idx)
<< std::endl;
throw std::runtime_error(fmt::format(
"Node not at expected commit idx ({}) on line {} : {}",
"Node {} not at expected commit idx ({}) on line {} : {}",
node_id,
idx,
std::to_string((int)lineno),
_nodes.at(node_id).raft->get_committed_seqno()));
Expand Down
10 changes: 9 additions & 1 deletion src/enclave/enclave.h
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,15 @@ namespace ccf

DISPATCHER_SET_MESSAGE_HANDLER(
bp, ccf::node_inbound, [this](const uint8_t* data, size_t size) {
node->recv_node_inbound(data, size);
try
{
node->recv_node_inbound(data, size);
}
catch (const std::exception& e)
{
LOG_DEBUG_FMT(
"Ignoring node_inbound message due to exception: {}", e.what());
}
});

DISPATCHER_SET_MESSAGE_HANDLER(
Expand Down
35 changes: 32 additions & 3 deletions src/host/node_connections.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ namespace asynchost
node(node)
{}

void on_read(size_t len, uint8_t*& incoming, sockaddr)
bool on_read(size_t len, uint8_t*& incoming, sockaddr) override
{
LOG_DEBUG_FMT(
"from node {} received {} bytes",
Expand Down Expand Up @@ -71,8 +71,35 @@ namespace asynchost
}

const auto size_pre_headers = size;
auto msg_type = serialized::read<ccf::NodeMsgType>(data, size);
ccf::NodeId from = serialized::read<ccf::NodeId::Value>(data, size);

ccf::NodeMsgType msg_type;
try
{
msg_type = serialized::read<ccf::NodeMsgType>(data, size);
}
catch (const std::exception& e)
{
LOG_DEBUG_FMT(
"Received invalid node-to-node traffic. Unable to read message "
"type ({}). Closing connection.",
e.what());
return false;
}

ccf::NodeId from;
try
{
from = serialized::read<ccf::NodeId::Value>(data, size);
}
catch (const std::exception& e)
{
LOG_DEBUG_FMT(
"Received invalid node-to-node traffic. Unable to read sender "
"node ID ({}). Closing connection.",
e.what());
return false;
}

const auto size_post_headers = size;
const size_t payload_size =
msg_size.value() - (size_pre_headers - size_post_headers);
Expand Down Expand Up @@ -107,6 +134,8 @@ namespace asynchost
{
pending.erase(pending.begin(), pending.begin() + used);
}

return true;
}

virtual void associate_incoming(const ccf::NodeId&) {}
Expand Down
8 changes: 6 additions & 2 deletions src/host/rpc_connections.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ namespace asynchost
cleanup();
}

void on_read(size_t len, uint8_t*& data, sockaddr) override
bool on_read(size_t len, uint8_t*& data, sockaddr) override
{
LOG_DEBUG_FMT("rpc read {}: {}", id, len);

Expand All @@ -125,6 +125,8 @@ namespace asynchost
parent.to_enclave,
id,
serializer::ByteRange{data, len});

return true;
}

void on_disconnect() override
Expand Down Expand Up @@ -195,7 +197,7 @@ namespace asynchost
}
}

void on_read(size_t len, uint8_t*& data, sockaddr addr) override
bool on_read(size_t len, uint8_t*& data, sockaddr addr) override
{
// UDP connections don't have clients, it's all done in the server
if constexpr (isUDP<ConnType>())
Expand All @@ -211,6 +213,8 @@ namespace asynchost
addr_data,
serializer::ByteRange{data, len});
}

return true;
}

void cleanup()
Expand Down
6 changes: 5 additions & 1 deletion src/host/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ namespace asynchost
virtual ~SocketBehaviour() {}

/// To be implemented by clients
virtual void on_read(size_t, uint8_t*&, sockaddr) {}
/// Return false to immediately disconnect socket.
virtual bool on_read(size_t, uint8_t*&, sockaddr)
{
return true;
}

/// To be implemented by servers with connections
virtual void on_accept(ConnType&) {}
Expand Down
8 changes: 7 additions & 1 deletion src/host/tcp.h
Original file line number Diff line number Diff line change
Expand Up @@ -789,12 +789,18 @@ namespace asynchost
}

uint8_t* p = (uint8_t*)buf->base;
behaviour->on_read((size_t)sz, p, {});
const bool read_good = behaviour->on_read((size_t)sz, p, {});

if (p != nullptr)
{
on_free(buf);
}

if (!read_good)
{
behaviour->on_disconnect();
return;
}
}

static void on_write(uv_write_t* req, int)
Expand Down
110 changes: 108 additions & 2 deletions tests/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import functools
import httpx
import os
import socket
import struct
from infra.snp import IS_SNP

from loguru import logger as LOG
Expand Down Expand Up @@ -51,7 +53,7 @@ def interface_caps(i):
}


def run(args):
def run_connection_caps_tests(args):
# Listen on additional RPC interfaces with even lower session caps
for i, node_spec in enumerate(args.nodes):
caps = interface_caps(i)
Expand Down Expand Up @@ -262,14 +264,118 @@ def create_connections_until_exhaustion(
LOG.warning("Expected a fatal crash and saw none!")


@contextlib.contextmanager
def node_tcp_socket(node):
interface = node.n2n_interface
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.connect((interface.host, interface.port))
yield s
s.close()


def run_node_socket_robustness_tests(args):
with infra.network.network(
args.nodes, args.binary_dir, args.debug_nodes, args.perf_nodes, pdb=args.pdb
) as network:
network.start_and_open(args)

primary, _ = network.find_nodes()

# Protocol is:
# - 4 byte message size N (remainder is not processed until this many bytes arrive)
# - 8 byte message type (valid values are only 0, 1, or 2)
# - Sender node ID (length-prefixed string), consisting of:
# - 8 byte string length S
# - S bytes of string content
# - Message body, of N - 16 - S bytes
# Note number serialization is little-endian!

def encode_msg(
msg_type=0,
sender="OtherNode",
body=b"",
sender_len_override=None,
total_len_override=None,
):
b_type = struct.pack("<Q", msg_type)
sender_len = sender_len_override or len(sender)
b_sender = struct.pack("<Q", sender_len) + sender.encode()
total_len = total_len_override or len(b_type) + len(b_sender) + len(body)
b_size = struct.pack("<I", total_len)
encoded_msg = b_size + b_type + b_sender + body
return encoded_msg

def try_write(msg_bytes):
with node_tcp_socket(primary) as sock:
LOG.debug(
f"Sending raw TCP bytes to {primary.local_node_id}'s node-to-node port: {msg_bytes}"
)
sock.send(msg_bytes)
assert (
not primary.remote.check_done()
), f"Crashed node with N2N message: {msg_bytes}"
LOG.success(f"Node {primary.local_node_id} tolerated this message")

LOG.info("Sending messages which do not contain initial size")
try_write(b"")
try_write(b"\x00")
for size in range(1, 4):
# NB: Regardless of what these bytes contain!
for i in range(5):
msg = random.getrandbits(8 * size).to_bytes(size, byteorder="little")
try_write(msg)

LOG.info("Sending messages which do not contain initial header")
for size in range(0, 16):
try_write(struct.pack("<I", size) + b"\x00" * size)

LOG.info("Sending plausible messages")
try_write(encode_msg())
try_write(encode_msg(msg_type=1))
try_write(encode_msg(msg_type=100))
try_write(encode_msg(sender="abcd"))
try_write(encode_msg(body=struct.pack("<QQQQ", 100, 200, 300, 400)))
try_write(
encode_msg(
msg_type=2, sender="abcd", body=struct.pack("<QQQQ", 100, 200, 300, 400)
)
)

LOG.info("Sending messages with incorrect sender length")
try_write(encode_msg(sender="abcd", sender_len_override=0))
try_write(encode_msg(sender="abcd", sender_len_override=1))
try_write(encode_msg(sender="abcd", sender_len_override=5))
try_write(
b"\x0b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00bbbb"
)

LOG.info("Sending messages with randomised bodies")
for _ in range(10):
body_len = random.randrange(10, 100)
body = random.getrandbits(body_len * 8).to_bytes(body_len, "little")
try_write(encode_msg(msg_type=random.randrange(0, 3), body=body))

# Don't fill the output with failure messages from this probing
network.ignore_error_pattern_on_shutdown(
"Exception in bool ccf::Channel::recv_key_exchange_message"
)
network.ignore_error_pattern_on_shutdown("Unknown node message type")
network.ignore_error_pattern_on_shutdown("Unhandled AFT message type")
network.ignore_error_pattern_on_shutdown("Unknown frontend msg type")


if __name__ == "__main__":
args = infra.e2e_args.cli_args()
args.package = "samples/apps/logging/liblogging"

args.nodes = infra.e2e_args.nodes(args, 1)
run_node_socket_robustness_tests(args)

# Set a relatively low cap on max open sessions, so we can saturate it in a reasonable amount of time
args.max_open_sessions = 40
args.max_open_sessions_hard = args.max_open_sessions + 5

args.nodes = infra.e2e_args.nodes(args, 1)
args.initial_user_count = 1
run(args)

run_connection_caps_tests(args)
73 changes: 73 additions & 0 deletions tla/trace2scen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the Apache 2.0 License.

import sys
import json
import os

def comment(action):
return f"# {action['name']} {action['location']['module']}:{action['location']['beginLine']}"

def term(ctx, pre):
return str(pre["currentTerm"][ctx['i']])

def noop(ctx, pre, post):
return ["# Noop"]

MAP = {
"ClientRequest": lambda ctx, pre, post: ["replicate", term(ctx, pre), "42"],
"MCClientRequest": lambda ctx, pre, post: ["replicate", term(ctx, pre), "42"],
"CheckQuorum": lambda ctx, pre, post: ["periodic_one", ctx['i'], "110"],
"Timeout": lambda ctx, pre, post: ["periodic_one", ctx['i'], "110"],
"MCTimeout": lambda ctx, pre, post: ["periodic_one", ctx['i'], "110"],
"RequestVote": noop,
"AppendEntries": lambda _, __, ___: ["dispatch_all"],
"BecomeLeader": noop,
"SignCommittableMessages": lambda ctx, pre, post: ["emit_signature", term(ctx, pre)],
"MCSignCommittableMessages": lambda ctx, pre, post: ["emit_signature", term(ctx, pre)],
"ChangeConfigurationInt": lambda ctx, pre, post: ["replicate_new_configuration", term(ctx, pre), *ctx["newConfiguration"]],
"AdvanceCommitIndex": noop,
"HandleRequestVoteRequest": lambda _, __, ___: ["dispatch_all"],
"HandleRequestVoteResponse": noop,
"RejectAppendEntriesRequest": noop,
"ReturnToFollowerState": noop,
"AppendEntriesAlreadyDone": noop,
"RcvDropIgnoredMessage": noop,
"RcvUpdateTerm": noop,
"RcvRequestVoteRequest": noop,
"RcvRequestVoteResponse": noop,
}

def post_commit(post):
return [["assert_commit_idx", node, str(idx)] for node, idx in post["commitIndex"].items()]

def post_state(post):
entries = []
for node, state in post["state"].items():
if state == "Leader":
entries.append(["assert_is_primary", node])
elif state == "Follower":
entries.append(["assert_is_backup", node])
elif state == "Candidate":
entries.append(["assert_is_candidate", node])
return entries

def step_to_action(pre_state, action, post_state):
return os.linesep.join([
comment(action),
','.join(MAP[action['name']](action['context'], pre_state[1], post_state[1]))])

def asserts(pre_state, action, post_state, assert_gen):
return os.linesep.join([','.join(assertion) for assertion in assert_gen(post_state[1])])

if __name__ == "__main__":
with open(sys.argv[1]) as trace:
steps = json.load(trace)["action"]
initial_state = steps[0][0][1]
initial_node, = [node for node, log in initial_state["log"].items() if log]
print(f"start_node,{initial_node}")
print(f"emit_signature,2")
for step in steps:
print(step_to_action(*step))
print(asserts(*step, post_state))
print(asserts(*steps[-1], post_commit))

0 comments on commit feb9d68

Please sign in to comment.