diff --git a/rustls/src/client/client_conn.rs b/rustls/src/client/client_conn.rs index b84219f27a..e848ce9802 100644 --- a/rustls/src/client/client_conn.rs +++ b/rustls/src/client/client_conn.rs @@ -551,6 +551,24 @@ impl ClientConnection { }) } + /// Make a new ClientConnection with a session id generator. `config` controls how + pub fn new_with_session_id_generator( + config: Arc, + name: ServerName, + generator: Option [u8; 32]>, + ) -> Result { + Ok(Self { + inner: ConnectionCore::for_client_with_session_id_generator( + config, + name, + Vec::new(), + Protocol::Tcp, + generator, + )? + .into(), + }) + } + /// Returns an `io::Write` implementer you can write bytes to /// to send TLS1.3 early data (a.k.a. "0-RTT data") to the server. /// @@ -666,7 +684,33 @@ impl ConnectionCore { data: &mut data, }; - let state = hs::start_handshake(name, extra_exts, config, &mut cx)?; + let state = + hs::start_handshake:: [u8; 32]>(name, extra_exts, config, &mut cx, None)?; + Ok(Self::new(state, data, common_state)) + } + + pub(crate) fn for_client_with_session_id_generator( + config: Arc, + name: ServerName, + extra_exts: Vec, + proto: Protocol, + generator: Option [u8; 32]>, + ) -> Result { + let mut common_state = CommonState::new(Side::Client); + common_state.set_max_fragment_size(config.max_fragment_size)?; + common_state.protocol = proto; + #[cfg(feature = "secret_extraction")] + { + common_state.enable_secret_extraction = config.enable_secret_extraction; + } + let mut data = ClientConnectionData::new(); + + let mut cx = hs::ClientContext { + common: &mut common_state, + data: &mut data, + }; + + let state = hs::start_handshake(name, extra_exts, config, &mut cx, generator)?; Ok(Self::new(state, data, common_state)) } diff --git a/rustls/src/client/hs.rs b/rustls/src/client/hs.rs index 1bef101bf0..4d17e1bcee 100644 --- a/rustls/src/client/hs.rs +++ b/rustls/src/client/hs.rs @@ -10,6 +10,7 @@ use crate::kx; #[cfg(feature = "logging")] use crate::log::{debug, trace}; use crate::msgs::base::Payload; +use crate::msgs::codec::Codec; use crate::msgs::enums::{Compression, ExtensionType}; use crate::msgs::enums::{ECPointFormat, PSKKeyExchangeMode}; use crate::msgs::handshake::ConvertProtocolNameList; @@ -86,12 +87,16 @@ fn find_session( found } -pub(super) fn start_handshake( +pub(super) fn start_handshake( server_name: ServerName, extra_exts: Vec, config: Arc, cx: &mut ClientContext<'_>, -) -> NextStateOrError { + session_id_generator: Option, +) -> NextStateOrError +where + T: Fn(&[u8]) -> [u8; 32], +{ let mut transcript_buffer = HandshakeHashBuffer::new(); if config .client_auth_cert_resolver @@ -113,8 +118,7 @@ pub(super) fn start_handshake( None }; - #[cfg_attr(not(feature = "tls12"), allow(unused_mut))] - let mut session_id = None; + let mut session_id: Option = None; if let Some(_resuming) = &mut resuming { #[cfg(feature = "tls12")] if let ClientSessionValue::Tls12(inner) = &mut _resuming.value { @@ -149,6 +153,7 @@ pub(super) fn start_handshake( extra_exts, may_send_sct_list, None, + session_id_generator, ClientHelloInput { config, resuming, @@ -189,16 +194,20 @@ struct ClientHelloInput { server_name: ServerName, } -fn emit_client_hello_for_retry( +fn emit_client_hello_for_retry( mut transcript_buffer: HandshakeHashBuffer, retryreq: Option<&HelloRetryRequest>, key_share: Option, extra_exts: Vec, may_send_sct_list: bool, suite: Option, + session_id_generator: Option, mut input: ClientHelloInput, cx: &mut ClientContext<'_>, -) -> NextState { +) -> NextState +where + T: Fn(&[u8]) -> [u8; 32], +{ let config = &input.config; let support_tls12 = config.supports_version(ProtocolVersion::TLSv1_2) && !cx.common.is_quic(); let support_tls13 = config.supports_version(ProtocolVersion::TLSv1_3); @@ -308,6 +317,31 @@ fn emit_client_hello_for_retry( None }; + // ref: https://github.com/shadow-tls/rustls/blob/c033c22cdbb6b08adf8b35571ee8427c70512d13/rustls/src/client/hs.rs#L365 + if let Some(generator) = session_id_generator { + let mut buffer = Vec::new(); + match &mut chp.payload { + HandshakePayload::ClientHello(c) => { + c.session_id = SessionId { + len: 32, + data: [0; 32], + }; + } + _ => unreachable!(), + } + chp.encode(&mut buffer); + let session_id = SessionId { + len: 32, + data: generator(&buffer), + }; + match &mut chp.payload { + HandshakePayload::ClientHello(c) => { + c.session_id = session_id; + } + _ => unreachable!(), + } + } + let ch = Message { // "This value MUST be set to 0x0303 for all records generated // by a TLS 1.3 implementation other than an initial ClientHello @@ -841,13 +875,14 @@ impl ExpectServerHelloOrHelloRetryRequest { _ => offered_key_share, }; - Ok(emit_client_hello_for_retry( + Ok(emit_client_hello_for_retry:: [u8; 32]>( transcript_buffer, Some(hrr), Some(key_share), self.extra_exts, may_send_sct_list, Some(cs), + None, self.next.input, cx, )) diff --git a/rustls/src/msgs/handshake.rs b/rustls/src/msgs/handshake.rs index f8be007ed5..3cab19580e 100644 --- a/rustls/src/msgs/handshake.rs +++ b/rustls/src/msgs/handshake.rs @@ -107,8 +107,8 @@ impl From<[u8; 32]> for Random { #[derive(Copy, Clone)] pub struct SessionId { - len: usize, - data: [u8; 32], + pub(crate)len: usize, + pub(crate)data: [u8; 32], } impl fmt::Debug for SessionId { @@ -913,7 +913,7 @@ impl ClientHelloPayload { pub fn get_namedgroups_extension(&self) -> Option<&[NamedGroup]> { let ext = self.find_extension(ExtensionType::EllipticCurves)?; match *ext { - ClientExtension::NamedGroups(ref req) => Some(req), + ClientExtension::NamedGroups(ref req) => Some(req.as_slice()), _ => None, } } @@ -921,7 +921,7 @@ impl ClientHelloPayload { pub fn get_ecpoints_extension(&self) -> Option<&[ECPointFormat]> { let ext = self.find_extension(ExtensionType::ECPointFormats)?; match *ext { - ClientExtension::ECPointFormats(ref req) => Some(req), + ClientExtension::ECPointFormats(ref req) => Some(req.as_slice()), _ => None, } } @@ -998,7 +998,7 @@ impl ClientHelloPayload { pub fn get_psk_modes(&self) -> Option<&[PSKKeyExchangeMode]> { let ext = self.find_extension(ExtensionType::PSKKeyExchangeModes)?; match *ext { - ClientExtension::PresharedKeyModes(ref psk_modes) => Some(psk_modes), + ClientExtension::PresharedKeyModes(ref psk_modes) => Some(psk_modes.as_slice()), _ => None, } } @@ -1254,7 +1254,7 @@ impl ServerHelloPayload { pub fn get_ecpoints_extension(&self) -> Option<&[ECPointFormat]> { let ext = self.find_extension(ExtensionType::ECPointFormats)?; match *ext { - ServerExtension::ECPointFormats(ref fmts) => Some(fmts), + ServerExtension::ECPointFormats(ref fmts) => Some(fmts.as_slice()), _ => None, } }