From e171e33510a713847cdf682b0de677e57233b6ce Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Mon, 31 Jul 2023 17:34:21 +0000 Subject: [PATCH] Handle out of sessions and out of exchanges --- rs-matter/src/core.rs | 6 + rs-matter/src/error.rs | 2 + rs-matter/src/secure_channel/case.rs | 29 ++-- rs-matter/src/secure_channel/common.rs | 4 +- rs-matter/src/secure_channel/pake.rs | 28 ++-- rs-matter/src/transport/core.rs | 220 ++++++++++++++++++------- rs-matter/src/transport/exchange.rs | 153 +++++++++++++++-- rs-matter/src/transport/session.rs | 87 +++------- 8 files changed, 359 insertions(+), 170 deletions(-) diff --git a/rs-matter/src/core.rs b/rs-matter/src/core.rs index f0196526..25ede51c 100644 --- a/rs-matter/src/core.rs +++ b/rs-matter/src/core.rs @@ -17,6 +17,8 @@ use core::{borrow::Borrow, cell::RefCell}; +use embassy_sync::{blocking_mutex::raw::NoopRawMutex, mutex::Mutex}; + use crate::{ acl::AclMgr, data_model::{ @@ -61,6 +63,8 @@ pub struct Matter<'a> { dev_att: &'a dyn DevAttDataFetcher, pub(crate) port: u16, pub(crate) exchanges: RefCell>, + pub(crate) ephemeral: RefCell>, + pub(crate) ephemeral_mutex: Mutex, pub session_mgr: RefCell, // Public for tests } @@ -108,6 +112,8 @@ impl<'a> Matter<'a> { dev_att, port, exchanges: RefCell::new(heapless::Vec::new()), + ephemeral: RefCell::new(None), + ephemeral_mutex: Mutex::new(()), session_mgr: RefCell::new(SessionMgr::new(epoch, rand)), } } diff --git a/rs-matter/src/error.rs b/rs-matter/src/error.rs index 91ba77e4..5527e1f4 100644 --- a/rs-matter/src/error.rs +++ b/rs-matter/src/error.rs @@ -47,6 +47,8 @@ pub enum ErrorCode { NoMemory, NoSession, NoSpace, + NoSpaceExchanges, + NoSpaceSessions, NoSpaceAckTable, NoSpaceRetransTable, NoTagFound, diff --git a/rs-matter/src/secure_channel/case.rs b/rs-matter/src/secure_channel/case.rs index 155dfbf8..80522e2d 100644 --- a/rs-matter/src/secure_channel/case.rs +++ b/rs-matter/src/secure_channel/case.rs @@ -96,7 +96,7 @@ impl<'a> Case<'a> { ) -> Result<(), Error> { rx.check_proto_opcode(OpCode::CASESigma3 as _)?; - let status = { + let result = { let fabric_mgr = self.fabric_mgr.borrow(); let fabric = fabric_mgr.get_fabric(case_session.local_fabric_idx)?; @@ -133,7 +133,7 @@ impl<'a> Case<'a> { if let Err(e) = Case::validate_certs(fabric, &initiator_noc, initiator_icac_mut) { error!("Certificate Chain doesn't match: {}", e); - SCStatusCodes::InvalidParameter + Err(SCStatusCodes::InvalidParameter) } else if let Err(e) = Case::validate_sigma3_sign( d.initiator_noc.0, d.initiator_icac.map(|a| a.0), @@ -142,30 +142,33 @@ impl<'a> Case<'a> { case_session, ) { error!("Sigma3 Signature doesn't match: {}", e); - SCStatusCodes::InvalidParameter + Err(SCStatusCodes::InvalidParameter) } else { // Only now do we add this message to the TT Hash let mut peer_catids: NocCatIds = Default::default(); initiator_noc.get_cat_ids(&mut peer_catids); case_session.tt_hash.update(rx.as_slice())?; - let clone_data = Case::get_session_clone_data( + + Ok(Case::get_session_clone_data( fabric.ipk.op_key(), fabric.get_node_id(), initiator_noc.get_node_id()?, exchange.with_session(|sess| Ok(sess.get_peer_addr()))?, case_session, &peer_catids, - )?; - - // TODO: Handle NoSpace - exchange - .with_session_mgr_mut(|sess_mgr| sess_mgr.clone_session(&clone_data))?; - - SCStatusCodes::SessionEstablishmentSuccess + )?) } } else { - SCStatusCodes::NoSharedTrustRoots + Err(SCStatusCodes::NoSharedTrustRoots) + } + }; + + let status = match result { + Ok(clone_data) => { + exchange.clone_session(tx, &clone_data).await?; + SCStatusCodes::SessionEstablishmentSuccess } + Err(status) => status, }; complete_with_status(exchange, tx, status, None).await @@ -201,7 +204,7 @@ impl<'a> Case<'a> { return Ok(()); } - let local_sessid = exchange.with_session_mgr_mut(|mgr| Ok(mgr.get_next_sess_id()))?; + let local_sessid = exchange.get_next_sess_id(); case_session.peer_sessid = r.initiator_sessid; case_session.local_sessid = local_sessid; case_session.tt_hash.update(rx_buf)?; diff --git a/rs-matter/src/secure_channel/common.rs b/rs-matter/src/secure_channel/common.rs index 2f00ed45..eeb1e4a7 100644 --- a/rs-matter/src/secure_channel/common.rs +++ b/rs-matter/src/secure_channel/common.rs @@ -78,8 +78,8 @@ pub fn create_sc_status_report( // the session will be closed soon GeneralCode::Success } - SCStatusCodes::Busy - | SCStatusCodes::InvalidParameter + SCStatusCodes::Busy => GeneralCode::Busy, + SCStatusCodes::InvalidParameter | SCStatusCodes::NoSharedTrustRoots | SCStatusCodes::SessionNotFound => GeneralCode::Failure, }; diff --git a/rs-matter/src/secure_channel/pake.rs b/rs-matter/src/secure_channel/pake.rs index 638e93ef..794cb185 100644 --- a/rs-matter/src/secure_channel/pake.rs +++ b/rs-matter/src/secure_channel/pake.rs @@ -167,9 +167,9 @@ impl<'a> Pake<'a> { self.update_timeout(exchange, tx, true).await?; let cA = extract_pasepake_1_or_3_params(rx.as_slice())?; - let (status_code, ke) = spake2p.handle_cA(cA); + let (status, ke) = spake2p.handle_cA(cA); - let clone_data = if status_code == SCStatusCodes::SessionEstablishmentSuccess { + let result = if status == SCStatusCodes::SessionEstablishmentSuccess { // Get the keys let ke = ke.ok_or(ErrorCode::Invalid)?; let mut session_keys: [u8; 48] = [0; 48]; @@ -194,22 +194,22 @@ impl<'a> Pake<'a> { .att_challenge .copy_from_slice(&session_keys[32..48]); - // Queue a transport mgr request to add a new session - Some(clone_data) + Ok(clone_data) } else { - None + Err(status) }; - if let Some(clone_data) = clone_data { - // TODO: Handle NoSpace - exchange.with_session_mgr_mut(|sess_mgr| sess_mgr.clone_session(&clone_data))?; + let status = match result { + Ok(clone_data) => { + exchange.clone_session(tx, &clone_data).await?; + self.pase.borrow_mut().disable_pase_session(mdns)?; - self.pase.borrow_mut().disable_pase_session(mdns)?; - } - - complete_with_status(exchange, tx, status_code, None).await?; + SCStatusCodes::SessionEstablishmentSuccess + } + Err(status) => status, + }; - Ok(()) + complete_with_status(exchange, tx, status, None).await } #[allow(non_snake_case)] @@ -273,7 +273,7 @@ impl<'a> Pake<'a> { let mut our_random: [u8; 32] = [0; 32]; (self.pase.borrow().rand)(&mut our_random); - let local_sessid = exchange.with_session_mgr_mut(|mgr| Ok(mgr.get_next_sess_id()))?; + let local_sessid = exchange.get_next_sess_id(); let spake2p_data: u32 = ((local_sessid as u32) << 16) | a.initiator_ssid as u32; spake2p.set_app_data(spake2p_data); diff --git a/rs-matter/src/transport/core.rs b/rs-matter/src/transport/core.rs index 7300c659..dccf5cb7 100644 --- a/rs-matter/src/transport/core.rs +++ b/rs-matter/src/transport/core.rs @@ -25,6 +25,9 @@ use embassy_time::{Duration, Timer}; use log::{error, info, warn}; +use crate::interaction_model::core::IMStatusCode; +use crate::secure_channel::common::SCStatusCodes; +use crate::secure_channel::status_report::{create_status_report, GeneralCode}; use crate::utils::select::Notification; use crate::CommissioningData; use crate::{ @@ -41,6 +44,7 @@ use crate::{ Matter, }; +use super::exchange::SessionId; use super::{ exchange::{ Exchange, ExchangeCtr, ExchangeCtx, ExchangeId, ExchangeState, Role, MAX_EXCHANGES, @@ -97,7 +101,7 @@ impl RunBuffers { pub struct PacketBuffers { tx: [TxBuf; MAX_EXCHANGES], rx: [RxBuf; MAX_EXCHANGES], - sx: [SxBuf; MAX_EXCHANGES], + sx: [SxBuf; MAX_EXCHANGES + 1], } impl PacketBuffers { @@ -107,7 +111,7 @@ impl PacketBuffers { const TX_INIT: [TxBuf; MAX_EXCHANGES] = [Self::TX_ELEM; MAX_EXCHANGES]; const RX_INIT: [RxBuf; MAX_EXCHANGES] = [Self::RX_ELEM; MAX_EXCHANGES]; - const SX_INIT: [SxBuf; MAX_EXCHANGES] = [Self::SX_ELEM; MAX_EXCHANGES]; + const SX_INIT: [SxBuf; MAX_EXCHANGES + 1] = [Self::SX_ELEM; MAX_EXCHANGES + 1]; #[inline(always)] pub const fn new() -> Self { @@ -266,7 +270,12 @@ impl<'a> Matter<'a> { .unwrap(); } - let mut rx = pin!(self.handle_rx_multiplex(rx_pipe, construction_notification, &channel)); + let mut rx = pin!(self.handle_rx_multiplex( + rx_pipe, + unsafe { buffers.sx[MAX_EXCHANGES].assume_init_mut() }, + construction_notification, + &channel, + )); let result = select(&mut rx, select_slice(&mut handlers)).await; @@ -291,7 +300,7 @@ impl<'a> Matter<'a> { if data.chunk.is_none() { let mut tx = alloc!(Packet::new_tx(data.buf)); - if self.pull_tx(&mut tx).await? { + if self.pull_tx(&mut tx)? { data.chunk = Some(Chunk { start: tx.get_writebuf()?.get_start(), end: tx.get_writebuf()?.get_tail(), @@ -315,12 +324,15 @@ impl<'a> Matter<'a> { pub async fn handle_rx_multiplex<'t, 'e, const N: usize>( &'t self, rx_pipe: &Pipe<'_>, + sts_buf: &mut [u8; MAX_RX_STATUS_BUF_SIZE], construction_notification: &'e Notification, channel: &Channel, N>, ) -> Result<(), Error> where 't: 'e, { + let mut sts_tx = alloc!(Packet::new_tx(sts_buf)); + loop { info!("Transport: waiting for incoming packets"); @@ -331,8 +343,9 @@ impl<'a> Matter<'a> { let mut rx = alloc!(Packet::new_rx(&mut data.buf[chunk.start..chunk.end])); rx.peer = chunk.addr; - if let Some(exchange_ctr) = - self.process_rx(construction_notification, &mut rx)? + if let Some(exchange_ctr) = self + .process_rx(construction_notification, &mut rx, &mut sts_tx) + .await? { let exchange_id = exchange_ctr.id().clone(); @@ -444,24 +457,39 @@ impl<'a> Matter<'a> { self.session_mgr.borrow_mut().reset(); } - pub fn process_rx<'r>( + pub async fn process_rx<'r>( &'r self, construction_notification: &'r Notification, src_rx: &mut Packet<'_>, + sts_tx: &mut Packet<'_>, ) -> Result>, Error> { + src_rx.plain_hdr_decode()?; + self.purge()?; + let (exchange_index, new) = loop { + let result = self.assign_exchange(&mut self.exchanges.borrow_mut(), src_rx); + + match result { + Err(e) => match e.code() { + ErrorCode::Duplicate => { + self.send_notification.signal(()); + return Ok(None); + } + // TODO: NoSession, NoExchange and others + ErrorCode::NoSpaceSessions => self.evict_session(sts_tx).await?, + ErrorCode::NoSpaceExchanges => { + self.send_busy(src_rx, sts_tx).await?; + return Ok(None); + } + _ => break Err(e), + }, + other => break other, + } + }?; + let mut exchanges = self.exchanges.borrow_mut(); - let (ctx, new) = match self.post_recv(&mut exchanges, src_rx) { - Ok((ctx, new)) => (ctx, new), - Err(e) => match e.code() { - ErrorCode::Duplicate => { - self.send_notification.signal(()); - return Ok(None); - } - _ => Err(e)?, - }, - }; + let ctx = &mut exchanges[exchange_index]; src_rx.log("Got packet"); @@ -516,6 +544,8 @@ impl<'a> Matter<'a> { ExchangeState::ExchangeRecv { rx, notification, .. } => { + // TODO: Handle Busy status codes + let rx = unsafe { rx.as_mut() }.unwrap(); rx.load(src_rx)?; @@ -572,12 +602,24 @@ impl<'a> Matter<'a> { Ok(()) } - pub async fn pull_tx(&self, dest_tx: &mut Packet<'_>) -> Result { + pub fn pull_tx(&self, dest_tx: &mut Packet) -> Result { self.purge()?; + let mut ephemeral = self.ephemeral.borrow_mut(); let mut exchanges = self.exchanges.borrow_mut(); - let ctx = exchanges.iter_mut().find(|ctx| { + self.pull_tx_exchanges(ephemeral.iter_mut().chain(exchanges.iter_mut()), dest_tx) + } + + fn pull_tx_exchanges<'i, I>( + &self, + mut exchanges: I, + dest_tx: &mut Packet, + ) -> Result + where + I: Iterator, + { + let ctx = exchanges.find(|ctx| { matches!( &ctx.state, ExchangeState::Acknowledge { .. } @@ -629,10 +671,15 @@ impl<'a> Matter<'a> { let tx = unsafe { tx.as_ref() }.unwrap(); dest_tx.load(tx)?; - *state = ExchangeState::CompleteAcknowledge { - _tx: tx as *const _, - notification: *notification, - }; + if dest_tx.is_reliable() { + *state = ExchangeState::CompleteAcknowledge { + _tx: tx as *const _, + notification: *notification, + }; + } else { + unsafe { notification.as_ref() }.unwrap().signal(()); + ctx.state = ExchangeState::Closed; + } true } @@ -648,8 +695,6 @@ impl<'a> Matter<'a> { if send { dest_tx.log("Sending packet"); - - self.pre_send(ctx, dest_tx)?; self.notify_changed(); return Ok(true); @@ -675,13 +720,88 @@ impl<'a> Matter<'a> { Ok(()) } - fn post_recv<'r>( + pub(crate) async fn evict_session(&self, tx: &mut Packet<'_>) -> Result<(), Error> { + let sess_index = self.session_mgr.borrow().get_session_for_eviction(); + if let Some(sess_index) = sess_index { + let ctx = { + create_status_report( + tx, + GeneralCode::Success, + PROTO_ID_SECURE_CHANNEL as _, + SCStatusCodes::CloseSession as _, + None, + )?; + + let mut session_mgr = self.session_mgr.borrow_mut(); + let session_id = session_mgr.mut_by_index(sess_index).unwrap().id(); + warn!("Evicting session: {:?}", session_id); + + let ctx = ExchangeCtx::prep_ephemeral(session_id, &mut session_mgr, None, tx)?; + + session_mgr.remove(sess_index); + + ctx + }; + + self.send_ephemeral(ctx, tx).await + } else { + Err(ErrorCode::NoSpaceSessions.into()) + } + } + + async fn send_busy(&self, rx: &Packet<'_>, tx: &mut Packet<'_>) -> Result<(), Error> { + warn!("Sending Busy as all exchanges are occupied"); + + create_status_report( + tx, + GeneralCode::Busy, + rx.get_proto_id() as _, + if rx.get_proto_id() == PROTO_ID_SECURE_CHANNEL { + SCStatusCodes::Busy as _ + } else { + IMStatusCode::Busy as _ + }, + None, // TODO: ms + )?; + + let ctx = ExchangeCtx::prep_ephemeral( + SessionId::load(rx), + &mut self.session_mgr.borrow_mut(), + Some(rx), + tx, + )?; + + self.send_ephemeral(ctx, tx).await + } + + async fn send_ephemeral(&self, mut ctx: ExchangeCtx, tx: &mut Packet<'_>) -> Result<(), Error> { + let _guard = self.ephemeral_mutex.lock().await; + + let notification = Notification::new(); + + let tx: &'static mut Packet<'static> = unsafe { core::mem::transmute(tx) }; + + ctx.state = ExchangeState::Complete { + tx, + notification: ¬ification, + }; + + *self.ephemeral.borrow_mut() = Some(ctx); + + self.send_notification.signal(()); + + notification.wait().await; + + *self.ephemeral.borrow_mut() = None; + + Ok(()) + } + + fn assign_exchange( &self, - exchanges: &'r mut heapless::Vec, + exchanges: &mut heapless::Vec, rx: &mut Packet<'_>, - ) -> Result<(&'r mut ExchangeCtx, bool), Error> { - rx.plain_hdr_decode()?; - + ) -> Result<(usize, bool), Error> { // Get the session let mut session_mgr = self.session_mgr.borrow_mut(); @@ -693,8 +813,7 @@ impl<'a> Matter<'a> { session.recv(self.epoch, rx)?; // Get the exchange - // TODO: Handle out of space - let (exch, new) = Self::register( + let (exchange_index, new) = Self::register( exchanges, ExchangeId::load(rx), Role::complementary(rx.proto.is_initiator()), @@ -703,32 +822,9 @@ impl<'a> Matter<'a> { )?; // Message Reliability Protocol - exch.mrp.recv(rx, self.epoch)?; - - Ok((exch, new)) - } - - fn pre_send(&self, ctx: &mut ExchangeCtx, tx: &mut Packet) -> Result<(), Error> { - let mut session_mgr = self.session_mgr.borrow_mut(); - let sess_index = session_mgr - .get( - ctx.id.session_id.id, - ctx.id.session_id.peer_addr, - ctx.id.session_id.peer_nodeid, - ctx.id.session_id.is_encrypted, - ) - .ok_or(ErrorCode::NoSession)?; - - let session = session_mgr.mut_by_index(sess_index).unwrap(); - - tx.proto.exch_id = ctx.id.id; - if ctx.role == Role::Initiator { - tx.proto.set_initiator(); - } + exchanges[exchange_index].mrp.recv(rx, self.epoch)?; - session.pre_send(tx)?; - ctx.mrp.pre_send(tx)?; - session_mgr.send(sess_index, tx) + Ok((exchange_index, new)) } fn register( @@ -736,7 +832,7 @@ impl<'a> Matter<'a> { id: ExchangeId, role: Role, create_new: bool, - ) -> Result<(&mut ExchangeCtx, bool), Error> { + ) -> Result<(usize, bool), Error> { let exchange_index = exchanges .iter_mut() .enumerate() @@ -745,7 +841,7 @@ impl<'a> Matter<'a> { if let Some(exchange_index) = exchange_index { let exchange = &mut exchanges[exchange_index]; if exchange.role == role { - Ok((exchange, false)) + Ok((exchange_index, false)) } else { Err(ErrorCode::NoExchange.into()) } @@ -759,9 +855,11 @@ impl<'a> Matter<'a> { state: ExchangeState::Active, }; - exchanges.push(exchange).map_err(|_| ErrorCode::NoSpace)?; + exchanges + .push(exchange) + .map_err(|_| ErrorCode::NoSpaceExchanges)?; - Ok((exchanges.iter_mut().next_back().unwrap(), true)) + Ok((exchanges.len() - 1, true)) } else { Err(ErrorCode::NoExchange.into()) } diff --git a/rs-matter/src/transport/exchange.rs b/rs-matter/src/transport/exchange.rs index d4d28b23..d6291b04 100644 --- a/rs-matter/src/transport/exchange.rs +++ b/rs-matter/src/transport/exchange.rs @@ -1,7 +1,7 @@ use crate::{ acl::Accessor, error::{Error, ErrorCode}, - utils::select::Notification, + utils::{epoch::Epoch, select::Notification}, Matter, }; @@ -9,7 +9,7 @@ use super::{ mrp::ReliableMessage, network::Address, packet::Packet, - session::{Session, SessionMgr}, + session::{CloneData, Session, SessionMgr}, }; pub const MAX_EXCHANGES: usize = 8; @@ -46,6 +46,101 @@ impl ExchangeCtx { ) -> Option<&'r mut ExchangeCtx> { exchanges.iter_mut().find(|exchange| exchange.id == *id) } + + pub fn new_ephemeral(session_id: SessionId, reply_to: Option<&Packet<'_>>) -> Self { + Self { + id: ExchangeId { + id: if let Some(rx) = reply_to { + rx.proto.exch_id + } else { + 0 + }, + session_id: session_id.clone(), + }, + role: if reply_to.is_some() { + Role::Responder + } else { + Role::Initiator + }, + mrp: ReliableMessage::new(), + state: ExchangeState::Active, + } + } + + pub(crate) fn prep_ephemeral( + session_id: SessionId, + session_mgr: &mut SessionMgr, + reply_to: Option<&Packet<'_>>, + tx: &mut Packet<'_>, + ) -> Result { + let mut ctx = Self::new_ephemeral(session_id.clone(), reply_to); + + let sess_index = session_mgr.get( + session_id.id, + session_id.peer_addr, + session_id.peer_nodeid, + session_id.is_encrypted, + ); + + let epoch = session_mgr.epoch; + let rand = session_mgr.rand; + + if let Some(rx) = reply_to { + ctx.mrp.recv(rx, epoch)?; + } else { + tx.proto.set_initiator(); + } + + tx.unset_reliable(); + + if let Some(sess_index) = sess_index { + let session = session_mgr.mut_by_index(sess_index).unwrap(); + ctx.pre_send_sess(session, tx, epoch)?; + } else { + let mut session = + Session::new(session_id.peer_addr, session_id.peer_nodeid, epoch, rand); + ctx.pre_send_sess(&mut session, tx, epoch)?; + } + + Ok(ctx) + } + + pub(crate) fn pre_send( + &mut self, + session_mgr: &mut SessionMgr, + tx: &mut Packet, + ) -> Result<(), Error> { + let epoch = session_mgr.epoch; + + let sess_index = session_mgr + .get( + self.id.session_id.id, + self.id.session_id.peer_addr, + self.id.session_id.peer_nodeid, + self.id.session_id.is_encrypted, + ) + .ok_or(ErrorCode::NoSession)?; + + let session = session_mgr.mut_by_index(sess_index).unwrap(); + + self.pre_send_sess(session, tx, epoch) + } + + pub(crate) fn pre_send_sess( + &mut self, + session: &mut Session, + tx: &mut Packet, + epoch: Epoch, + ) -> Result<(), Error> { + tx.proto.exch_id = self.id.id; + if self.role == Role::Initiator { + tx.proto.set_initiator(); + } + + session.pre_send(tx)?; + self.mrp.pre_send(tx)?; + session.send(epoch, tx) + } } #[derive(Debug, Clone)] @@ -192,15 +287,6 @@ impl<'a> Exchange<'a> { self.with_session_mut(|sess| f(sess)) } - pub fn with_session_mgr_mut(&self, f: F) -> Result - where - F: FnOnce(&mut SessionMgr) -> Result, - { - let mut session_mgr = self.matter.session_mgr.borrow_mut(); - - f(&mut session_mgr) - } - pub async fn acknowledge(&mut self) -> Result<(), Error> { let wait = self.with_ctx_mut(|_self, ctx| { if !matches!(ctx.state, ExchangeState::Active) { @@ -226,8 +312,12 @@ impl<'a> Exchange<'a> { Ok(()) } - pub async fn exchange(&mut self, tx: &Packet<'_>, rx: &mut Packet<'_>) -> Result<(), Error> { - let tx: &Packet<'static> = unsafe { core::mem::transmute(tx) }; + pub async fn exchange( + &mut self, + tx: &mut Packet<'_>, + rx: &mut Packet<'_>, + ) -> Result<(), Error> { + let tx: &mut Packet<'static> = unsafe { core::mem::transmute(tx) }; let rx: &mut Packet<'static> = unsafe { core::mem::transmute(rx) }; self.with_ctx_mut(|_self, ctx| { @@ -235,6 +325,9 @@ impl<'a> Exchange<'a> { Err(ErrorCode::NoExchange)?; } + let mut session_mgr = _self.matter.session_mgr.borrow_mut(); + ctx.pre_send(&mut session_mgr, tx)?; + ctx.state = ExchangeState::ExchangeSend { tx: tx as *const _, rx: rx as *mut _, @@ -250,18 +343,21 @@ impl<'a> Exchange<'a> { Ok(()) } - pub async fn complete(mut self, tx: &Packet<'_>) -> Result<(), Error> { + pub async fn complete(mut self, tx: &mut Packet<'_>) -> Result<(), Error> { self.send_complete(tx).await } - pub async fn send_complete(&mut self, tx: &Packet<'_>) -> Result<(), Error> { - let tx: &Packet<'static> = unsafe { core::mem::transmute(tx) }; + pub async fn send_complete(&mut self, tx: &mut Packet<'_>) -> Result<(), Error> { + let tx: &mut Packet<'static> = unsafe { core::mem::transmute(tx) }; self.with_ctx_mut(|_self, ctx| { if !matches!(ctx.state, ExchangeState::Active) { Err(ErrorCode::NoExchange)?; } + let mut session_mgr = _self.matter.session_mgr.borrow_mut(); + ctx.pre_send(&mut session_mgr, tx)?; + ctx.state = ExchangeState::Complete { tx: tx as *const _, notification: &_self.notification as *const _, @@ -276,6 +372,31 @@ impl<'a> Exchange<'a> { Ok(()) } + pub(crate) fn get_next_sess_id(&mut self) -> u16 { + self.matter.session_mgr.borrow_mut().get_next_sess_id() + } + + pub(crate) async fn clone_session( + &mut self, + tx: &mut Packet<'_>, + clone_data: &CloneData, + ) -> Result { + loop { + let result = self + .matter + .session_mgr + .borrow_mut() + .clone_session(clone_data); + + match result { + Err(err) if err.code() == ErrorCode::NoSpaceSessions => { + self.matter.evict_session(tx).await? + } + other => break other, + } + } + } + fn with_ctx(&self, f: F) -> Result where F: FnOnce(&Self, &ExchangeCtx) -> Result, diff --git a/rs-matter/src/transport/session.rs b/rs-matter/src/transport/session.rs index 41fbc497..b25631f6 100644 --- a/rs-matter/src/transport/session.rs +++ b/rs-matter/src/transport/session.rs @@ -19,13 +19,13 @@ use crate::data_model::sdm::noc::NocData; use crate::utils::epoch::Epoch; use crate::utils::rand::Rand; use core::fmt; -use core::ops::{Deref, DerefMut}; use core::time::Duration; use crate::{error::*, transport::plain_hdr}; use log::info; use super::dedup::RxCtrState; +use super::exchange::SessionId; use super::{network::Address, packet::Packet}; pub const MAX_CAT_IDS_PER_NOC: usize = 3; @@ -151,6 +151,15 @@ impl Session { } } + pub fn id(&self) -> SessionId { + SessionId { + id: self.local_sess_id, + peer_addr: self.peer_addr, + peer_nodeid: self.peer_nodeid, + is_encrypted: self.is_encrypted(), + } + } + pub fn set_noc_data(&mut self, data: NocData) { self.data = Some(data); } @@ -251,7 +260,7 @@ impl Session { Ok(()) } - fn send(&mut self, epoch: Epoch, tx: &mut Packet) -> Result<(), Error> { + pub(crate) fn send(&mut self, epoch: Epoch, tx: &mut Packet) -> Result<(), Error> { self.last_use = epoch(); tx.proto_encode( @@ -291,8 +300,8 @@ pub const MAX_SESSIONS: usize = 16; pub struct SessionMgr { next_sess_id: u16, sessions: heapless::Vec, MAX_SESSIONS>, - epoch: Epoch, - rand: Rand, + pub(crate) epoch: Epoch, + pub(crate) rand: Rand, } impl SessionMgr { @@ -327,7 +336,11 @@ impl SessionMgr { } // Ensure the currently selected id doesn't match any existing session - if self.get_with_id(next_sess_id).is_none() { + if self.sessions.iter().all(|sess| { + sess.as_ref() + .map(|sess| sess.get_local_sess_id() != next_sess_id) + .unwrap_or(true) + }) { break; } } @@ -381,12 +394,12 @@ impl SessionMgr { } else if self.sessions.len() < MAX_SESSIONS { self.sessions .push(Some(session)) - .map_err(|_| ErrorCode::NoSpace) + .map_err(|_| ErrorCode::NoSpaceSessions) .unwrap(); Ok(self.sessions.len() - 1) } else { - Err(ErrorCode::NoSpace.into()) + Err(ErrorCode::NoSpaceSessions.into()) } } @@ -419,14 +432,6 @@ impl SessionMgr { }) } - pub fn get_with_id(&mut self, sess_id: u16) -> Option { - let index = self - .sessions - .iter_mut() - .position(|x| x.as_ref().map(|s| s.local_sess_id) == Some(sess_id))?; - Some(self.get_session_handle(index)) - } - pub fn get_or_add( &mut self, sess_id: u16, @@ -472,13 +477,6 @@ impl SessionMgr { .ok_or(ErrorCode::NoSession)? .send(self.epoch, tx) } - - pub fn get_session_handle(&mut self, sess_idx: usize) -> SessionHandle { - SessionHandle { - sess_mgr: self, - sess_idx, - } - } } impl fmt::Display for SessionMgr { @@ -492,45 +490,6 @@ impl fmt::Display for SessionMgr { } } -pub struct SessionHandle<'a> { - pub(crate) sess_mgr: &'a mut SessionMgr, - sess_idx: usize, -} - -impl<'a> SessionHandle<'a> { - pub fn session(&self) -> &Session { - self.sess_mgr.sessions[self.sess_idx].as_ref().unwrap() - } - - pub fn session_mut(&mut self) -> &mut Session { - self.sess_mgr.sessions[self.sess_idx].as_mut().unwrap() - } - - pub fn reserve_new_sess_id(&mut self) -> u16 { - self.sess_mgr.get_next_sess_id() - } - - pub fn send(&mut self, tx: &mut Packet) -> Result<(), Error> { - self.sess_mgr.send(self.sess_idx, tx) - } -} - -impl<'a> Deref for SessionHandle<'a> { - type Target = Session; - - fn deref(&self) -> &Self::Target { - // There is no other option but to panic if this is None - self.session() - } -} - -impl<'a> DerefMut for SessionHandle<'a> { - fn deref_mut(&mut self) -> &mut Self::Target { - // There is no other option but to panic if this is None - self.session_mut() - } -} - #[cfg(test)] mod tests { @@ -545,12 +504,12 @@ mod tests { fn test_next_sess_id_doesnt_reuse() { let mut sm = SessionMgr::new(dummy_epoch, dummy_rand); let sess_idx = sm.add(Address::default(), None).unwrap(); - let mut sess = sm.get_session_handle(sess_idx); + let sess = sm.mut_by_index(sess_idx).unwrap(); sess.set_local_sess_id(1); assert_eq!(sm.get_next_sess_id(), 2); assert_eq!(sm.get_next_sess_id(), 3); let sess_idx = sm.add(Address::default(), None).unwrap(); - let mut sess = sm.get_session_handle(sess_idx); + let sess = sm.mut_by_index(sess_idx).unwrap(); sess.set_local_sess_id(4); assert_eq!(sm.get_next_sess_id(), 5); } @@ -559,7 +518,7 @@ mod tests { fn test_next_sess_id_overflows() { let mut sm = SessionMgr::new(dummy_epoch, dummy_rand); let sess_idx = sm.add(Address::default(), None).unwrap(); - let mut sess = sm.get_session_handle(sess_idx); + let sess = sm.mut_by_index(sess_idx).unwrap(); sess.set_local_sess_id(1); assert_eq!(sm.get_next_sess_id(), 2); sm.next_sess_id = 65534;