diff --git a/tests/unit/s2n_ktls_io_test.c b/tests/unit/s2n_ktls_io_test.c index 901aca9ce5e..eeaefad05f7 100644 --- a/tests/unit/s2n_ktls_io_test.c +++ b/tests/unit/s2n_ktls_io_test.c @@ -1262,7 +1262,6 @@ int main(int argc, char **argv) EXPECT_OK(s2n_assert_seq_num_equal(seq_num, expected_seq_num)); /* Test: Send enough data to hit the encryption limit */ - expected_seq_num += large_test_data_records; EXPECT_FAILURE_WITH_ERRNO( s2n_send(conn, large_test_data, sizeof(large_test_data), &blocked), S2N_ERR_KTLS_KEY_LIMIT); @@ -1284,23 +1283,37 @@ int main(int argc, char **argv) EXPECT_FAILURE_WITH_ERRNO( s2n_send(conn, large_test_data, 1, &blocked), S2N_ERR_KTLS_KEY_LIMIT); - EXPECT_OK(s2n_assert_seq_num_equal(seq_num, test_encryption_limit + 1)); + EXPECT_OK(s2n_assert_seq_num_equal(seq_num, test_encryption_limit)); }; /* Test: Limit not tracked with TLS1.2 */ { - conn->actual_protocol_version = S2N_TLS12; - DEFER_CLEANUP(struct s2n_blob seq_num = { 0 }, s2n_blob_zero); EXPECT_OK(s2n_connection_get_sequence_number(conn, conn->mode, &seq_num)); - EXPECT_EQUAL(s2n_send(conn, large_test_data, 1, &blocked), 1); + /* Sequence number not incremented with TLS1.2 */ + conn->actual_protocol_version = S2N_TLS12; + EXPECT_EQUAL( + s2n_send(conn, large_test_data, sizeof(large_test_data), &blocked), + sizeof(large_test_data)); EXPECT_OK(s2n_assert_seq_num_equal(seq_num, 0)); + /* Sequence number incremented with TLS1.3 */ + conn->actual_protocol_version = S2N_TLS13; EXPECT_EQUAL( s2n_send(conn, large_test_data, sizeof(large_test_data), &blocked), sizeof(large_test_data)); - EXPECT_OK(s2n_assert_seq_num_equal(seq_num, 0)); + EXPECT_OK(s2n_assert_seq_num_equal(seq_num, test_encryption_limit)); + + /* Passing the limit with TLS1.3 is an error */ + conn->actual_protocol_version = S2N_TLS13; + EXPECT_FAILURE_WITH_ERRNO( + s2n_send(conn, large_test_data, 1, &blocked), + S2N_ERR_KTLS_KEY_LIMIT); + + /* Passing the limit with TLS1.2 is NOT an error */ + conn->actual_protocol_version = S2N_TLS12; + EXPECT_EQUAL(s2n_send(conn, large_test_data, 1, &blocked), 1); }; } }; diff --git a/tests/unit/s2n_safety_test.c b/tests/unit/s2n_safety_test.c index 41122318a4b..98a4b19ddeb 100644 --- a/tests/unit/s2n_safety_test.c +++ b/tests/unit/s2n_safety_test.c @@ -387,6 +387,57 @@ int main(int argc, char **argv) CHECK_OVF(s2n_add_overflow, uint32_t, 100, ACTUAL_MAX - 99); CHECK_OVF(s2n_add_overflow, uint32_t, 100, ACTUAL_MAX - 1); + /* Test: S2N_ADD_IS_OVERFLOW_SAFE */ + { + const size_t num = 100; + + uint64_t success_test_values[][3] = { + { 0, 0, 0 }, + { 1, 0, 1 }, + { 0, 0, UINT8_MAX }, + { 1, 1, UINT8_MAX }, + { UINT8_MAX, 0, UINT8_MAX }, + { UINT8_MAX - num, num, UINT8_MAX }, + { UINT8_MAX / 2, UINT8_MAX / 2, UINT8_MAX }, + { 1, 1, UINT64_MAX }, + { UINT64_MAX, 0, UINT64_MAX }, + { UINT64_MAX - num, num, UINT64_MAX }, + { UINT64_MAX / 2, UINT64_MAX / 2, UINT64_MAX }, + }; + for (size_t i = 0; i < s2n_array_len(success_test_values); i++) { + uint64_t v1 = success_test_values[i][0]; + uint64_t v2 = success_test_values[i][1]; + uint64_t max = success_test_values[i][2]; + EXPECT_TRUE(S2N_ADD_IS_OVERFLOW_SAFE(v1, v2, max)); + EXPECT_TRUE(S2N_ADD_IS_OVERFLOW_SAFE(v2, v1, max)); + } + + uint64_t failure_test_values[][3] = { + { 1, 0, 0 }, + { UINT8_MAX, 0, 0 }, + { UINT64_MAX, 0, UINT8_MAX }, + { UINT64_MAX, UINT64_MAX, UINT8_MAX }, + { UINT8_MAX, 1, UINT8_MAX }, + { UINT8_MAX - 1, UINT8_MAX - 1, UINT8_MAX }, + { UINT16_MAX, 1, UINT16_MAX }, + { UINT64_MAX, 1, UINT64_MAX }, + { UINT8_MAX, num, UINT8_MAX }, + { UINT16_MAX, num, UINT16_MAX }, + { UINT64_MAX, num, UINT64_MAX }, + { UINT8_MAX, UINT8_MAX, UINT8_MAX }, + { UINT16_MAX, UINT16_MAX, UINT16_MAX }, + { UINT64_MAX, UINT64_MAX, UINT64_MAX }, + { UINT64_MAX - num, UINT64_MAX - num, UINT64_MAX }, + }; + for (size_t i = 0; i < s2n_array_len(failure_test_values); i++) { + uint64_t v1 = failure_test_values[i][0]; + uint64_t v2 = failure_test_values[i][1]; + uint64_t max = failure_test_values[i][2]; + EXPECT_FALSE(S2N_ADD_IS_OVERFLOW_SAFE(v1, v2, max)); + EXPECT_FALSE(S2N_ADD_IS_OVERFLOW_SAFE(v2, v1, max)); + } + } + END_TEST(); return 0; } diff --git a/tests/unit/s2n_send_test.c b/tests/unit/s2n_send_test.c index 3918be6f19a..d9ac4173e6b 100644 --- a/tests/unit/s2n_send_test.c +++ b/tests/unit/s2n_send_test.c @@ -600,5 +600,131 @@ int main(int argc, char **argv) EXPECT_EQUAL(conn->out.blob.size, out_size[S2N_MFL_DEFAULT]); }; + /* Test: s2n_sendv_with_offset_total_size */ + { + const struct iovec test_multiple_bufs[] = { + { .iov_len = 0 }, + { .iov_len = 1 }, + { .iov_len = 2 }, + { .iov_len = 0 }, + { .iov_len = 14 }, + { .iov_len = 0 }, + { .iov_len = 3 }, + { .iov_len = 0 }, + }; + const size_t test_multiple_bufs_total_size = 20; + + /* Safety */ + { + size_t out = 0; + EXPECT_ERROR_WITH_ERRNO( + s2n_sendv_with_offset_total_size(NULL, 0, 0, NULL), + S2N_ERR_NULL); + EXPECT_ERROR_WITH_ERRNO( + s2n_sendv_with_offset_total_size(NULL, 1, 0, &out), + S2N_ERR_NULL); + } + + /* No iovecs */ + { + size_t out = 0; + EXPECT_OK(s2n_sendv_with_offset_total_size(NULL, 0, 0, &out)); + EXPECT_EQUAL(out, 0); + } + + /* Array of zero-length iovecs */ + { + const struct iovec test_bufs[10] = { 0 }; + size_t out = 0; + EXPECT_OK(s2n_sendv_with_offset_total_size( + test_bufs, s2n_array_len(test_bufs), 0, &out)); + EXPECT_EQUAL(out, 0); + } + + /* Single iovec */ + { + const size_t expected_size = 10; + const struct iovec test_buf = { .iov_len = expected_size }; + size_t out = 0; + EXPECT_OK(s2n_sendv_with_offset_total_size(&test_buf, 1, 0, &out)); + EXPECT_EQUAL(out, expected_size); + } + + /* Single iovec with offset */ + { + const struct iovec test_buf = { .iov_len = 10 }; + const ssize_t offset = 5; + size_t out = 0; + EXPECT_OK(s2n_sendv_with_offset_total_size(&test_buf, 1, offset, &out)); + EXPECT_EQUAL(out, test_buf.iov_len - offset); + } + + /* Multiple iovecs */ + { + size_t out = 0; + EXPECT_OK(s2n_sendv_with_offset_total_size( + test_multiple_bufs, s2n_array_len(test_multiple_bufs), 0, &out)); + EXPECT_EQUAL(out, test_multiple_bufs_total_size); + } + + /* Multiple iovecs with offset */ + { + const size_t offset = 10; + size_t out = 0; + EXPECT_OK(s2n_sendv_with_offset_total_size( + test_multiple_bufs, s2n_array_len(test_multiple_bufs), offset, &out)); + EXPECT_EQUAL(out, test_multiple_bufs_total_size - offset); + } + + /* Offset with no data */ + { + const struct iovec test_bufs[10] = { 0 }; + size_t out = 0; + EXPECT_ERROR_WITH_ERRNO( + s2n_sendv_with_offset_total_size(NULL, 0, 1, &out), + S2N_ERR_INVALID_ARGUMENT); + EXPECT_ERROR_WITH_ERRNO( + s2n_sendv_with_offset_total_size(test_bufs, 0, 1, &out), + S2N_ERR_INVALID_ARGUMENT); + EXPECT_ERROR_WITH_ERRNO( + s2n_sendv_with_offset_total_size(test_bufs, s2n_array_len(test_bufs), 1, &out), + S2N_ERR_INVALID_ARGUMENT); + } + + /* Offset larger than available data */ + { + const struct iovec test_buf = { .iov_len = 10 }; + size_t out = 0; + + ssize_t test_buf_offset = test_buf.iov_len + 1; + EXPECT_ERROR_WITH_ERRNO( + s2n_sendv_with_offset_total_size(&test_buf, 1, test_buf_offset, &out), + S2N_ERR_INVALID_ARGUMENT); + + ssize_t test_multiple_bufs_offset = test_multiple_bufs_total_size + 1; + EXPECT_ERROR_WITH_ERRNO( + s2n_sendv_with_offset_total_size(test_multiple_bufs, + s2n_array_len(test_multiple_bufs), test_multiple_bufs_offset, &out), + S2N_ERR_INVALID_ARGUMENT); + } + + /* Too much data to count + * + * This isn't really practically possible since an application would need + * to allocate more than SIZE_MAX memory for the iovec buffers, but we + * should ensure that the inputs don't cause unexpected behavior. + */ + { + const struct iovec test_bufs[] = { + { .iov_len = SIZE_MAX }, + { .iov_len = 1 }, + }; + size_t out = 0; + EXPECT_ERROR_WITH_ERRNO( + s2n_sendv_with_offset_total_size(test_bufs, s2n_array_len(test_bufs), 0, &out), + S2N_ERR_INVALID_ARGUMENT); + } + }; + END_TEST(); } diff --git a/tls/s2n_ktls_io.c b/tls/s2n_ktls_io.c index 68b92e0384e..a1fd655bc9c 100644 --- a/tls/s2n_ktls_io.c +++ b/tls/s2n_ktls_io.c @@ -269,6 +269,21 @@ S2N_RESULT s2n_ktls_recvmsg(void *io_context, uint8_t *record_type, uint8_t *buf return S2N_RESULT_OK; } +/* The RFC defines the encryption limits in terms of "full-size records" sent. + * We can estimate the number of "full-sized records" sent by assuming that + * all records are full-sized. + */ +static S2N_RESULT s2n_ktls_estimate_records(size_t bytes, uint64_t *estimate) +{ + RESULT_ENSURE_REF(estimate); + uint64_t records = bytes / S2N_TLS_MAXIMUM_FRAGMENT_LENGTH; + if (bytes % S2N_TLS_MAXIMUM_FRAGMENT_LENGTH) { + records++; + } + *estimate = records; + return S2N_RESULT_OK; +} + /* ktls does not currently support updating keys, so we should kill the connection * when the key encryption limit is reached. We could get the current record * sequence number from the kernel with getsockopt, but that requires a surprisingly @@ -277,34 +292,51 @@ S2N_RESULT s2n_ktls_recvmsg(void *io_context, uint8_t *record_type, uint8_t *buf * Instead, we track the estimated sequence number and enforce the limit based * on that estimate. */ -static S2N_RESULT s2n_ktls_enforce_estimated_record_limit( - struct s2n_connection *conn, size_t bytes_written) +static S2N_RESULT s2n_ktls_check_estimated_record_limit( + struct s2n_connection *conn, size_t bytes_requested) { RESULT_ENSURE_REF(conn); if (conn->actual_protocol_version < S2N_TLS13) { return S2N_RESULT_OK; } + uint64_t new_records_sent = 0; + RESULT_GUARD(s2n_ktls_estimate_records(bytes_requested, &new_records_sent)); + + uint64_t old_records_sent = 0; struct s2n_blob seq_num = { 0 }; RESULT_GUARD(s2n_connection_get_sequence_number(conn, conn->mode, &seq_num)); + RESULT_GUARD_POSIX(s2n_sequence_number_to_uint64(&seq_num, &old_records_sent)); - /* The RFC states the encryption limits in terms of "full-size records" sent. - * We can estimate the number of "full-sized records" sent by assuming that - * all records are full-sized. - */ - while (bytes_written > 0) { - RESULT_GUARD_POSIX(s2n_increment_sequence_number(&seq_num)); - bytes_written -= MIN(bytes_written, S2N_TLS_MAXIMUM_FRAGMENT_LENGTH); - } - - uint64_t records_sent = 0; - RESULT_GUARD_POSIX(s2n_sequence_number_to_uint64(&seq_num, &records_sent)); + RESULT_ENSURE(S2N_ADD_IS_OVERFLOW_SAFE(old_records_sent, new_records_sent, UINT64_MAX), + S2N_ERR_KTLS_KEY_LIMIT); + uint64_t total_records_sent = old_records_sent + new_records_sent; RESULT_ENSURE_REF(conn->secure); RESULT_ENSURE_REF(conn->secure->cipher_suite); RESULT_ENSURE_REF(conn->secure->cipher_suite->record_alg); uint64_t encryption_limit = conn->secure->cipher_suite->record_alg->encryption_limit; - RESULT_ENSURE(records_sent <= encryption_limit, S2N_ERR_KTLS_KEY_LIMIT); + RESULT_ENSURE(total_records_sent <= encryption_limit, S2N_ERR_KTLS_KEY_LIMIT); + return S2N_RESULT_OK; +} + +static S2N_RESULT s2n_ktls_set_estimated_sequence_number( + struct s2n_connection *conn, size_t bytes_written) +{ + RESULT_ENSURE_REF(conn); + if (conn->actual_protocol_version < S2N_TLS13) { + return S2N_RESULT_OK; + } + + uint64_t new_records_sent = 0; + RESULT_GUARD(s2n_ktls_estimate_records(bytes_written, &new_records_sent)); + + struct s2n_blob seq_num = { 0 }; + RESULT_GUARD(s2n_connection_get_sequence_number(conn, conn->mode, &seq_num)); + + for (size_t i = 0; i < new_records_sent; i++) { + RESULT_GUARD_POSIX(s2n_increment_sequence_number(&seq_num)); + } return S2N_RESULT_OK; } @@ -387,6 +419,10 @@ ssize_t s2n_ktls_sendv_with_offset(struct s2n_connection *conn, const struct iov POSIX_ENSURE(offs_in >= 0, S2N_ERR_INVALID_ARGUMENT); size_t offs = offs_in; + ssize_t total_bytes = 0; + POSIX_GUARD_RESULT(s2n_sendv_with_offset_total_size(bufs, count_in, offs_in, &total_bytes)); + POSIX_GUARD_RESULT(s2n_ktls_check_estimated_record_limit(conn, total_bytes)); + DEFER_CLEANUP(struct s2n_blob new_bufs = { 0 }, s2n_free_or_wipe); uint8_t new_bufs_mem[S2N_MAX_STACK_IOVECS_MEM] = { 0 }; POSIX_GUARD(s2n_blob_init(&new_bufs, new_bufs_mem, sizeof(new_bufs_mem))); @@ -398,11 +434,7 @@ ssize_t s2n_ktls_sendv_with_offset(struct s2n_connection *conn, const struct iov POSIX_GUARD_RESULT(s2n_ktls_sendmsg(conn->send_io_context, TLS_APPLICATION_DATA, bufs, count, blocked, &bytes_written)); - /* Unlike s2n_sendfile, here we could calculate the number of bytes to be sent - * before actually sending them. However, we instead choose to maintain consistent - * behavior across our send methods and always check for the limit after the send. - */ - POSIX_GUARD_RESULT(s2n_ktls_enforce_estimated_record_limit(conn, bytes_written)); + POSIX_GUARD_RESULT(s2n_ktls_set_estimated_sequence_number(conn, bytes_written)); return bytes_written; } @@ -466,6 +498,7 @@ int s2n_sendfile(struct s2n_connection *conn, int in_fd, off_t offset, size_t co *bytes_written = 0; POSIX_ENSURE_REF(conn); POSIX_ENSURE(conn->ktls_send_enabled, S2N_ERR_KTLS_UNSUPPORTED_CONN); + POSIX_GUARD_RESULT(s2n_ktls_check_estimated_record_limit(conn, count)); int out_fd = 0; POSIX_GUARD_RESULT(s2n_ktls_get_file_descriptor(conn, S2N_KTLS_MODE_SEND, &out_fd)); @@ -480,20 +513,8 @@ int s2n_sendfile(struct s2n_connection *conn, int in_fd, off_t offset, size_t co POSIX_BAIL(S2N_ERR_UNIMPLEMENTED); #endif + POSIX_GUARD_RESULT(s2n_ktls_set_estimated_sequence_number(conn, *bytes_written)); *blocked = S2N_NOT_BLOCKED; - - /* Because we pass the input file descriptor to the kernel without examining - * it, we don't know how many bytes actually need to be sent. We therefore - * can't verify that the send is safe with respect to the encryption limit - * before sending the records. Instead, we raise a fatal error afterwards if - * the send violated the encryption limit. - * - * An application should treat S2N_ERR_KTLS_KEY_LIMIT as a very high severity - * error, as it indicates that the application is violating the requirements - * for using TLS1.3 with ktls without a kernel patch to enable KeyUpdates, - * and is therefore operating unsafely. - */ - POSIX_GUARD_RESULT(s2n_ktls_enforce_estimated_record_limit(conn, *bytes_written)); return S2N_SUCCESS; } diff --git a/tls/s2n_send.c b/tls/s2n_send.c index a9ec8aa0d85..c97b773425b 100644 --- a/tls/s2n_send.c +++ b/tls/s2n_send.c @@ -103,6 +103,37 @@ int s2n_flush(struct s2n_connection *conn, s2n_blocked_status *blocked) return 0; } +S2N_RESULT s2n_sendv_with_offset_total_size(const struct iovec *bufs, ssize_t count, + ssize_t offs, ssize_t *total_size_out) +{ + RESULT_ENSURE_REF(total_size_out); + if (count) { + RESULT_ENSURE_REF(bufs); + } + + size_t total_size = 0; + for (ssize_t i = 0; i < count; i++) { + size_t iov_len = bufs[i].iov_len; + /* Account for any offset */ + if (offs) { + size_t offs_consumed = MIN(offs, iov_len); + iov_len -= offs_consumed; + offs -= offs_consumed; + } + RESULT_ENSURE(S2N_ADD_IS_OVERFLOW_SAFE(total_size, iov_len, SSIZE_MAX), + S2N_ERR_INVALID_ARGUMENT); + total_size += iov_len; + } + + /* We must have fully accounted for the offset, or else the offset is larger + * than the available data and our inputs are invalid. + */ + RESULT_ENSURE(offs == 0, S2N_ERR_INVALID_ARGUMENT); + + *total_size_out = total_size; + return S2N_RESULT_OK; +} + ssize_t s2n_sendv_with_offset_impl(struct s2n_connection *conn, const struct iovec *bufs, ssize_t count, ssize_t offs, s2n_blocked_status *blocked) { @@ -137,23 +168,9 @@ ssize_t s2n_sendv_with_offset_impl(struct s2n_connection *conn, const struct iov writer = conn->client; } + POSIX_GUARD_RESULT(s2n_sendv_with_offset_total_size(bufs, count, offs, &total_size)); /* Defensive check against an invalid retry */ - if (offs > 0) { - const struct iovec *_bufs = bufs; - ssize_t _count = count; - while ((size_t) offs >= _bufs->iov_len && _count > 0) { - offs -= _bufs->iov_len; - _bufs++; - _count--; - } - bufs = _bufs; - count = _count; - } - for (ssize_t i = 0; i < count; i++) { - total_size += bufs[i].iov_len; - } - total_size -= offs; - S2N_ERROR_IF(conn->current_user_data_consumed > total_size, S2N_ERR_SEND_SIZE); + POSIX_ENSURE(conn->current_user_data_consumed <= total_size, S2N_ERR_SEND_SIZE); POSIX_GUARD_RESULT(s2n_early_data_validate_send(conn, total_size)); if (conn->dynamic_record_timeout_threshold > 0) { diff --git a/tls/s2n_tls.h b/tls/s2n_tls.h index b59378e971b..ff7670535b7 100644 --- a/tls/s2n_tls.h +++ b/tls/s2n_tls.h @@ -83,6 +83,8 @@ int s2n_handshake_write_header(struct s2n_stuffer *out, uint8_t message_type); int s2n_handshake_finish_header(struct s2n_stuffer *out); S2N_RESULT s2n_handshake_parse_header(struct s2n_stuffer *io, uint8_t *message_type, uint32_t *length); int s2n_read_full_record(struct s2n_connection *conn, uint8_t *record_type, int *isSSLv2); +S2N_RESULT s2n_sendv_with_offset_total_size(const struct iovec *bufs, ssize_t count, + ssize_t offs, ssize_t *total_size_out); extern uint16_t mfl_code_to_length[5]; diff --git a/utils/s2n_safety.h b/utils/s2n_safety.h index f69e7574ea7..f5a9d7a3afe 100644 --- a/utils/s2n_safety.h +++ b/utils/s2n_safety.h @@ -109,3 +109,4 @@ int s2n_mul_overflow(uint32_t a, uint32_t b, uint32_t* out); int s2n_align_to(uint32_t initial, uint32_t alignment, uint32_t* out); int s2n_add_overflow(uint32_t a, uint32_t b, uint32_t* out); int s2n_sub_overflow(uint32_t a, uint32_t b, uint32_t* out); +#define S2N_ADD_IS_OVERFLOW_SAFE(a, b, max) (((max) >= (a)) && ((max) - (a) >= (b)))