Skip to content

Commit

Permalink
Support customized connection id generator
Browse files Browse the repository at this point in the history
  • Loading branch information
iyangsj committed Jul 12, 2024
1 parent 7f25fc8 commit 76ba349
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 17 deletions.
2 changes: 1 addition & 1 deletion cbindgen.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ sys_includes = ["sys/socket.h", "sys/types.h"]
includes = ["openssl/ssl.h", "tquic_def.h"]

[export]
exclude = ["MAX_CID_LEN", "MIN_CLIENT_INITIAL_LEN", "VINT_MAX"]
exclude = ["MIN_CLIENT_INITIAL_LEN", "VINT_MAX"]

[export.rename]
"Config" = "quic_config_t"
Expand Down
42 changes: 42 additions & 0 deletions include/tquic.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
*/
#define QUIC_VERSION_V1 1

/**
* The Connection ID MUST NOT exceed 20 bytes in QUIC version 1.
* See RFC 9000 Section 17.2
*/
#define MAX_CID_LEN 20

/**
* Available congestion control algorithms.
*/
Expand Down Expand Up @@ -224,6 +230,34 @@ typedef struct quic_packet_send_methods_t {

typedef void *quic_packet_send_context_t;

/**
* Connection Id is an identifier used to identify a QUIC connection
* at an endpoint.
*/
typedef struct ConnectionId {
/**
* length of cid
*/
uint8_t len;
/**
* octets of cid
*/
uint8_t data[MAX_CID_LEN];
} ConnectionId;

typedef struct ConnectionIdGeneratorMethods {
/**
* Generate a new CID
*/
struct ConnectionId (*generate)(void *gctx);
/**
* Return the length of a CID
*/
uint8_t (*cid_len)(void *gctx);
} ConnectionIdGeneratorMethods;

typedef void *ConnectionIdGeneratorContext;

/**
* Meta information of an incoming packet.
*/
Expand Down Expand Up @@ -707,6 +741,14 @@ struct quic_endpoint_t *quic_endpoint_new(struct quic_config_t *config,
*/
void quic_endpoint_free(struct quic_endpoint_t *endpoint);

/**
* Set the connection id generator for the endpoint.
* By default, the random connection id generator is used.
*/
void quic_endpoint_set_cid_generator(struct quic_endpoint_t *endpoint,
const struct ConnectionIdGeneratorMethods *cid_gen_methods,
ConnectionIdGeneratorContext cid_gen_ctx);

/**
* Create a client connection.
* If success, the output parameter `index` carrys the index of the connection.
Expand Down
4 changes: 2 additions & 2 deletions src/connection/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4266,8 +4266,8 @@ pub(crate) mod tests {
server_config: &mut Config,
server_name: &str,
) -> Result<TestPair> {
let mut cli_cid_gen = RandomConnectionIdGenerator::new(client_config.cid_len, None);
let mut srv_cid_gen = RandomConnectionIdGenerator::new(server_config.cid_len, None);
let mut cli_cid_gen = RandomConnectionIdGenerator::new(client_config.cid_len);
let mut srv_cid_gen = RandomConnectionIdGenerator::new(server_config.cid_len);
let client_scid = cli_cid_gen.generate();
let server_scid = srv_cid_gen.generate();
let client_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 9443);
Expand Down
7 changes: 6 additions & 1 deletion src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ impl Endpoint {
) -> Self {
let cid_gen = Box::new(crate::RandomConnectionIdGenerator {
cid_len: config.cid_len,
cid_lifetime: None,
});
let trace_id = if is_server { "SERVER" } else { "CLIENT" };
let buffer = PacketBuffer::new(config.zerortt_buffer_size);
Expand Down Expand Up @@ -802,6 +801,12 @@ impl Endpoint {
self.conns.clear();
}

/// Set the connection id generator
/// By default, the RandomConnectionIdGenerator is used.
pub fn set_cid_generator(&mut self, cid_gen: Box<dyn ConnectionIdGenerator>) {
self.cid_gen = cid_gen;
}

/// Set the unique trace id for the endpoint
pub fn set_trace_id(&mut self, trace_id: String) {
self.trace_id = trace_id
Expand Down
47 changes: 47 additions & 0 deletions src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,21 @@ pub extern "C" fn quic_endpoint_free(endpoint: *mut Endpoint) {
};
}

/// Set the connection id generator for the endpoint.
/// By default, the random connection id generator is used.
#[no_mangle]
pub extern "C" fn quic_endpoint_set_cid_generator(
endpoint: &mut Endpoint,
cid_gen_methods: *const ConnectionIdGeneratorMethods,
cid_gen_ctx: ConnectionIdGeneratorContext,
) {
let cid_generator = Box::new(ConnectionIdGenerator {
methods: cid_gen_methods,
context: cid_gen_ctx,
});
endpoint.set_cid_generator(cid_generator);
}

/// Create a client connection.
/// If success, the output parameter `index` carrys the index of the connection.
/// Note: The `config` specific to the endpoint or server is irrelevant and will be disregarded.
Expand Down Expand Up @@ -1773,6 +1788,38 @@ pub struct PacketOutSpec {
dst_addr_len: socklen_t,
}

#[repr(C)]
pub struct ConnectionIdGeneratorMethods {
/// Generate a new CID
pub generate: fn(gctx: *mut c_void) -> ConnectionId,

/// Return the length of a CID
pub cid_len: fn(gctx: *mut c_void) -> u8,
}

#[repr(transparent)]
pub struct ConnectionIdGeneratorContext(*mut c_void);

/// cbindgen:no-export
#[repr(C)]
pub struct ConnectionIdGenerator {
pub methods: *const ConnectionIdGeneratorMethods,
pub context: ConnectionIdGeneratorContext,
}

impl crate::ConnectionIdGenerator for ConnectionIdGenerator {
/// Generate a new CID
fn generate(&mut self) -> ConnectionId {
unsafe { ((*self.methods).generate)(self.context.0) }
}

/// Return the length of a CID
fn cid_len(&self) -> usize {
let cid_len = unsafe { ((*self.methods).cid_len)(self.context.0) };
cid_len as usize
}
}

/// Create default config for HTTP3.
#[no_mangle]
pub extern "C" fn http3_config_new() -> *mut Http3Config {
Expand Down
16 changes: 3 additions & 13 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ pub type Result<T> = std::result::Result<T, Error>;

/// Connection Id is an identifier used to identify a QUIC connection
/// at an endpoint.
#[repr(C)]
#[derive(Clone, Copy, Eq, PartialEq, Ord, PartialOrd, Hash, Default)]
pub struct ConnectionId {
/// length of cid
Expand Down Expand Up @@ -214,9 +215,6 @@ pub trait ConnectionIdGenerator {
/// Return the length of a CID
fn cid_len(&self) -> usize;

/// Return the lifetime of CID
fn cid_lifetime(&self) -> Option<Duration>;

/// Generate a new CID and associated reset token.
fn generate_cid_and_token(&mut self, reset_token_key: &hmac::Key) -> (ConnectionId, u128) {
let scid = self.generate();
Expand All @@ -229,14 +227,12 @@ pub trait ConnectionIdGenerator {
#[derive(Debug, Clone, Copy)]
pub struct RandomConnectionIdGenerator {
cid_len: usize,
cid_lifetime: Option<Duration>,
}

impl RandomConnectionIdGenerator {
pub fn new(cid_len: usize, cid_lifetime: Option<Duration>) -> Self {
pub fn new(cid_len: usize) -> Self {
Self {
cid_len: cmp::min(cid_len, MAX_CID_LEN),
cid_lifetime,
}
}
}
Expand All @@ -251,10 +247,6 @@ impl ConnectionIdGenerator for RandomConnectionIdGenerator {
fn cid_len(&self) -> usize {
self.cid_len
}

fn cid_lifetime(&self) -> Option<Duration> {
self.cid_lifetime
}
}

/// Meta information about a packet.
Expand Down Expand Up @@ -1085,11 +1077,9 @@ mod tests {

#[test]
fn connection_id() {
let lifetime = Duration::from_secs(3600);
let mut cid_gen = RandomConnectionIdGenerator::new(8, Some(lifetime));
let mut cid_gen = RandomConnectionIdGenerator::new(8);
let cid = cid_gen.generate();
assert_eq!(cid.len(), cid_gen.cid_len());
assert_eq!(Some(lifetime), cid_gen.cid_lifetime());

let cid = ConnectionId {
len: 4,
Expand Down

0 comments on commit 76ba349

Please sign in to comment.