Skip to content

Commit

Permalink
feat: add custom session id generator, based on shadow-tls's impl
Browse files Browse the repository at this point in the history
VendettaReborn committed Sep 3, 2024
1 parent 8c04dba commit f84c0f8
Showing 3 changed files with 103 additions and 13 deletions.
44 changes: 43 additions & 1 deletion rustls/src/client/client_conn.rs
Original file line number Diff line number Diff line change
@@ -683,6 +683,24 @@ mod connection {
})
}

/// Make a new ClientConnection with a session id generator. `config` controls how
pub fn new_with_session_id_generator(
config: Arc<ClientConfig>,
name: ServerName<'static>,
generator: Option<impl Fn(&[u8]) -> [u8; 32]>,
) -> Result<Self, Error> {
Ok(Self {
inner: ConnectionCore::for_client_with_session_id_generator(
config,
name,
Vec::new(),
Protocol::Tcp,
generator,
)?
.into(),
})
}

/// Returns an `io::Write` implementer you can write bytes to
/// to send TLS1.3 early data (a.k.a. "0-RTT data") to the server.
///
@@ -814,7 +832,31 @@ impl ConnectionCore<ClientConnectionData> {
sendable_plaintext: None,
};

let state = hs::start_handshake(name, extra_exts, config, &mut cx)?;
let state =
hs::start_handshake::<fn(&[u8]) -> [u8; 32]>(name, extra_exts, config, &mut cx, None)?;
Ok(Self::new(state, data, common_state))
}

pub(crate) fn for_client_with_session_id_generator(
config: Arc<ClientConfig>,
name: ServerName<'static>,
extra_exts: Vec<ClientExtension>,
proto: Protocol,
generator: Option<impl Fn(&[u8]) -> [u8; 32]>,
) -> Result<Self, Error> {
let mut common_state = CommonState::new(Side::Client);
common_state.set_max_fragment_size(config.max_fragment_size)?;
common_state.protocol = proto;
common_state.enable_secret_extraction = config.enable_secret_extraction;
let mut data = ClientConnectionData::new();

let mut cx = hs::ClientContext {
common: &mut common_state,
data: &mut data,
sendable_plaintext: None,
};

let state = hs::start_handshake(name, extra_exts, config, &mut cx, generator)?;
Ok(Self::new(state, data, common_state))
}

60 changes: 54 additions & 6 deletions rustls/src/client/hs.rs
Original file line number Diff line number Diff line change
@@ -25,6 +25,7 @@ use crate::error::{Error, PeerIncompatible, PeerMisbehaved};
use crate::hash_hs::HandshakeHashBuffer;
use crate::log::{debug, trace};
use crate::msgs::base::Payload;
use crate::msgs::codec::Codec;
use crate::msgs::enums::{Compression, ECPointFormat, ExtensionType, PSKKeyExchangeMode};
use crate::msgs::handshake::{
CertificateStatusRequest, ClientExtension, ClientHelloPayload, ClientSessionTicket,
@@ -91,12 +92,16 @@ fn find_session(
found
}

pub(super) fn start_handshake(
pub(super) fn start_handshake<T>(
server_name: ServerName<'static>,
extra_exts: Vec<ClientExtension>,
config: Arc<ClientConfig>,
cx: &mut ClientContext<'_>,
) -> NextStateOrError<'static> {
session_id_generator: Option<T>,
) -> NextStateOrError<'static>
where
T: Fn(&[u8]) -> [u8; 32],
{
let mut transcript_buffer = HandshakeHashBuffer::new();
if config
.client_auth_cert_resolver
@@ -117,7 +122,19 @@ pub(super) fn start_handshake(
None
};

let session_id = if let Some(_resuming) = &mut resuming {
let mut session_id: Option<SessionId> = None;
if let Some(_resuming) = &mut resuming {
#[cfg(feature = "tls12")]
if let ClientSessionValue::Tls12(inner) = &mut _resuming.value {
// If we have a ticket, we use the sessionid as a signal that
// we're doing an abbreviated handshake. See section 3.4 in
// RFC5077.
if !inner.ticket().is_empty() {
inner.session_id = SessionId::random(config.provider.secure_random)?;
}
session_id = Some(inner.session_id);
}

debug!("Resuming session");

match &mut _resuming.value {
@@ -169,6 +186,7 @@ pub(super) fn start_handshake(
key_share,
extra_exts,
None,
session_id_generator,
ClientHelloInput {
config,
resuming,
@@ -213,16 +231,20 @@ struct ClientHelloInput {
prev_ech_ext: Option<ClientExtension>,
}

fn emit_client_hello_for_retry(
fn emit_client_hello_for_retry<T>(
mut transcript_buffer: HandshakeHashBuffer,
retryreq: Option<&HelloRetryRequest>,
key_share: Option<Box<dyn ActiveKeyExchange>>,
extra_exts: Vec<ClientExtension>,
suite: Option<SupportedCipherSuite>,
session_id_generator: Option<T>,
mut input: ClientHelloInput,
cx: &mut ClientContext<'_>,
mut ech_state: Option<EchState>,
) -> NextStateOrError<'static> {
) -> NextStateOrError<'static>
where
T: Fn(&[u8]) -> [u8; 32],
{
let config = &input.config;
// Defense in depth: the ECH state should be None if ECH is disabled based on config
// builder semantics.
@@ -461,6 +483,31 @@ fn emit_client_hello_for_retry(
_ => None,
};

// ref: https://github.com/shadow-tls/rustls/blob/c033c22cdbb6b08adf8b35571ee8427c70512d13/rustls/src/client/hs.rs#L365
if let Some(generator) = session_id_generator {
let mut buffer = Vec::new();
match &mut chp.payload {
HandshakePayload::ClientHello(c) => {
c.session_id = SessionId {
len: 32,
data: [0; 32],
};
}
_ => unreachable!(),
}
chp.encode(&mut buffer);
let session_id = SessionId {
len: 32,
data: generator(&buffer),
};
match &mut chp.payload {
HandshakePayload::ClientHello(c) => {
c.session_id = session_id;
}
_ => unreachable!(),
}
}

let ch = Message {
version: match retryreq {
// <https://datatracker.ietf.org/doc/html/rfc8446#section-5.1>:
@@ -1044,12 +1091,13 @@ impl ExpectServerHelloOrHelloRetryRequest {
_ => offered_key_share,
};

emit_client_hello_for_retry(
emit_client_hello_for_retry::<fn(&[u8]) -> [u8; 32]>(
transcript_buffer,
Some(hrr),
Some(key_share),
self.extra_exts,
Some(cs),
None,
self.next.input,
cx,
self.next.ech_state,
12 changes: 6 additions & 6 deletions rustls/src/msgs/handshake.rs
Original file line number Diff line number Diff line change
@@ -115,8 +115,8 @@ impl From<[u8; 32]> for Random {

#[derive(Copy, Clone)]
pub struct SessionId {
len: usize,
data: [u8; 32],
pub(crate)len: usize,
pub(crate)data: [u8; 32],
}

impl fmt::Debug for SessionId {
@@ -985,7 +985,7 @@ impl ClientHelloPayload {
pub(crate) fn namedgroups_extension(&self) -> Option<&[NamedGroup]> {
let ext = self.find_extension(ExtensionType::EllipticCurves)?;
match *ext {
ClientExtension::NamedGroups(ref req) => Some(req),
ClientExtension::NamedGroups(ref req) => Some(req.as_slice()),
_ => None,
}
}
@@ -994,7 +994,7 @@ impl ClientHelloPayload {
pub(crate) fn ecpoints_extension(&self) -> Option<&[ECPointFormat]> {
let ext = self.find_extension(ExtensionType::ECPointFormats)?;
match *ext {
ClientExtension::EcPointFormats(ref req) => Some(req),
ClientExtension::EcPointFormats(ref req) => Some(req.as_slice()),
_ => None,
}
}
@@ -1068,7 +1068,7 @@ impl ClientHelloPayload {
pub(crate) fn psk_modes(&self) -> Option<&[PSKKeyExchangeMode]> {
let ext = self.find_extension(ExtensionType::PSKKeyExchangeModes)?;
match *ext {
ClientExtension::PresharedKeyModes(ref psk_modes) => Some(psk_modes),
ClientExtension::PresharedKeyModes(ref psk_modes) => Some(psk_modes.as_slice()),
_ => None,
}
}
@@ -1367,7 +1367,7 @@ impl ServerHelloPayload {
pub(crate) fn ecpoints_extension(&self) -> Option<&[ECPointFormat]> {
let ext = self.find_extension(ExtensionType::ECPointFormats)?;
match *ext {
ServerExtension::EcPointFormats(ref fmts) => Some(fmts),
ServerExtension::EcPointFormats(ref fmts) => Some(fmts.as_slice()),
_ => None,
}
}

0 comments on commit f84c0f8

Please sign in to comment.