From f84c0f8020b252978e9b157179e9a99233cd33aa Mon Sep 17 00:00:00 2001 From: VendettaReborn Date: Tue, 19 Mar 2024 20:42:44 +0800 Subject: [PATCH] feat: add custom session id generator, based on shadow-tls's impl --- rustls/src/client/client_conn.rs | 44 ++++++++++++++++++++++- rustls/src/client/hs.rs | 60 ++++++++++++++++++++++++++++---- rustls/src/msgs/handshake.rs | 12 +++---- 3 files changed, 103 insertions(+), 13 deletions(-) diff --git a/rustls/src/client/client_conn.rs b/rustls/src/client/client_conn.rs index af49032ffc..4455325ba2 100644 --- a/rustls/src/client/client_conn.rs +++ b/rustls/src/client/client_conn.rs @@ -683,6 +683,24 @@ mod connection { }) } + /// Make a new ClientConnection with a session id generator. `config` controls how + pub fn new_with_session_id_generator( + config: Arc, + name: ServerName<'static>, + generator: Option [u8; 32]>, + ) -> Result { + Ok(Self { + inner: ConnectionCore::for_client_with_session_id_generator( + config, + name, + Vec::new(), + Protocol::Tcp, + generator, + )? + .into(), + }) + } + /// Returns an `io::Write` implementer you can write bytes to /// to send TLS1.3 early data (a.k.a. "0-RTT data") to the server. /// @@ -814,7 +832,31 @@ impl ConnectionCore { sendable_plaintext: None, }; - let state = hs::start_handshake(name, extra_exts, config, &mut cx)?; + let state = + hs::start_handshake:: [u8; 32]>(name, extra_exts, config, &mut cx, None)?; + Ok(Self::new(state, data, common_state)) + } + + pub(crate) fn for_client_with_session_id_generator( + config: Arc, + name: ServerName<'static>, + extra_exts: Vec, + proto: Protocol, + generator: Option [u8; 32]>, + ) -> Result { + let mut common_state = CommonState::new(Side::Client); + common_state.set_max_fragment_size(config.max_fragment_size)?; + common_state.protocol = proto; + common_state.enable_secret_extraction = config.enable_secret_extraction; + let mut data = ClientConnectionData::new(); + + let mut cx = hs::ClientContext { + common: &mut common_state, + data: &mut data, + sendable_plaintext: None, + }; + + let state = hs::start_handshake(name, extra_exts, config, &mut cx, generator)?; Ok(Self::new(state, data, common_state)) } diff --git a/rustls/src/client/hs.rs b/rustls/src/client/hs.rs index 60b3c30e28..660f5e5308 100644 --- a/rustls/src/client/hs.rs +++ b/rustls/src/client/hs.rs @@ -25,6 +25,7 @@ use crate::error::{Error, PeerIncompatible, PeerMisbehaved}; use crate::hash_hs::HandshakeHashBuffer; use crate::log::{debug, trace}; use crate::msgs::base::Payload; +use crate::msgs::codec::Codec; use crate::msgs::enums::{Compression, ECPointFormat, ExtensionType, PSKKeyExchangeMode}; use crate::msgs::handshake::{ CertificateStatusRequest, ClientExtension, ClientHelloPayload, ClientSessionTicket, @@ -91,12 +92,16 @@ fn find_session( found } -pub(super) fn start_handshake( +pub(super) fn start_handshake( server_name: ServerName<'static>, extra_exts: Vec, config: Arc, cx: &mut ClientContext<'_>, -) -> NextStateOrError<'static> { + session_id_generator: Option, +) -> NextStateOrError<'static> +where + T: Fn(&[u8]) -> [u8; 32], +{ let mut transcript_buffer = HandshakeHashBuffer::new(); if config .client_auth_cert_resolver @@ -117,7 +122,19 @@ pub(super) fn start_handshake( None }; - let session_id = if let Some(_resuming) = &mut resuming { + let mut session_id: Option = None; + if let Some(_resuming) = &mut resuming { + #[cfg(feature = "tls12")] + if let ClientSessionValue::Tls12(inner) = &mut _resuming.value { + // If we have a ticket, we use the sessionid as a signal that + // we're doing an abbreviated handshake. See section 3.4 in + // RFC5077. + if !inner.ticket().is_empty() { + inner.session_id = SessionId::random(config.provider.secure_random)?; + } + session_id = Some(inner.session_id); + } + debug!("Resuming session"); match &mut _resuming.value { @@ -169,6 +186,7 @@ pub(super) fn start_handshake( key_share, extra_exts, None, + session_id_generator, ClientHelloInput { config, resuming, @@ -213,16 +231,20 @@ struct ClientHelloInput { prev_ech_ext: Option, } -fn emit_client_hello_for_retry( +fn emit_client_hello_for_retry( mut transcript_buffer: HandshakeHashBuffer, retryreq: Option<&HelloRetryRequest>, key_share: Option>, extra_exts: Vec, suite: Option, + session_id_generator: Option, mut input: ClientHelloInput, cx: &mut ClientContext<'_>, mut ech_state: Option, -) -> NextStateOrError<'static> { +) -> NextStateOrError<'static> +where + T: Fn(&[u8]) -> [u8; 32], +{ let config = &input.config; // Defense in depth: the ECH state should be None if ECH is disabled based on config // builder semantics. @@ -461,6 +483,31 @@ fn emit_client_hello_for_retry( _ => None, }; + // ref: https://github.com/shadow-tls/rustls/blob/c033c22cdbb6b08adf8b35571ee8427c70512d13/rustls/src/client/hs.rs#L365 + if let Some(generator) = session_id_generator { + let mut buffer = Vec::new(); + match &mut chp.payload { + HandshakePayload::ClientHello(c) => { + c.session_id = SessionId { + len: 32, + data: [0; 32], + }; + } + _ => unreachable!(), + } + chp.encode(&mut buffer); + let session_id = SessionId { + len: 32, + data: generator(&buffer), + }; + match &mut chp.payload { + HandshakePayload::ClientHello(c) => { + c.session_id = session_id; + } + _ => unreachable!(), + } + } + let ch = Message { version: match retryreq { // : @@ -1044,12 +1091,13 @@ impl ExpectServerHelloOrHelloRetryRequest { _ => offered_key_share, }; - emit_client_hello_for_retry( + emit_client_hello_for_retry:: [u8; 32]>( transcript_buffer, Some(hrr), Some(key_share), self.extra_exts, Some(cs), + None, self.next.input, cx, self.next.ech_state, diff --git a/rustls/src/msgs/handshake.rs b/rustls/src/msgs/handshake.rs index c9badca011..1c6dc9da08 100644 --- a/rustls/src/msgs/handshake.rs +++ b/rustls/src/msgs/handshake.rs @@ -115,8 +115,8 @@ impl From<[u8; 32]> for Random { #[derive(Copy, Clone)] pub struct SessionId { - len: usize, - data: [u8; 32], + pub(crate)len: usize, + pub(crate)data: [u8; 32], } impl fmt::Debug for SessionId { @@ -985,7 +985,7 @@ impl ClientHelloPayload { pub(crate) fn namedgroups_extension(&self) -> Option<&[NamedGroup]> { let ext = self.find_extension(ExtensionType::EllipticCurves)?; match *ext { - ClientExtension::NamedGroups(ref req) => Some(req), + ClientExtension::NamedGroups(ref req) => Some(req.as_slice()), _ => None, } } @@ -994,7 +994,7 @@ impl ClientHelloPayload { pub(crate) fn ecpoints_extension(&self) -> Option<&[ECPointFormat]> { let ext = self.find_extension(ExtensionType::ECPointFormats)?; match *ext { - ClientExtension::EcPointFormats(ref req) => Some(req), + ClientExtension::EcPointFormats(ref req) => Some(req.as_slice()), _ => None, } } @@ -1068,7 +1068,7 @@ impl ClientHelloPayload { pub(crate) fn psk_modes(&self) -> Option<&[PSKKeyExchangeMode]> { let ext = self.find_extension(ExtensionType::PSKKeyExchangeModes)?; match *ext { - ClientExtension::PresharedKeyModes(ref psk_modes) => Some(psk_modes), + ClientExtension::PresharedKeyModes(ref psk_modes) => Some(psk_modes.as_slice()), _ => None, } } @@ -1367,7 +1367,7 @@ impl ServerHelloPayload { pub(crate) fn ecpoints_extension(&self) -> Option<&[ECPointFormat]> { let ext = self.find_extension(ExtensionType::ECPointFormats)?; match *ext { - ServerExtension::EcPointFormats(ref fmts) => Some(fmts), + ServerExtension::EcPointFormats(ref fmts) => Some(fmts.as_slice()), _ => None, } }