Skip to content

Commit

Permalink
feat: add custom session id generator, based on shadow-tls's impl
Browse files Browse the repository at this point in the history
  • Loading branch information
VendettaReborn committed Mar 19, 2024
1 parent 40e4b5d commit 6f57a01
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 14 deletions.
46 changes: 45 additions & 1 deletion rustls/src/client/client_conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,24 @@ impl ClientConnection {
})
}

/// Make a new ClientConnection with a session id generator. `config` controls how
pub fn new_with_session_id_generator(
config: Arc<ClientConfig>,
name: ServerName,
generator: Option<impl Fn(&[u8]) -> [u8; 32]>,
) -> Result<Self, Error> {
Ok(Self {
inner: ConnectionCore::for_client_with_session_id_generator(
config,
name,
Vec::new(),
Protocol::Tcp,
generator,
)?
.into(),
})
}

/// Returns an `io::Write` implementer you can write bytes to
/// to send TLS1.3 early data (a.k.a. "0-RTT data") to the server.
///
Expand Down Expand Up @@ -663,7 +681,33 @@ impl ConnectionCore<ClientConnectionData> {
data: &mut data,
};

let state = hs::start_handshake(name, extra_exts, config, &mut cx)?;
let state =
hs::start_handshake::<fn(&[u8]) -> [u8; 32]>(name, extra_exts, config, &mut cx, None)?;
Ok(Self::new(state, data, common_state))
}

pub(crate) fn for_client_with_session_id_generator(
config: Arc<ClientConfig>,
name: ServerName,
extra_exts: Vec<ClientExtension>,
proto: Protocol,
generator: Option<impl Fn(&[u8]) -> [u8; 32]>,
) -> Result<Self, Error> {
let mut common_state = CommonState::new(Side::Client);
common_state.set_max_fragment_size(config.max_fragment_size)?;
common_state.protocol = proto;
#[cfg(feature = "secret_extraction")]
{
common_state.enable_secret_extraction = config.enable_secret_extraction;
}
let mut data = ClientConnectionData::new();

let mut cx = hs::ClientContext {
common: &mut common_state,
data: &mut data,
};

let state = hs::start_handshake(name, extra_exts, config, &mut cx, generator)?;
Ok(Self::new(state, data, common_state))
}

Expand Down
49 changes: 42 additions & 7 deletions rustls/src/client/hs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use crate::kx;
#[cfg(feature = "logging")]
use crate::log::{debug, trace};
use crate::msgs::base::Payload;
use crate::msgs::codec::Codec;
use crate::msgs::enums::{Compression, ExtensionType};
use crate::msgs::enums::{ECPointFormat, PSKKeyExchangeMode};
use crate::msgs::handshake::ConvertProtocolNameList;
Expand Down Expand Up @@ -86,12 +87,16 @@ fn find_session(
found
}

pub(super) fn start_handshake(
pub(super) fn start_handshake<T>(
server_name: ServerName,
extra_exts: Vec<ClientExtension>,
config: Arc<ClientConfig>,
cx: &mut ClientContext<'_>,
) -> NextStateOrError {
session_id_generator: Option<T>,
) -> NextStateOrError
where
T: Fn(&[u8]) -> [u8; 32],
{
let mut transcript_buffer = HandshakeHashBuffer::new();
if config
.client_auth_cert_resolver
Expand All @@ -113,8 +118,7 @@ pub(super) fn start_handshake(
None
};

#[cfg_attr(not(feature = "tls12"), allow(unused_mut))]
let mut session_id = None;
let mut session_id: Option<SessionId> = None;
if let Some(_resuming) = &mut resuming {
#[cfg(feature = "tls12")]
if let ClientSessionValue::Tls12(inner) = &mut _resuming.value {
Expand Down Expand Up @@ -149,6 +153,7 @@ pub(super) fn start_handshake(
extra_exts,
may_send_sct_list,
None,
session_id_generator,
ClientHelloInput {
config,
resuming,
Expand Down Expand Up @@ -189,16 +194,20 @@ struct ClientHelloInput {
server_name: ServerName,
}

fn emit_client_hello_for_retry(
fn emit_client_hello_for_retry<T>(
mut transcript_buffer: HandshakeHashBuffer,
retryreq: Option<&HelloRetryRequest>,
key_share: Option<kx::KeyExchange>,
extra_exts: Vec<ClientExtension>,
may_send_sct_list: bool,
suite: Option<SupportedCipherSuite>,
session_id_generator: Option<T>,
mut input: ClientHelloInput,
cx: &mut ClientContext<'_>,
) -> NextState {
) -> NextState
where
T: Fn(&[u8]) -> [u8; 32],
{
let config = &input.config;
let support_tls12 = config.supports_version(ProtocolVersion::TLSv1_2) && !cx.common.is_quic();
let support_tls13 = config.supports_version(ProtocolVersion::TLSv1_3);
Expand Down Expand Up @@ -308,6 +317,31 @@ fn emit_client_hello_for_retry(
None
};

// ref: https://github.com/shadow-tls/rustls/blob/c033c22cdbb6b08adf8b35571ee8427c70512d13/rustls/src/client/hs.rs#L365
if let Some(generator) = session_id_generator {
let mut buffer = Vec::new();
match &mut chp.payload {
HandshakePayload::ClientHello(c) => {
c.session_id = SessionId {
len: 32,
data: [0; 32],
};
}
_ => unreachable!(),
}
chp.encode(&mut buffer);
let session_id = SessionId {
len: 32,
data: generator(&buffer),
};
match &mut chp.payload {
HandshakePayload::ClientHello(c) => {
c.session_id = session_id;
}
_ => unreachable!(),
}
}

let ch = Message {
// "This value MUST be set to 0x0303 for all records generated
// by a TLS 1.3 implementation other than an initial ClientHello
Expand Down Expand Up @@ -819,13 +853,14 @@ impl ExpectServerHelloOrHelloRetryRequest {
_ => offered_key_share,
};

Ok(emit_client_hello_for_retry(
Ok(emit_client_hello_for_retry::<fn(&[u8]) -> [u8; 32]>(
transcript_buffer,
Some(hrr),
Some(key_share),
self.extra_exts,
may_send_sct_list,
Some(cs),
None,
self.next.input,
cx,
))
Expand Down
12 changes: 6 additions & 6 deletions rustls/src/msgs/handshake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ impl From<[u8; 32]> for Random {

#[derive(Copy, Clone)]
pub struct SessionId {
len: usize,
data: [u8; 32],
pub(crate)len: usize,
pub(crate)data: [u8; 32],
}

impl fmt::Debug for SessionId {
Expand Down Expand Up @@ -913,15 +913,15 @@ impl ClientHelloPayload {
pub fn get_namedgroups_extension(&self) -> Option<&[NamedGroup]> {
let ext = self.find_extension(ExtensionType::EllipticCurves)?;
match *ext {
ClientExtension::NamedGroups(ref req) => Some(req),
ClientExtension::NamedGroups(ref req) => Some(req.as_slice()),
_ => None,
}
}

pub fn get_ecpoints_extension(&self) -> Option<&[ECPointFormat]> {
let ext = self.find_extension(ExtensionType::ECPointFormats)?;
match *ext {
ClientExtension::ECPointFormats(ref req) => Some(req),
ClientExtension::ECPointFormats(ref req) => Some(req.as_slice()),
_ => None,
}
}
Expand Down Expand Up @@ -998,7 +998,7 @@ impl ClientHelloPayload {
pub fn get_psk_modes(&self) -> Option<&[PSKKeyExchangeMode]> {
let ext = self.find_extension(ExtensionType::PSKKeyExchangeModes)?;
match *ext {
ClientExtension::PresharedKeyModes(ref psk_modes) => Some(psk_modes),
ClientExtension::PresharedKeyModes(ref psk_modes) => Some(psk_modes.as_slice()),
_ => None,
}
}
Expand Down Expand Up @@ -1254,7 +1254,7 @@ impl ServerHelloPayload {
pub fn get_ecpoints_extension(&self) -> Option<&[ECPointFormat]> {
let ext = self.find_extension(ExtensionType::ECPointFormats)?;
match *ext {
ServerExtension::ECPointFormats(ref fmts) => Some(fmts),
ServerExtension::ECPointFormats(ref fmts) => Some(fmts.as_slice()),
_ => None,
}
}
Expand Down

0 comments on commit 6f57a01

Please sign in to comment.