From 7d128aca510e61193e9ed66559e2465eaed72433 Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Wed, 24 Apr 2024 08:55:06 +0300 Subject: [PATCH] Multiple improvements to the exchange and transport layers (#147) * Multiple improvements to the exchange and transport layers Bugfixing in subscriptions Mdns shares buffers with the main transport Complete subscription logic (incl change notification) Bugfixing Bugfixing Best effort for Google controller subscriptions to stay alive Re-publish the mDNS broadcast when an entry is removed too Google controller expects revoke commissioning to be supported Eagerly close subscriptions that don't report anything Cleanup in transport mgr Restore the correct subscription id Simplify transport mgr Tests typecheck Minor renames Docu Docu Docu, clippy fix the tests Handle session close Fix buffer sizes for subscriptions Report responer memory RFC (work in progress) RFC (work in progress) RFC (work in progress) RFC (work in progress) RFC (work in progress) RFC (work in progress) RFC (work in progress) RFC (work in progress) RFC (work in progress) RFC (work in progress) RFC (work in progress) RFC (work in progress) RFC (work in progress) RFC (work in progress) RFC (WIP) Buffers for IM; more flexible Exchange; renames in Exchange Address several bugs Matter header for Notification Address several bugs Unit tests work Address several bugs Small updates to the RFC Small updates to the RFC Make IM compatible with unit tests Enable operation over reliable protocols Support for large buffers (TCP) Unify the synchronization primitives Try to reduce a bit the consumed memory Reduce the change delta Update RFC Extra comments Fix the build WIP - RFC Address several bugs Fix lifetime issues with subscriptions notifications Updates to the RFC * Document the handler API * Clarify a commented out line * Document the await optimization * Remove a level of indentation * Leave a TODO that trhe subscription notification logic is incomplete * Add a warning for an unanticipated opcode * Skip the doctest * Change semantics the of recv and recv_fetch to return the last fetched message, if there is any * Address feedback from code review * std::net not necessary as it is now just re-exporting core::net * Address feedback from code review * Address feedback from code review * Address feedback from code review * Address feedback from code review * Address feedback from code review * Address feedback from code review * Address feedback from code review * Address feedback from code review * Address feedback from code review * Address feedback from code review * Incorporate changes to the RFC from HackMD --- ...C_Transport_Exchange_Layer_Improvements.md | 616 +++++++ examples/onoff_light/src/dev_att.rs | 2 +- examples/onoff_light/src/main.rs | 110 +- rs-matter/Cargo.toml | 2 + rs-matter/src/core.rs | 102 +- rs-matter/src/data_model/cluster_on_off.rs | 4 + rs-matter/src/data_model/core.rs | 897 ++++++++-- rs-matter/src/data_model/mod.rs | 1 + rs-matter/src/data_model/objects/handler.rs | 203 ++- rs-matter/src/data_model/objects/node.rs | 32 +- rs-matter/src/data_model/sdm/noc.rs | 6 +- rs-matter/src/data_model/subscriptions.rs | 144 ++ rs-matter/src/error.rs | 1 + rs-matter/src/interaction_model/busy.rs | 78 + rs-matter/src/interaction_model/core.rs | 639 +------ rs-matter/src/interaction_model/messages.rs | 8 +- rs-matter/src/interaction_model/mod.rs | 1 + rs-matter/src/lib.rs | 1 + rs-matter/src/mdns/builtin.rs | 38 +- rs-matter/src/respond.rs | 298 ++++ rs-matter/src/secure_channel/busy.rs | 81 + rs-matter/src/secure_channel/case.rs | 245 ++- rs-matter/src/secure_channel/common.rs | 118 +- rs-matter/src/secure_channel/core.rs | 47 +- rs-matter/src/secure_channel/mod.rs | 1 + rs-matter/src/secure_channel/pake.rs | 219 ++- rs-matter/src/secure_channel/spake2p.rs | 10 +- rs-matter/src/secure_channel/status_report.rs | 53 +- rs-matter/src/tlv/parser.rs | 112 +- rs-matter/src/transport/core.rs | 1498 ++++++++++------- rs-matter/src/transport/dedup.rs | 80 +- rs-matter/src/transport/exchange.rs | 1232 ++++++++++---- rs-matter/src/transport/mrp.rs | 190 ++- rs-matter/src/transport/network.rs | 30 +- rs-matter/src/transport/packet.rs | 296 +--- rs-matter/src/transport/plain_hdr.rs | 183 +- rs-matter/src/transport/proto_hdr.rs | 197 ++- rs-matter/src/transport/session.rs | 637 ++++--- rs-matter/src/utils/buf.rs | 162 +- rs-matter/src/utils/ifmutex.rs | 227 +++ rs-matter/src/utils/mod.rs | 3 + rs-matter/src/utils/notification.rs | 46 + rs-matter/src/utils/parsebuf.rs | 4 + rs-matter/src/utils/select.rs | 157 +- rs-matter/src/utils/signal.rs | 97 ++ rs-matter/src/utils/writebuf.rs | 8 +- rs-matter/tests/common/handlers.rs | 3 +- rs-matter/tests/common/im_engine.rs | 297 ++-- rs-matter/tests/common/mod.rs | 4 +- rs-matter/tests/data_model/long_reads.rs | 6 +- 50 files changed, 6490 insertions(+), 2936 deletions(-) create mode 100644 docs/RFC_Transport_Exchange_Layer_Improvements.md create mode 100644 rs-matter/src/data_model/subscriptions.rs create mode 100644 rs-matter/src/interaction_model/busy.rs create mode 100644 rs-matter/src/respond.rs create mode 100644 rs-matter/src/secure_channel/busy.rs create mode 100644 rs-matter/src/utils/ifmutex.rs create mode 100644 rs-matter/src/utils/notification.rs create mode 100644 rs-matter/src/utils/signal.rs diff --git a/docs/RFC_Transport_Exchange_Layer_Improvements.md b/docs/RFC_Transport_Exchange_Layer_Improvements.md new file mode 100644 index 00000000..771e6043 --- /dev/null +++ b/docs/RFC_Transport_Exchange_Layer_Improvements.md @@ -0,0 +1,616 @@ +# RFC - Transport / Exchange Layer Improvements + +## Terminology + +Throughout this document, the code in `rs-matter` which is responsible for dealing with the network (sending and receiving packets, including their decoding and encoding) is called the _transport layer_. It is also sometimes interchangeably named the _exchange layer_. The _exchange layer_ name is emphasizing the API aspects of the functionality, whereas it is providing the "exchange" API to user code / upper layers. + +When comparing with the "Layered Architecture" diagram in the Matter spec, by "transport / exchange layer" we mean all of: +* (Probably) Action Framing +* Security +* Message Framing + Routing +* IP Framing + Transport Management + +By _user code_ / _upper layer(s)_ (which mean the same thing actually throughout the document), we mean any code within or outside the `rs-matter` codebase that is using exclusively or mostly the _public_ "exchange" API of the such-defined transport layer. + +When comparing with the "Layered Architecture" diagram in the Matter spec, by "user code / upper layers" we mean all of: +* Interaction Model +* Data Model +* Application Layer + +## Intro + +Back in 2023Q2, the new exchange concept - represented by the [`Exchange` struct](https://github.com/project-chip/rs-matter/blob/main/rs-matter/src/transport/exchange.rs#L248) was introduced to `rs-matter`. + +Unlike the C++ Matter SDK and the previous `rs-matter` transport code, the `Exchange` struct API allows for a straightforward, sequential looking sequence of sending and receiving messages with the other peer. No callbacks complexity, no explicit and error-prone state machines' management. I.e. + +```rust +async fn handle_an_im_request(exchange: &mut Exchange<'_>, tx_buf: &mut Packet<'_>, rx_buf: &mut Packet<'_>) -> Result<(), Error> { + // Get the first received packet via the exchange + exchange.recv(rx_buf).await?; + + // Do something with the RX packet + // ... + + // Write something in the TX packet + tx_buf.reset(); + let tlv = TLVWriter::new(tx.write_buf()?)?; + // ... + + // Send the TX packet and wait for the reply + exchange.exchange(tx_buf, rx_buf).await?; + + // Do something with the reply which is in the RX packet + // ... + + // ... and so on + + // That's all folks, w.r.t. this exchange! + Ok(()) +} +``` + +In the absence of blocking IO and multithreading, the above is really only possible by utilizing the Rust `async` syntax of course. Which is - after all - nothing else but a way to use "linear", sequential syntax, yet still get a bunch of (single-threaded, yet concurrent-IO friendly) state machines. + +## Status Quo + +While the worrying memory consumption of the auto-generated Rust state machines is still around ([here](https://github.com/rust-lang/rust/issues/62958), [here](https://github.com/rust-lang/rust/issues/108906), [here](https://github.com/rust-lang/rust/issues/59087) and a [potential first fix](https://github.com/rust-lang/rust/pull/120168), in the meantime we know how to control it at least to some extent (use references to pre-allocated large objects within the async code and don't allocate the large objects from _within_ the async code). + +With the long-awaited `async-fn-in-trait` feature now part of Rust stable, with `gen async` on the horizon, as well as some other async features, we are betting on the right horse w.r.t. `async`. + +Nevertheless, the initial `Exchange` implementation left a lot to be desired. Yet, we believe the main metaphor is solid, so the suggested / implemented changes do **not** change the metaphor, but rather enhance it and try to address all "duck taped places" / outstanding issues summarized below. + +## Issue 1: Unsafe implementation that can actually cause Undefined Behavior (i.e. crash) + +**TL;DR**: The current Exchange implementation is unsound, because it tries to implement a "completion" API in Rust. Basically, we have implemented [the Linux io-uring metaphor in Rust](https://www.cloudwego.io/blog/2023/04/17/introducing-monoio-a-high-performance-rust-runtime-based-on-io-uring/#pure-async-io-interface-based-on-gat) (or ["the DMA with non-owning buffers"](https://hackmd.io/@rust-ctcft/ryivZ5c85?print-pdf#/) - see p17 at the end), which is _impossible_ to do safely _with borrowed buffers_ (as we currently do!) with the existing Rust type system. + +Fortunately, there is a relatively straightforward fix, which requires the internal Matter transport implementation to _own_ the RX/TX buffers rather than what is happening now - user code or the `Exchange` objects own the buffers and "lend" / borrow their `&mut` refs (actually worse - `*mut` refs) to the internal Matter implementation. + +The [details](#appendix-a-undefined-behavior-in-current-exchange-impl) of the problem are at the end of the document. + +### Solution + +The solution is also how the "DMA" and "io-uring" problems are solved in general - by [the transport singleton **owning** the buffers and the `async` notification mechanism](https://github.com/ivmarkov/rs-matter/blob/next/rs-matter/src/transport/core.rs#L77), rather than the other way around which is the status quo: +* The transport singleton **owns** the RX/TX buffers as well as the `async` notification mechanism +* All `Exchange` objects keep a reference `&Matter` to the Matter stack and thus to the transport impl too. So they cannot "outlive" the transport singleton +* The RX/TX buffers and the notification are protected with an `async` Mutex-like synchronization primitive, so that at any point in time, either the transport singleton reads/writes from/to its own RX/TX buffers, or the user code, via one of the exchanges +* User code - via the `Exchange` structs - awaits for the RX/TX buffers to become available, and then receives an `async` Muterx Guard to these. More details below, in [Issue2](#issue-2-too-many-rxtx-buffers) + +All of the above is implemented with a safe-only code. + +## Issue 2: Too many RX/TX buffers + +The current implementation of `Exchange::send`, `Exchange::recv` and `Exchange::exchange` all take user-supplied buffers. This means that in the general case, N active exchanges require N, or even N*2 buffers (one RX and one TX, because of `Exchange::exchange(&mut tx, &mut rx)` which needs both). + +With 8 active exchanges this means 8 * (1583 + 1280) = ~23K memory just for the buffers. + +Granted, user (i.e. upper layer)-supplied buffers might be necessary in some cases _anyway_. Imagine answering long reads in the Interaction Model where the device is `await`-ing when e.g. reading data from the device HAL. For that case you **do** need to copy the RX data into a user supplied buffer, so that while you are populating the various attribute values in the TX data and potentially `await`-ing the HAL layer in-between, the transport can continue to operate and dispatch RX packets for _other_ exchanges. + +The problem however is that extra RX/TX buffers are _not_ always necessary, yet with the current API we are _always_ paying the price. + +Another problem is that these buffers currently _always_ have the shape of an `[u8]` slice, while the upper layer might ultimately need a different buffer. + +For example, the Secure Channel implementation never awaits the device HAL, as it does not communicate with the device HAL in the first place. It is a pure computational code. What that means, is that - theoretically - it can read _directly_ from the RX buffer of the transport impl and write _directly_ to the TX buffer of the transport impl. Even if that means that while it does so, no new incoming UDP packets will be accepted by the transport layer and all other exchanges willing to send stuff would be `async`-waiting for the (single) TX buffer to become available. This is tolerable in that case because - again - the Secure Channel is a pure computational layer without doing any IO. So at least in theory, it should complete fast (putting aside delays due to complex elliptic curve calcs). + +In practice, the Secure Channel does currently need interim buffers too, as in [here](https://github.com/ivmarkov/rs-matter/blob/next/rs-matter/src/secure_channel/case.rs#L74), [here](https://github.com/ivmarkov/rs-matter/blob/next/rs-matter/src/secure_channel/case.rs#L544), [here](https://github.com/ivmarkov/rs-matter/blob/next/rs-matter/src/secure_channel/core.rs#L56) and [here](https://github.com/ivmarkov/rs-matter/blob/next/rs-matter/src/secure_channel/core.rs#L60). (Which by the way need to be optimized.) But these are **on top** of the additional `[u8]` RX/TX buffers that the `Exchange` API currently requires. And they have a different shape. So it becomes a bit of a "too many buffers" situation. + +Where we're getting with all of that is that the Exchange layer should not assume how the Secure Channel or Interaction Model layers operate. They might or might not need extra buffers. However, requiring them to always use extra `[u8]`-shaped buffers means we are trading off extra memory of a fixed shape for responsiveness _even when we don't necessarily have a responsiveness problem in the first place_ or we don't know what type of memory layout the upper layer actually needs for buffers. + +Some extra info in [Appendix B](#appendix-b-digression-do-we-need-the-data-model-handlers-to-be-async-in-the-first-place). + +### Solution + +The improved transport/exchange layer operates off from a **single** pair of RX/TX buffers. + +Its philosophy is that it is up to the upper layers (Interaction Model and Secure Channel) not to hold on to its single pair of RX/TX buffers for too long. Since the knowledge whether the upper layers would be holding on for too long on its RX/TX buffers is with these upper layers, it is _up to them to decide_ when or if to use additional buffers in the first place. It is also up to them to decide what the _shape_ and _lifetime_ of these additional buffers would be. + +The exchange layer deciding on behalf of the upper layers is basically the current status quo, where the exchange layer is always immediately copying data from/to buffers supplied by the upper layers. Fine, but that means over-provisioning of memory which we are trying to solve in the first place. + +#### Receiving Details + +* The transport layer is concurrently and asynchronously trying to get a `&mut` ref to its own RX buffer, but only when the RX buffer is _already emptied or empty_. If the RX buffer is full with a previous packet for _some_ exchange, the transport layer waits until the corresponding `Exchange` instance consumes the content of the RX buffer and signals back that the RX buffer is empty again. +* At the same time, all active `Exchange` instances which are `await`-ing inside their `Exchange::recv` method, are concurrently and asynchronously trying to lock the `async` mutex protecting the RX buffer singleton. An `Exchange` instance will succeed doing so _only when the RX buffer is full_. Moreover, only when the RX buffer is full with data designated for _that concrete concrete Exchange_ which is trying to get hold of the RX buffer. +* `Exchange::recv().await` returns an `async` Mutex Guard in disguise. The user (e.g. the upper layer) can read freely the data in the buffer protected by this guard, including `await`-ing the HAL while operating on that data. However, the transport layer will _not_ be receiving other packets at that time (as there is a single RX buffer), potentially causing UDP packets from other peers to be dropped and re-transmitted if the OS packet queue is full. So the buffer returned from `Exchange::recv` should not be held for too long and if so (i.e. the "network bridge" case), the upper IM/Secure Channel layer shold pull the data in an interim buffer and drop the `async` Mutex guard it got via `Exchange::recv` thus singnalling the RX packet singleton as empty. + +#### Sending Details +* The transport layer is concurrently and asynchronously trying to get a `&mut` ref to its own TX buffer, _but only when the TX buffer is full_. If the TX buffer is not full, this means no exchange has prepared data for sending. When the transport layer gets access to the (already full) TX buffer, it copies the data in there over UDP (or other protocols in future), then marks the buffer as empty and signals/wakes all exchanges potentially `await`-ing on the TX buffer, that it is releasing the `async` lock on it. +* At the same time, all active `Exchange` instances which are inside their `Exchange::init_send` methods, are concurrently and asynchronously trying to lock the `async` mutex protecting the TX buffer singleton. An `Exchange` instance will succeed doing so _only when the TX buffer is empty_, and only one exchange instance would succeed doing so, and the others would continue to wait. +* `Exchange::init_send().await` returns an `async` Mutex Guard in disguise as well. The user can write freely into the buffer protected by this guard, including `await`-ing the HAL while operating on that data. However, if it is slow in doing that, it would delay all other exchanges willing to send at that time. Since the transport layer is automatically sending ACKs for re-transmitted packets this is not the end of the world, but if an exchange is delayed too much, it might cause this or other peers to eventually time out the whole exchange. Therefore, the buffer returned from `Exchange::init_send` should not be held for too long and if so (i.e. the "network bridge" case), the upper IM/Secure Channel layer shold first prepare the data to be sent in its own buffer, and only then try to lock the common TX buffer when the data is ready to be sent. + +#### Deadlock Avoidance + +Given that the transport layer is offering the upper layers two separate async mutexes in disguise - one for the RX buffer, and another - for the TX buffer, how are we avoiding a deadlock situation. E.g.: +* Exchange 1 has successfully locked the RX buffer by calling `exchange.recv().await` and now tries to lock the TX buffer by awaiting `exchange.init_send().await` to complete +* Exchange 2 did the opposite: it had locked the TX buffer by completing `exchange.init_send().await` and is now awaiting for the RX buffer with `exchange.recv().await`? + +The answer is that the new API currently simply does not allow this: +* Method `Exchange::recv` takes a `&mut self` of the `Exchange` struct, and - most importantly - the returned Guard wrapper looks as if it keeps a `&mut` ref to the `Exchange` object while the guard wrapper is still alive. +* Metod `Exchange::init_send` does exactly the same. +* Similarly for the variations of the above methods - namely - `Exchange::recv_fetch`, `Exchange::sender`, `Exchange::send_with` and so on + +What the above means is that the upper layer can either operate on the RX buffer, or on the TX buffer, but not simultaneously on both. In a way that also means that we are _forcing_ the upper layers to actually use interim buffers, as they can't really "write into the TX buffer while reading from the RX buffer". Even if we relaxed our deadlock-avoiding locking scheme so as the upper layers to additionally be allowed to lock the RX buffer first, and then - as a second and only as a second step - the TX buffer as well - that would be problematic when addressing "Issue 3" (packet re-transmission). More on that below. + +So even though it seems we are "back to square one" and in a way forcing the upper layers to re-introduce additional buffers, this is not exactly the same as the current status quo, as we are no longer dictating the _shape_ or the _lifetime duration_ of these interim buffers. For example: +* Pase needs `Spake2p` instance, but this "buffer" is (a) valid throughout the whole Pase exchange (b) needs the data to be massaged first before pushing it into it +* Ditto for Case with its `CaseSession` +* Ditto for Case that does need extra buffers to first encrypt and/or sign content before pushing into TX +* The IM layer almost always needs a TX `&mut [u8]`-shaped buffer where it can stream the data to be returned in the response, but it might not necessarily need an RX buffer for the incoming request, as long as the clusters it needs to query / write to / invoke are not `await`-ing +* Some of the operations in the IM layer do not need any buffers at all - for example, processing a timeout request/response. Or processing a status respone to a chunked `ReportData` response. Or answering with a `Busy` / `Resource Exhausted` status code in case the IM layer cannot handle the incoming request +* etc. + +#### I hear the arguments, but if we have to, can we revert to the old scheme, just in case? + +If we decide so, that's easy: +* Make public only `Exchange::recv_into`, `Exchange::send` and `Exchange::send_from` +* Make `Exchange::recv_fetch`, `Exchange::rx`, `Exchange::recv`, `Sender` and `Exchange::init_send` private + +That way upper layers would be _required_ to provide raw, `[u8]`-shaped buffers and the data will be read from/written to those immediately. +Yet, all other issues except the memory one would still be solved. And - a solution for Issues 1 and 6 in particular might anyway require a scheme similar to the proposed one. + +## Issue 3: No packet re-transmission + +The current transport layer does not have packet re-transmission implemented. + +### Solution + +The improved transport layer has packet re-transmission implemented. + +The question here rather is - how is it possible to implement packet re-transmissions for _multiple_ exchanges by using a _single_ pair of RX/TX buffers in the first place? + +We are simply trading less memory usage for extra computation and some extra burden on the user / upper layers. To put it simply, the upper layers are _required_ to be capable of re-generating the TX payload of their packet (and then the exchange layer would re-encode it and re-send it again), until the exchange layer tells them it is no longer necessary (i.e. when an ACK is received). + +While this sounds like a lot of lift and shift, the new public `Exchange` API provides plenty of utilities to get the job done: +* `Exchange::send_with(f: FnMut(&Exchange, &mut WriteBuf) -> Option)` + * The transport layer will call the provided closure as many times as necessary (or just once for non-reliable packets); upper layer is only required to behave idempotently and generate the _same_ content and message meta-data every time +* `let sender = Sender::new(&mut exchange)` and then `while let Some(tx) = sender.tx() { let payload = tx.payload(); let mut wb = WriteBuf::new(payload); ... }` + * Same as above, but also allows the upper layer to await while generating the content, as a non-async `FnMut` closure is not necessary (and Rust still lacks `AsyncFnMut`-style closures) +* `Exchange::send(payload: &[u8], meta: MessageMeta)` + * The "old style" API where the message payload is prepared in a separate buffer, and then handed to the exchange layer for sending (and re-sending) + +Here are a few examples from the actual `DataModel` IM layer, as to how packet retransmission looks like from the POV of layers above the transport one: + +#### Example 1: Handling an IM "Timed" request + +A "Timed" request might precede a "Write" or "Invoke" request. It only contains a "timeout" scalar `u32` value. As such, its processing and the (re)transmission of a response which is just a status response can be done without any intermediate buffers. + + +Here's how the "Timed" request-response interaction is coded: +```rust= +async fn timed(&self, exchange: &mut Exchange<'_>) -> Result { + // Get access to the transport layer RX packet and convert it to a TimedReq struct + let req = TimedReq::from_tlv(&get_root_node_struct(exchange.rx()?.payload())?)?; + debug!("IM: Timed request: {:?}", req); + + // Extract the timeout value. In a way, we _do_ use a buffer between the + // above RX operation and the below TX operation. The buffer is `timeout_instant`. + let timeout_instant = req.timeout_instant(exchange.matter().epoch); + + // Send (with re-transmission) a status response + Self::send_status(exchange, IMStatusCode::Success).await?; + + Ok(timeout_instant) +} +``` + +As for `send_status`: +```rust= +async fn send_status(exchange: &mut Exchange<'_>, status: IMStatusCode) -> Result<(), Error> { + exchange + .send_with(|_, wb| { + StatusResp::write(wb, status)?; + + Ok(Some(OpCode::StatusResponse.into())) + }) + .await +} +``` + +Do note how `exchange.send_with` takes a (`FnMut`) closure. What this means is that once we call `send_with` and thus call the transport layer, we should be prepared our closure to be called multiple times, due to packet retransmissions, and until the transport layer receives an ACK for the packet we are transmitting. So our closure should be idempotent and generate the same payload every time it is called. Since the response is a simple status message, this is not a problem in this case. + +#### Example 2: Answering an IM `Invoke` request: + +Here's how the "Invoke" request-response interaction is coded: +```rust= +async fn invoke( + &self, + exchange: &mut Exchange<'_>, + timeout_instant: Option, + ) -> Result<(), Error> { + let req = InvReq::from_tlv(&get_root_node_struct(exchange.rx()?.payload())?)?; + debug!("IM: Invoke request: {:?}", req); + + // (Handling timeouts is skipped for brevity) + + // To easily handle idempotent re-transmissions, we + // simply allocate a TX buffer here and prepare the response inside it + let Some(mut tx) = self.tx_buffer(exchange).await? else { + return Ok(()); + }; + + let mut wb = WriteBuf::new(&mut tx); + + let metadata = self.handler.lock().await; + + // Get the request shape by parsing the RX payload as TLV + let req = InvReq::from_tlv(&get_root_node_struct(exchange.rx()?.payload())?)?; + + // Will the clusters that are to be invoked await? + let awaits = metadata + .node() + .invoke(&req, &exchange.accessor()?) + .any(|item| { + item.map(|(cmd, _)| self.handler.invoke_awaits(&cmd)) + .unwrap_or(false) + }); + + if awaits { + // Yes, they will + // Allocate a separate RX buffer then and copy the RX packet + // into this buffer, so as not to hold on to the transport layer + // (single) RX packet for too long and block send / receive + // for everybody + let Some(rx) = self.rx_buffer(exchange).await? else { + // Allocating an RX buffer failed. + // However, `rx_buffer` already had sent a status response + // "Busy" to the remote peer. We can therefore simply unroll + // our stack by returning. + return Ok(()); + }; + + // Re-parse the incoming request + let req = InvReq::from_tlv(&get_root_node_struct(&rx)?)?; + + // Call the clusters and at the same time populate our TX + // buffer + req.respond(&self.handler, exchange, &metadata.node(), &mut wb) + .await?; + } else { + // No, they won't. Answer the invoke requests by directly using + // the RX packet of the transport layer, as the operation won't await + // Same as per above, call the clusters and at the same time + // populate our TX buffer + req.respond(&self.handler, exchange, &metadata.node(), &mut wb) + .await?; + } + + // Now that the clusters are invoked and we have their response in `wb`, + // call the transport (exchange) layer to send the response + // + // Note that `exchange.send` will NOT complete until it receives an + // ACK for the message it sends. Therefore, it might transmit our + // `wb.as_slice()` payload multiple times, with multiple messages + // But we don't care about that. Thanks to `async`, this re-transmission + // loop is hidden from us. All that we need to provide is the message + // payload in an idempotent way (as a `&[u8]` slice in this case + // that can be read from multiple times), so that the transport layer + // can do its re-transmission logic. + exchange.send(OpCode::InvokeResponse, wb.as_slice()).await?; + + Ok(()) +} +``` + +`self.tx_buffer(exchange)` and `self.rx_buffer(exchange)` are also interesting, as these are `async` calls, and in fact, allocating an intermediate TX or RX buffers can fail. Here's the TX buffer allocation: +```rust= +async fn tx_buffer(&self, exchange: &mut Exchange<'_>) -> Result>, Error> { + if let Some(mut buffer) = self.buffers.get().await { + // Getting a TX buffer (potentially after some time!) succeeded + // Size it and return it. + // + // NOTE: How much (and even if) allocating a buffer can await + // for a free buffer is up to the `BufferAccess` implementation, + // but it should be in the order of a few milliseconds, as + // while awaiting here we are potentially blocking the single + // RX/TX buffers of the transport layer. + // + // The default `BufferAccess` impl does not await. + buffer.resize_default(MAX_EXCHANGE_TX_BUF_SIZE).unwrap(); + + Ok(Some(buffer)) + } else { + // Getting a TX buffer failed. + // + // Before returning, call `send_status` (the method we looked at + // during the examination of the "Timed" req handling) + // to return to the client a status code that we are "Busy" + // (i.e. it should retry later, when we might have buffers) + Self::send_status(exchange, IMStatusCode::Busy).await?; + + // Return `None` so that the upper function can unroll its stack + Ok(None) + } +} +``` + +## Issue 4: Responding to exchanges is "locked" and hard-coded inside the transport layer implementation + +Method `Matter::run` currently is not only running the exchanges' transport logic (as in dispatching RX packets to `Exchange` objects and sending their TX packets). It is also managing the lifecycle of all "responder" exchanges and keeps them locked in a cage. + +Worse, the _concrete_ IM and SC implementations of the upper layers [are hard-coded](https://github.com/project-chip/rs-matter/blob/main/rs-matter/src/transport/core.rs#L352). + +### Solution + +While we can implement a "callback / handler style" API so that the user can plug-in their own protocol handlers, that might not be the best _base-level_ API, as it - by necessity - would hard-code how multiple exchanges are executed _concurrently_ (i.e. their execution and lifecycle model). Also it is not quite an idiomatic Rust, which seems to favor a non-closure based, "external iteration" style base-level APIs, and then optionally provide closure/handler API on top of the base-level ones. + +(Also there is the question of how do we create "initiator" exchanges, that we need for handling subscriptions which is addressed in the next issue.) + +Instead of a callback/handler API, the base-level "responder" exchange API for the upper layer is as follows: +* [`let exchange = Exchange::accept(&matter).await`](https://github.com/ivmarkov/rs-matter/blob/next/rs-matter/src/transport/exchange.rs#L739) + * ...where `&matter` is a reference to a `Matter` instance + * This would wait until a new exchange is initiated by a remote peer, and then that exchange would be returned to the upper layer / user code + * Obviously, multiple async tasks (or futures aggregated in a bigger future) can concurrently call `Exchange::accept`; the more tasks/futures do that "in parallel", the more exchanges would be handled "simultaneously" (w.r.t. IO, as everything is dispatched off from a single thread still) + * Also obviously, somebody needs to run in another task/future `Matter::run` or else the Matter transport layer will not run, and therefore no responder exchanges would ever be created, as there would be no networking traffic + +#### Users just want to run their on-off cluster, not deal with the complexity of accepting responder exchanges! + +Sure, and for this we still have "cage" callback-style utilities built on top of the above base-level API, thanks to the new `async-fn-in-trait` functionality in Rust: + +* [`Responder::run`](https://github.com/ivmarkov/rs-matter/blob/next/rs-matter/src/respond.rs#L120), which takes a `&matter` reference and organizes a pool of "handler" futures to concurrently call `Exchange::accept(&matter)` and then apply on each accepted exchange a user-provided `ExchangeHandler` trait callback + * `DataModel` and `SecureChannel` are retrofitted to implement the single-method `ExchangeHandler::handle(&mut Exchange)` API and are thus "exchange handlers" +* [`DefaultResponder`](https://github.com/ivmarkov/rs-matter/blob/next/rs-matter/src/respond.rs#L229), which internally uses `Responder` from above with an `ExchangeHandler` instance which is a composition of the default `DataModel` and `SecureChannel` protocol handlers + +The key difference between this new and the old arrangment being that `Responder` and `DefaultResponder` - just like `DataModel` and `SecureChannel` are **not** part of the main `Matter` instance and as such are replaceable with equivalents by the user. + +... which brings the question of _what are the roles and respoonsibilities of the `Matter` object then_, which is discussed in [Issue 8](#issue-8-current-exchange--transport-code-is-all-over-the-place). + +## Issue 5: No way to initiate an exchange + +As per above, we do need this so as to implement reporting data on active Interaction Model subscriptions. + +### Solution + +Similarly to `Exchange::accept`, the improved exchange layer now has `Exchange::initiate`: +* Call `let exchange = Exchange::initiate(&matter, node_id, is_secure)` + * This would create a new initiator exchange, as long as there is a valid session (of the provided type) with the node whose ID is provided, or fail otherwise + * Initiating an exchange to a node with which we don't have a valid session yet is left for the future, as that would require implementing the client side of the Secure Channel protocol + +## Issue 6: Robust error handling + +The current exchange layer is not really behaving well with regards to error conditions. The error conditions that need to be handled: +* a) A new responder exchange was created successfully, however it is not getting accepted by anybody because there are not enough exchange handlers in e.g. the upper layer +* b) A responder or initiator exchange fails "mid-flight" after sending and/or receiving some message(s). Or the user code holding an `Exchange` object fails due to some other error unrelated to Matter. In any case, the user code stack that uses the `Exchange` instance is unrolled, and as a result, the `Exchange` instance is dropped! +* c) A new responder or initiator exchange cannot be even created because the transport layer had ran out of `ExchangeState` exchange slots (which track the MRP and the exchange role) +* d) A new responder exchange was created successfully, yet the upper layers find out they are out of resources (say, not enough memory or other resource) + +### Solution + +To bring some order in the error handling process, the improved transport/exchange layer distinguishes between two main states an exchange could be: + +#### State 1: Exchange is "owned" by the transport layer itself + +What this means is that the internal structure in the transport layer that tracks the exchange (`ExchangeState` - see below) is created, however: +* (`Role::Responder(ResponderState::AcceptPending)`): The actual `Exchange` object that is used by the upper layers is either not created yet (by somebody in the upper layers completing an `Exchange::accept().await` call for that exchange)... + * This is error condition (a) from above +* (`Role::Responder(ResponderState::Dropped)` and `Role::Initiator(InitiatorState::Dropped)`) ...or the `Exchange` object was dropped after being created (and after possibly being used). This case can happen for both initiator and responder exchanges and in fact *will always happen* for every single exchange, as all `Exchange` objects are dropped sooner or later + * This is error condition (b) from above, which might actually not be an error condition even, but a normal exchange completion +* (No `Role` instance, no `ExchangeState` instance): It might also happen that the transport layer cannot even create its own `ExchangeState` structure instance for an RX packet that looks like a new exchange, as it had ran out of `ExchangeState` exchange slots for that session (by default, they are up to 5 per session, as suggested by the Matter spec) + * This is error condition (c) from above + +What is important for State 1 is that in State 1 it is the _transport layer_ who is _ultimately responsible_ for somehow handling the error conditions and ultimately completing the exchange, possibly by even responding itself to the peer that initiated the exchange. + +The TL;DR for how the transport layer handles exchanges it owns is that the only thing it does is nothing more than to **close the Exchange by following the "Closing an Exchange" procedure in the Matter spec**. + +In details: +* For issue (a): if the new exchange is not accepted within 500ms, the exchange is closed as per the "Closing an Exchange" procedure in the Matter spec - i.e., a Standalone ACK is sent if one was not sent yet and that's it - the `ExchangeState` instance is removed and the exchange slot in the session is free'd +* For issue (b): + * If the exchange has no "in-flight" re-transmission, then the exchange is closed just like when handling issue (a) except without any delays + * If the exchange has an "in-flight" retransmission of a reliable message which is not yet ACKed by the remote peer, **the whole session is closed abruptly** by sending a SC `CloseSession` status non-reliable message. The transport layer really cannot recover better from that case, as it does not have the payload of the message which is in a re-transmission state in the first place! (As the transport layer is operating off from a single shared with everybody TX buffer after all) +* For issue (c): + * Exactly as in issue (a) - send (via using the RX buffer that is locked by the transport layer in that case but so what!) an ACK if the exchange has an outstanding ACK waiting and then remove the `ExchangeState` object + +#### State 2: Exchange is "owned" by the upper layer + +`Role::Responder(ResponderState::Owned)` and `Role::Initiator(InitiatorState::Owned)`) + +What this means is that the `Exchange` object for the particular exchange was created successfully by the upper layer (with `Exchange::accept` or `Exchange::initiate`) and is still alive (not dropped). + +In that state, the only message that the transport layer generates on its own for a particular exchange is a Standalone ACK after a timeout, or when receiving a duplicate packet due to the other peer re-transmitting. All other RX/TX is only done when the upper layer explicitly uses the send/receive functions on the `Exchange` object. + +_Any_ errors (except network failures) that might happen during the RX/TX of messages in this state _are reported back to the upper layer code_ as an `Error` code which is returned from the sending/receiving methods of the `Exchange` struct. The thing is, if the upper layer code cannot deal with those, or if it cannot deal with any _other_ error conditions stemming from elsewhere (like it being low on resources), it would eventually unroll its stack, thus dropping the `Exchange` object, and thus transferring the ownership of the exchange back to the transport layer! That would mean that the exchange would enter State 1 again, and will be completed by the transport layer, following the above rules. + +#### But handling of "low on resources" conditions and other errors should be more intelligent, at least most of the time! + +I.e. rather than just ACKing and then silently closing the exchange leaving the remote peer maybe waiting for a response by us, and timing out the exchange after a long time, we should ideally be sending: +* SC BUSY status code for all Secure Channel payloads which try to initiate a Pase or Case session +* IM BUSY status code for all Interaction Model payloads which try to initiate a Read, Write or Invoke interaction +* IM "Resource Exhausted" status code for all Interaction Model payloads which try to initiate a Subscribe interaction +* IM Failure for all other incoming requests which look like an exchange which is "mid-flight" yet they are unexpected by the Interaction Model layer + +This is a fair statement, however it is questionable whether it is the duty of the transport layer to do this, as it does _not_ understand IM or SC protocols' details. Moreover, some of the above messages need to be send in a reliable manner with (re)transmission, which is difficult to do from within the transport layer "inner guts" itself. + +In general, the transport layer only understands: +* `MRPStandaloneAck` +* SC `CloseSession` status code +* (In future) Session counters sync req/resp messages + +So this problem is solved in a different way, as part of the upper layers: +* In addition to having `DataModel` and `SecureChannel`, the `rs-matter` framework now offers two additional `ExchangeHandler` implementations: + * `BusyDataModel` + * `BusySecureChannel` +* These handlers are very simple - they send `Busy` to the incoming messages that are the opening ones for a responder exchange, and `Failure` for everything else. Being so simple, these handlers don't need any additional buffers and _operate off completely from the transport layer RX/TX buffers_. What this means is that when these busy handlers are wrapped in a `Responder`, the `Responder` instance can create _a lot_ of handling futures for these, as they take so little memory, so that almost every exchange which is not handled by the main handlers, would be answered with a small delay by the busy handlers. +* Finally, the `DefaultHandler` struct mentioned in [Issue 4](#issue-4-responding-to-exchanges-is-locked-and-hard-coded-inside-the-transport-layer-implementation) actually runs _two_ `Responder` instances - the one which uses the "real" `DataModel` and `SecureChannel` handlers, and another one - answering with ~ 100ms delay - which runs the `BusyDataModel` and `BusySecureChannel` handlers. Thus - and in a natural way - if the main responder is low on resources and cannot (or is unwilling to) accept an exchange on time, the "busy" responder would kick in, answering with "Busy" or "Failure". Finally, if even the busy responder cannot handle the storm of exchanges, the transport layer would kick in after a ~ 500ms delay by ACKing the RX message for the exchange and then dropping the exchange slot from its session. + +## Issue 7: Low level details revealed to upper layers + +Currently, the IM and SC handler take an `Exchange` struct and then additionally - a pair of `&mut Packet<'_>` references for TX/RX. +This is not ideal, as the upper layers should not be concerned with the low level details of the transport layer packet structure. Ideally, they should: +* Have a way to read (for RX) or create (for TX) the message payload (TLV or other), where the RX payload is already decoded (decrypted) by the transport layer, and the TX payload would be automatically and transparently encoded (encrypted) by the transport layer +* Have a way to specify the protocol ID and the protocol OpCode +* Indicate if the message is reliable or not + +### Solution + +Instead of dealing with a `Packet` structure, the upper layers know about: +* An `MessageMeta` structure, that only captures the protocol ID, the protocol Opcode and whether the message is reliable or not +* A `[u8]` slice for the RX payload +* A `&mut [u8]` slice for the un-encoded TX payload they have to build (and then return the stanrt and end of the payload as well as the message meta-data) + +For presenting these types of structures to the upper layers, as well as - and more importantly - for solving [Issue 2](#issue-2-too-many-rxtx-buffers), the transport layer **no longer has the notion of a `Packet` structure**. Instead, it only has the notion of a packet _header_ structure (`PacketHdr` which is just a concatenation of a `PlainHdr` and `ProtoHdr`) as well as utility methods on `PacketHdr` for decoding / encoding a packet from/to _user supplied_ `ParseBuf` and `WriteBuf` respecitvely. + +This gives us the freedom to encode / decode the final UDP/TCP/etc packet either in-place, in an owned `heapless::Vec`, or elsewhere. + +## Issue 8: Current exchange / transport code is all over the place + +Or in more details: +* There is no `ExchangeMgr`. Exchange slots are owned directly by the `Matter` object +* All transport code is implemented directly on the `Matter` object, albeit in the `transport` module. So the `Matter` object currently has two implementations: one in `rs_matter::core`, and then another - in `rs_matter::transport::core`, which is weird, but acceptable in Rust + +### Solution + +The code is (re)organized as follows: +* `TransportMgr` is back! + * It is owned by the `Matter` object and aggregates all transport layer code; lives in `rs_matter::transport::core` + * `TransportMgr` - in turn - now owns `SessionMgr`, as sessions are part of the transport layer + * There is no a separate `ExchangeMgr`. The exchange slots are **not** owned by any `*Mgr`. They are owned and managed by their `Session` instance instead. Each session can have up to `MAX_EXCHANGES` exchange slots, which is by default set to 5 (as per the suggested maximum in the Matter spec). Regardless, `TransportMgr` is intimately aware about the notion of an exchange, as well as the notion of sessions (via its aggregated `SessionMgr` instance) + * `TransportMgr` now also owns the `MdnsImpl` which is the mDNS service in use by the Matter stack, as it is also considered a part odf the tranport layer + * The `Exchange` struct which is the main interface to the transport layer for the upper layers / user code lives in `rs_matter::transport::exchange` as before. There is now also `MessageMeta` (as per [Issue 7](#issue-7-low-level-details-revealed-to-upper-layers)), `ExchangeId` (an internal ID of each exchange - a concatenation of the internal session ID and the exchange index in the slots' array of the `Session`) as well as internal, transport-layer-only structs, like `ExchangeState` (the exchange slot, used to be called `ExchangeCtx`), `Role` and a few others +* `TransportMgr`'s responsibilities are as follows: + * Run the network layer, via `TransportMgr::run` and `TransportMgr::run_builtin_mdns` + * Provide means for the upper layers to accept and initiate exchanges (via `TransportMgr::accept` and `TransportMgr::initiate` which are crate-public and exposed via `Exchange::accept` and `Exchange::initiate` instead) + +Since `TransportMgr` is an internal detail of the `Matter` object, its crate-public `run`, `run_builtin_mdns`, `accept` and `initiate` methods are exposed either on the `Matter` object or on the `Exchange` object as well, as public methods. + +### What is the `Matter` object responsible for, in the end? + +Clear: +* Aggregating the transport layer, and exposing it to SC and IM protocol handlers via a handful of structures and methods: `Exchange`, `Matter::run`, `Matter::run_builtin_mdns` +* Providing the `rand` and `epoch` functions + +Unclear: +* Providing basic configuration in the form of `BasicClusterInfo` (TBD: do we still need this as part of the `Matter` object) +* Providing the notion of fabrics, in the form of `FabricMgr` (TBD: do we need to publicly expose this?) +* Providing the notion of IM ACLs, in the form of `AclMgr` (TBD: shouldn't it be owned by the `DataModel` IM implementation?) +* `PaseMgr` (TBD: Should this be part of the transport layer?) + +### What the `Matter` object should NOT be responsible for, in the end? + +Clear: +* Not responsible or aware of the Interaction Model layer, its payload / details (but the Interaction Model is aware of `Matter` and its transport layer - via a well defined small set of public exchange APIs, as per above) +* Not responsible or aware of the Secure Channel layer, its payload / details, except for a handful of messages related to session management and packet re-transmission (but the Secure Channel is aware of `Matter` and its transport layer - via a well defined small set of public exchange APIs, as per above) +* Not responsible for organizing in any way the response to exchanges. This is the role of the `Responder` utility, or user-defined exchange processing, using "real" async executors like `tokio` etc. +* Not responsible for organizing in any way initiation of exchanges + +## Appendix A: Undefined Behavior In Current `Exchange` Impl + +(Hold tight, this is a bit long and tricky.) + +The underlying transport implementation below the `Exchange` API currently implements the following metaphor (only receiving will be examined here, but sending has similar issues): +* For each active exchange, the user owns an RX packet/buffer. How and where this packet is allocated is not a concern of the Exchange API. +* The user operates on this buffer freely (as in mainly reading from it of course) +* When the user wants to receive, the user calls `Exchange::recv(&mut rx).await` or `Exchange::exchange(&mut tx, &mut rx).await`, supplying a *mutable reference* to the RX buffer. +* [The code for the above implementation uses `unsafe` to avoid lifetime-related compiler errors](https://github.com/project-chip/rs-matter/blob/main/rs-matter/src/transport/core.rs#L424), as the above pattern is un-expressible with the existing Rust lifetime rules. This is really important: should we've NOT used `unsafe`, this pattern would've been impossible to implement, and the problem would've not been here in the first place! +* This mutable RX reference - together with a `*mut` ref to the `async` Notification primitive owned by the concrete `Exchange` instance is recorded in an internal central singleton structure (the Matter transport impl), and when a packet for that particular exchange arrives, it is copied into the recorded RX buffer *mutable reference*, and then the corresponding `Exchange` object is awoken from `await`-ing, by `signal`-ing the recorded `*mut` ref of the `Notification` structure. + +... and that's the crux of the issue - that the mutable RX reference and the mutable Notification object reference "are recorded", i.e. they are kept around **accross** await points! + +This can lead to crashes - i.e. what happens if the `Exchange` instance had given the `*mut` refs of the user's RX buffer and its `Notification` instance to the internal transport impl, is now `await`-ing, and the user just "cancels" the `await` "mid-flight" by stopping to poll the future? + +If the cancellation of the future (which means the future is just dropped or forgotten) calls the exchange `drop` destructor - all is right - the `*mut` references will be de-registered from the internal transport impl, so no dangling references to memory which might no longer be around. + +The problem is when the user `core::mem::forget`s the future or some of its parent futures (forgetting in Rust IS a safe API!). Or uses `Arc`/`Rc`s that end up with cycles which leads to memory leaks as well. In that case, the compiler would think the RX packet specifically no longer has a mutable reference and it can be e.g. dropped (or mutated freely by somebody else), yet that's not the case, as a `*mut` reference to the RX packet is still registered in the transport impl! + +Now, I do realize this all spounds a bit long, unclear, theoretical and a corner case, stemming from the fact that Rust destructors are currently not guaranteed to run (the so called "leakpocalypse" discussion that had happened ~ 2015), but the problem is in there, real. We were recently hit by it in `esp-idf-hal` in two separate places by that very same problem (SPI driver with DMA; non-`'static` closures passed to hidden OS threads) and one of these places was found by users, not us. + +Anyway, the solution is to have the RX buffer owned by the internal transport impl. NOT by the user and NOT by the concrete `Exchange` instance. This way, no `*mut` pointer with unknown lifetime is used accross await points. If the user wants a copy of the RX packet in its own buffer, that's still possible, but the internal transport impl does not "record" the mut ref to the user's packet anywhere - it first receives the data in its own internal buffer, and only when the data is in there, it copies it to the user buffer, without keeping the mut ref around. + +## Appendix B: (Digression) Do we need the Data Model handlers to be `async` in the first place? + +Imagine that the HAL layer is actually _not_ requiring `await`s. I would say this might in fact be the norm rather than the exception for typical IO devices. Matter clusters seem to be semantically organized in such a way, that in fact no HAL `await` is necessary: +* When an on-off cluster is reporting its state, it is reporting its _current_ state, which requires e.g. a non-blocking read of an input pin. It is not awaiting anything +* When an on-off cluster is supposed to toggle from "on" to "off" or the other way around, it is setting an output pin to "high" or "low" without "waiting" for anything else +* When a window blinds cluster receives an "open" command, the semantics of that command is **not** that we should reply to that command only when the blinds are completely opened 20 or so seconds later. The semantics is that we should _turn on_ the blinds' motor and then reply immediately. Which is also e.g. setting an output pin to "high" and then responding the IM command without `await`-ing anything. In other words, the command is just "_start_ opening", not "start opening and wait with the command response until the blinds are opened completely" +* Ditto for reading the current state of the window blinds - we are supposed to report the current opening/closing _progress_ (as in e.g. "50% opened") +* Ditto for complex clusters like the multimedia ones + +Does that mean that we should retire our `async` [`AsyncHandler` Data Model trait](https://github.com/project-chip/rs-matter/blob/main/rs-matter/src/data_model/objects/handler.rs#L299) and only support [the non-`async` `Handler` one](https://github.com/project-chip/rs-matter/blob/main/rs-matter/src/data_model/objects/handler.rs#L35)? +No because we might have a HAL that is really much easier to express with `await`-ing. Imagine a Matter bridge device that communicates with the non-Matter devices it is bridging over the network. It is very attractice and simple to have the possibility of e.g. an `async` `AsyncHandler::invoke` on-off cluster implementation, that - while inside the `invoke` method - opens an HTTP REST request to the remove device, sends the request using `async` IO and awaits the `200 OK` response using `async` IO. Contrast this with a complex caching logic where you need to notify an interim layer that it needs to - at some point - send an HTTP request; and then we would be reporting back "the light went on" even if - in fact - it *didn't*, due to the device being temporarily offline or whatever. (Not that some Matter controllers don't operate like that anyway! :) ) + +So in conclusion, I think we have to preserve the current asynchronous `AsyncHandler` contract, as it is a superset of what the user might actually need. If we ~~(ever)~~ (UPDATE: I did) implement an intelligent buffer management scheme in the Interaction Model, we might introduce a new set of methods in the `AsyncHandler` trait: `AsyncHandler::xxx_awaits(&self) -> bool`. This way the user would be able to indicate if their cluster(s) are really needing asynchrony - and if not - the Interaction Layer might use this information to skip on using extra buffers for sending. For one, all clusters in Endpoint 0 are purely computational (just like the whole Secure Channel impl), so they do not really need an extra TX buffer. Or even an extra RX buffer, for that matter. + +## Appendix C: High level summary of code changes + +This is a non-exhaustive summary of the changes accompanying this RFC. +All changes are avilable in a branch [here](https://github.com/ivmarkov/rs-matter/tree/next). + +#### New modules / types (or existing ones which were almost re-written) + +Transport layer: +* [`rs_matter::transport::core::*`] + * All transport code is re-assembled under a new `TransportMgr` type. Heavily modified so better to assume it is brand new +* [`rs_matter::transport::exchange::*`] + * This is the improved `Exchange` instance and all acompanying types, like `ExchangeId`, `RxMessage`, `TxMessage`, `MessageMeta`. A lot of these types are brand new, or so modified in terms of impl as if these are brand new + +Exchange responders / exchange handlers: +* [`rs_matter::responder::Responder`] + * A generic way to respond/accept multiple incoming exchanges simultaneously without using async executor and utilizing only intra-task concurrency (i.e. `select`/`join`). Responders need an `ExchangeHandler` instance so as to apply it to the incoming exchanges. +* [`rs_matter::responder::ExchangeHandler`] - and its composition - `CompositeExchangeHandler` + * Something that can handle exchanges. Intuitively, this is a protocol handler, like the ones for IM and SC. `DataModel` and `SecureChannel` - the two protocol handlers provided by `rs-matter` out of the box implement the simplistic `ExchangeHandler` contract. +* [`rs_matter::responder::DefaultResponder`] + * "Out of the box" composition of the IM and secure channel implementations in `rs-matter` into an exchange responder. + +IM / SecureChannel: +* [`rs_matter::interaction_model::busy::BusyInteractionModel`] + * A very simple Interaction Model implementation that answers with an IM Status Code `Busy` to every incoming request that initiates a new exchange. +* [`rs_matter::secure_channel::busy::BusySecureChannel`] + * Ditto, but for Secure Channel. + +Utilities: +* [`rs_matter::utils::signal::Signal`] + * A new async primitive. Used directly in the `Subscriptions` implementation of IM subscriptions, and indirectly - by `Notification` and `IfMutex` +* [`rs_matter::utils::notification::Notification`] + * A `Notification` primitive which used to be based on `embassy_sync::Signal` and is now based on our own `utils::signal::Signal`. Implementation details not so important as the new and the impls behave identically +* [`rs_matter::utils::ifmutex::IfMutex`] + * An `IfMutex` primitive which is essentially an asynchronous mutex, except slightly more powerful than the `embassy_sync::Mutex` primitive after which it is modeled, in that `IfMutex` can conditionally lock the mutex, unlike its `embassy_sync::Mutex` counterpart +* [`rs_matter::utils::buf::BufferAccess`], [`rs_matter::utils::buf::PooledBuffers`] + * `BufferAccess` is a trait for a "slab" allocator that can allocate memory of the same size and shape. Asynchronous (i.e. depending on the implementation, calling code might await until a buffer is available) + * `PooledBuffers` is a simple implementation of the `BufferAccess` contract that allocates memory from a fixed, pre-allocated pool (i.e. no heap operations). + +#### Heavily modified modules / types + +Transport layer: +* [`rs_matter::transport::packet::*`] + * The notion of a packet header is now disconnected from the notion of a (decoded or encoded) packet payload. + * What this means is that the packet proto and plain headers can now be decoded/encoded from/to any container, where the container is still represented by a `WriteBuf` / `ParseBuf`, except that these instances are taken during encoding / decoding as method _parameters_. In the past, the `Packet` struct (which no longer ecists and is superceded by the `PacketHdr` struct) always assumed that the payload is encoded/decoded in a pair of `WriteBuf`/`ParseBuf` instances which were **owned** by the `Packet` struct. This - in turn - introduced lifetime issues when we switched to a single pair of RX + TX packets owned by the `TransportMgr` instance. + +IM layer: +* `rs_matter::data_model::core::DataModel` + * Modified so as to support subscriptions (new struct - `Subscriptions`) + * Modified to support better buffer utilization: + * Interactions that don't need TX or RX buffers (like timeout request/response) don't request these buffers + * Interactions that only involve non-awaiting clusters don't need and use an extra RX buffer (as these interactions don't await) + +Secure Channel layer: +* `Case`, `Pake` - modified so as _not_ to need additional `&mut [u8]`-shaped buffers +* Modified not to need `CaseSession`. All session state is kept in a regular `Session` instance which gets reserved prior to the exchange beginning +* Still pending for a future PR is an optimization of the other buffers it still uses + +IM Integration tests' `ImEngine`: +* These tests are now more end-to-end - meaning - both sides (the initiator and the responder) use the Matter transport layer +* The communication is organized as two separate `Matter` instances, where the "remote" one is the "server" or the "device", and the "local" one is the client +* TBD: We have to decide what to do with the IM layer integration tests. I'm kind of drifting these towards end to end tests, but the opposite is a completely valid direction as well - we might decide to move these back into unit tests that live inside the `rs_matter::data_model` module and only know / test the `DataModel` layer + +#### Slightly modified modules / types + +* `rs_matter::core::Matter` + * New methods, `reset`, `run` and `run_builtin_mdns` which delegate to `TransportMgr` + +* `rs_matter::data_model::objects::AsyncHandler` + * New functions which return `true` by default - `read_awaits`, `write_awaits` and `invoke_awaits` + +* `rs_matter::data_model::objects::HandlerCompat` (mapping of non-awaiting `Handler` to `AsyncHandler`) + * New functions which return `false` by default - `read_awaits`, `write_awaits` and `invoke_awaits` + +* `rs_matter::data_model::objects::Node` + * `read` method adjusted to take `ReportDataReq` enum, which is an enum that represents either a read request, or a subscribe request + +* `rs_matter::mdns::builtin::MdnsImpl` + * Changes due to `BufferAccess` now being generic on the type of buffer it offers + +* `rs_matter::transport::network::Address` + * Now capable of modeling TCP transport addresses; the transport layer is adjusted so as not to retransmit messages when these are sent over a reliable protocol + * Note that a TCP-based implementation of `NetworkSend` and `NetworkReceive` is still pending though, but that's external to the core Matter transport impl, which should now (in theory) work over reliable transports as well + +* `BufferAccess` + * Buffer type generified (could be an e.g. `&mut [u8]`, or a `&mut heapless::Vec` or something else) + +* `EitherUnwrap` + * Renamed to `Coalesce` and can now be used with `join*` in addition to `select*` future combinators \ No newline at end of file diff --git a/examples/onoff_light/src/dev_att.rs b/examples/onoff_light/src/dev_att.rs index 6c6a1501..9174ccbe 100644 --- a/examples/onoff_light/src/dev_att.rs +++ b/examples/onoff_light/src/dev_att.rs @@ -21,7 +21,7 @@ use rs_matter::error::{Error, ErrorCode}; pub struct HardCodedDevAtt {} impl HardCodedDevAtt { - pub fn new() -> Self { + pub const fn new() -> Self { Self {} } } diff --git a/examples/onoff_light/src/main.rs b/examples/onoff_light/src/main.rs index c99b80a0..6b076868 100644 --- a/examples/onoff_light/src/main.rs +++ b/examples/onoff_light/src/main.rs @@ -19,23 +19,29 @@ use core::borrow::Borrow; use core::pin::pin; use std::net::UdpSocket; -use embassy_futures::select::select3; +use embassy_futures::select::{select, select4}; +use embassy_sync::blocking_mutex::raw::NoopRawMutex; +use embassy_time::{Duration, Timer}; use log::info; use rs_matter::core::{CommissioningData, Matter}; use rs_matter::data_model::cluster_basic_information::BasicInfoConfig; use rs_matter::data_model::cluster_on_off; +use rs_matter::data_model::core::IMBuffer; use rs_matter::data_model::device_types::DEV_TYPE_ON_OFF_LIGHT; use rs_matter::data_model::objects::*; use rs_matter::data_model::root_endpoint; +use rs_matter::data_model::subscriptions::Subscriptions; use rs_matter::data_model::system_model::descriptor; use rs_matter::error::Error; use rs_matter::mdns::MdnsService; use rs_matter::persist::Psm; +use rs_matter::respond::DefaultResponder; use rs_matter::secure_channel::spake2p::VerifierData; -use rs_matter::transport::core::{PacketBuffers, MATTER_SOCKET_BIND_ADDR}; -use rs_matter::utils::select::EitherUnwrap; +use rs_matter::transport::core::MATTER_SOCKET_BIND_ADDR; +use rs_matter::utils::buf::PooledBuffers; +use rs_matter::utils::select::Coalesce; use rs_matter::MATTER_PORT; mod dev_att; @@ -49,7 +55,7 @@ fn main() -> Result<(), Error> { // e.g., an opt-level of "0" will require a several times' larger stack. // // Optimizing/lowering `rs-matter` memory consumption is an ongoing topic. - .stack_size(180 * 1024) + .stack_size(95 * 1024) .spawn(run) .unwrap(); @@ -62,9 +68,9 @@ fn run() -> Result<(), Error> { ); info!( - "Matter memory: Matter={}, PacketBuffers={}", + "Matter memory: Matter={}B, IM Buffers={}B", core::mem::size_of::(), - core::mem::size_of::(), + core::mem::size_of::>() ); let dev_det = BasicInfoConfig { @@ -81,57 +87,92 @@ fn run() -> Result<(), Error> { let dev_att = dev_att::HardCodedDevAtt::new(); - // NOTE: - // For `no_std` environments, provide your own epoch and rand functions here - let epoch = rs_matter::utils::epoch::sys_epoch; - let rand = rs_matter::utils::rand::sys_rand; - let matter = Matter::new( - // vid/pid should match those in the DAC &dev_det, &dev_att, + // NOTE: + // For `no_std` environments, provide your own epoch and rand functions here MdnsService::Builtin, - epoch, - rand, + rs_matter::utils::epoch::sys_epoch, + rs_matter::utils::rand::sys_rand, MATTER_PORT, ); + matter.initialize_transport_buffers()?; + info!("Matter initialized"); - let handler = HandlerCompat(handler(&matter)); + let buffers = PooledBuffers::<10, NoopRawMutex, _>::new(0); + + info!("IM buffers initialized"); + + let mut mdns = pin!(run_mdns(&matter)); + + let on_off = cluster_on_off::OnOffCluster::new(*matter.borrow()); + + let subscriptions = Subscriptions::<3>::new(); + + // Assemble our Data Model handler by composing the predefined Root Endpoint handler with our custom On/Off clusters + let dm_handler = HandlerCompat(dm_handler(&matter, &on_off)); + + // Create a default responder capable of handling up to 3 subscriptions + // All other subscription requests will be turned down with "resource exhausted" + let responder = DefaultResponder::new(&matter, &buffers, &subscriptions, dm_handler); + info!( + "Responder memory: Responder={}B, Runner={}B", + core::mem::size_of_val(&responder), + core::mem::size_of_val(&responder.run::<4, 4>()) + ); + + // Run the responder with up to 4 handlers (i.e. 4 exchanges can be handled simultenously) + // Clients trying to open more exchanges than the ones currently running will get "I'm busy, please try again later" + let mut respond = pin!(responder.run::<4, 4>()); + + // This is a sample code that simulates state changes triggered by the HAL + // Changes will be properly communicated to the Matter controllers and other Matter apps (i.e. Google Home, Alexa), thanks to subscriptions + let mut device = pin!(async { + loop { + Timer::after(Duration::from_secs(5)).await; + + on_off.set(!on_off.get()); + subscriptions.notify_changed(); + + info!("Lamp toggled"); + } + }); // NOTE: // When using a custom UDP stack (e.g. for `no_std` environments), replace with a UDP socket bind for your custom UDP stack // The returned socket should be splittable into two halves, where each half implements `UdpSend` and `UdpReceive` respectively let socket = async_io::Async::::bind(MATTER_SOCKET_BIND_ADDR)?; - let mut packet_buffers = PacketBuffers::new(); - let mut runner = pin!(matter.run( + // Run the Matter and mDNS transports + let mut transport = pin!(matter.run( &socket, &socket, - &mut packet_buffers, - CommissioningData { + Some(CommissioningData { // TODO: Hard-coded for now verifier: VerifierData::new_with_pw(123456, *matter.borrow()), discriminator: 250, - }, - &handler, + }), )); - let mut mdns_runner = pin!(run_mdns(&matter)); - // NOTE: // Replace with your own persister for e.g. `no_std` environments let mut psm = Psm::new(&matter, std::env::temp_dir().join("rs-matter"))?; - let mut psm_runner = pin!(psm.run()); - - let runner = select3(&mut runner, &mut mdns_runner, &mut psm_runner); + let mut persist = pin!(psm.run()); + + // Combine all async tasks in a single one + let all = select4( + &mut transport, + &mut mdns, + &mut persist, + select(&mut respond, &mut device).coalesce(), + ); // NOTE: // Replace with a different executor for e.g. `no_std` environments - futures_lite::future::block_on(runner).unwrap()?; - - Ok(()) + futures_lite::future::block_on(all.coalesce()) } const NODE: Node<'static> = Node { @@ -146,7 +187,10 @@ const NODE: Node<'static> = Node { ], }; -fn handler<'a>(matter: &'a Matter<'a>) -> impl Metadata + NonBlockingHandler + 'a { +fn dm_handler<'a>( + matter: &'a Matter<'a>, + on_off: &'a cluster_on_off::OnOffCluster, +) -> impl Metadata + NonBlockingHandler + 'a { ( NODE, root_endpoint::handler(0, matter) @@ -155,11 +199,7 @@ fn handler<'a>(matter: &'a Matter<'a>) -> impl Metadata + NonBlockingHandler + ' descriptor::ID, descriptor::DescriptorCluster::new(*matter.borrow()), ) - .chain( - 1, - cluster_on_off::ID, - cluster_on_off::OnOffCluster::new(*matter.borrow()), - ), + .chain(1, cluster_on_off::ID, on_off), ) } diff --git a/rs-matter/Cargo.toml b/rs-matter/Cargo.toml index 98605ab3..d6689157 100644 --- a/rs-matter/Cargo.toml +++ b/rs-matter/Cargo.toml @@ -20,6 +20,7 @@ alloc = [] openssl = ["alloc", "dep:openssl", "foreign-types", "hmac", "sha2"] mbedtls = ["alloc", "dep:mbedtls"] rustcrypto = ["alloc", "sha2", "hmac", "pbkdf2", "hkdf", "aes", "ccm", "p256", "elliptic-curve", "crypto-bigint", "x509-cert", "rand_core"] +large-buffers = [] # TCP support [dependencies] rs-matter-macros = { version = "0.1", path = "../rs-matter-macros" } @@ -45,6 +46,7 @@ domain = { version = "0.9", default-features = false, features = ["heapless"] } octseq = { version = "0.3", default-features = false } portable-atomic = "1" qrcodegen-no-heap = "1.8" +scopeguard = "1" # crypto openssl = { version = "0.10", optional = true } diff --git a/rs-matter/src/core.rs b/rs-matter/src/core.rs index f6bd0b6f..ec13fe4c 100644 --- a/rs-matter/src/core.rs +++ b/rs-matter/src/core.rs @@ -17,7 +17,7 @@ use core::{borrow::Borrow, cell::RefCell}; -use embassy_sync::{blocking_mutex::raw::NoopRawMutex, mutex::Mutex}; +use embassy_sync::blocking_mutex::raw::NoopRawMutex; use crate::{ acl::AclMgr, @@ -27,15 +27,14 @@ use crate::{ }, error::*, fabric::FabricMgr, - mdns::{Mdns, MdnsImpl, MdnsService}, + mdns::{Mdns, MdnsService}, pairing::{print_pairing_code_and_qr, DiscoveryCapabilities}, secure_channel::{pake::PaseMgr, spake2p::VerifierData}, transport::{ - exchange::{ExchangeCtx, MAX_EXCHANGES}, - packet::{MAX_RX_BUF_SIZE, MAX_TX_BUF_SIZE}, - session::SessionMgr, + core::{PacketBufferExternalAccess, TransportMgr}, + network::{NetworkReceive, NetworkSend}, }, - utils::{buf::BufferAccessImpl, epoch::Epoch, rand::Rand, select::Notification}, + utils::{buf::BufferAccess, epoch::Epoch, notification::Notification, rand::Rand}, }; /* The Matter Port */ @@ -55,20 +54,13 @@ pub struct Matter<'a> { pub acl_mgr: RefCell, // Public for tests pub(crate) pase_mgr: RefCell, pub(crate) failsafe: RefCell, - persist_notification: Notification, - pub(crate) send_notification: Notification, - pub(crate) mdns: MdnsImpl<'a>, - pub(crate) tx_buf: BufferAccessImpl, - pub(crate) rx_buf: BufferAccessImpl, + pub transport_mgr: TransportMgr<'a>, // Public for tests + persist_notification: Notification, pub(crate) epoch: Epoch, pub(crate) rand: Rand, dev_det: &'a BasicInfoConfig<'a>, dev_att: &'a dyn DevAttDataFetcher, pub(crate) port: u16, - pub(crate) exchanges: RefCell>, - pub(crate) ephemeral: RefCell>, - pub(crate) ephemeral_mutex: Mutex, - pub session_mgr: RefCell, // Public for tests } impl<'a> Matter<'a> { @@ -106,23 +98,20 @@ impl<'a> Matter<'a> { acl_mgr: RefCell::new(AclMgr::new()), pase_mgr: RefCell::new(PaseMgr::new(epoch, rand)), failsafe: RefCell::new(FailSafe::new()), + transport_mgr: TransportMgr::new(mdns.new_impl(dev_det, port), epoch, rand), persist_notification: Notification::new(), - send_notification: Notification::new(), - mdns: mdns.new_impl(dev_det, port), - rx_buf: BufferAccessImpl::new(), - tx_buf: BufferAccessImpl::new(), epoch, rand, dev_det, dev_att, port, - exchanges: RefCell::new(heapless::Vec::new()), - ephemeral: RefCell::new(None), - ephemeral_mutex: Mutex::new(()), - session_mgr: RefCell::new(SessionMgr::new(epoch, rand)), } } + pub fn initialize_transport_buffers(&self) -> Result<(), Error> { + self.transport_mgr.initialize_buffers() + } + pub fn dev_det(&self) -> &BasicInfoConfig<'_> { self.dev_det } @@ -136,7 +125,9 @@ impl<'a> Matter<'a> { } pub fn load_fabrics(&self, data: &[u8]) -> Result<(), Error> { - self.fabric_mgr.borrow_mut().load(data, &self.mdns) + self.fabric_mgr + .borrow_mut() + .load(data, &self.transport_mgr.mdns) } pub fn load_acls(&self, data: &[u8]) -> Result<(), Error> { @@ -155,7 +146,7 @@ impl<'a> Matter<'a> { self.acl_mgr.borrow().is_changed() || self.fabric_mgr.borrow().is_changed() } - pub fn start_comissioning( + fn start_comissioning( &self, dev_comm: CommissioningData, buf: &mut [u8], @@ -172,7 +163,7 @@ impl<'a> Matter<'a> { self.pase_mgr.borrow_mut().enable_pase_session( dev_comm.verifier, dev_comm.discriminator, - &self.mdns, + &self.transport_mgr.mdns, )?; Ok(true) @@ -181,12 +172,67 @@ impl<'a> Matter<'a> { } } + pub fn reset(&self) { + self.transport_mgr.reset(); + } + + pub async fn run( + &self, + send: S, + recv: R, + dev_comm: Option, + ) -> Result<(), Error> + where + S: NetworkSend, + R: NetworkReceive, + { + if let Some(dev_comm) = dev_comm { + let buf_access = PacketBufferExternalAccess(&self.transport_mgr.rx); + let mut buf = buf_access.get().await.ok_or(ErrorCode::NoSpace)?; + + self.start_comissioning(dev_comm, &mut buf)?; + } + + self.transport_mgr.run(send, recv).await + } + + #[cfg(not(all( + feature = "std", + any(target_os = "macos", all(feature = "zeroconf", target_os = "linux")) + )))] + pub async fn run_builtin_mdns( + &self, + send: S, + recv: R, + host: crate::mdns::Host<'_>, + interface: Option, + ) -> Result<(), Error> + where + S: NetworkSend, + R: NetworkReceive, + { + self.transport_mgr + .run_builtin_mdns(send, recv, host, interface) + .await + } + + /// Notify that the ACLs or Fabrics _might_ have changed + /// This method is supposed to be called after processing SC and IM messages that might affect the ACLs or Fabrics. + /// + /// The default IM and SC handlers (`DataModel` and `SecureChannel`) do call this method after processing the messages. + /// + /// TODO: Fix the method name as it is not clear enough. Potentially revamp the whole persistence notification logic pub fn notify_changed(&self) { if self.is_changed() { - self.persist_notification.signal(()); + self.persist_notification.notify(); } } + /// A hook for user persistence code to wait for potential changes in ACLs and/or Fabrics. + /// Once this future resolves, user code is supposed to inspect ACLs and Fabrics for changes, and + /// if there are changes, persist them. + /// + /// TODO: Fix the method name as it is not clear enough. Potentially revamp the whole persistence notification logic pub async fn wait_changed(&self) { self.persist_notification.wait().await } @@ -230,7 +276,7 @@ impl<'a> Borrow for Matter<'a> { impl<'a> Borrow for Matter<'a> { fn borrow(&self) -> &(dyn Mdns + 'a) { - &self.mdns + &self.transport_mgr.mdns } } diff --git a/rs-matter/src/data_model/cluster_on_off.rs b/rs-matter/src/data_model/cluster_on_off.rs index b045d3f3..211b6644 100644 --- a/rs-matter/src/data_model/cluster_on_off.rs +++ b/rs-matter/src/data_model/cluster_on_off.rs @@ -74,6 +74,10 @@ impl OnOffCluster { } } + pub fn get(&self) -> bool { + self.on.get() + } + pub fn set(&self, on: bool) { if self.on.get() != on { self.on.set(on); diff --git a/rs-matter/src/data_model/core.rs b/rs-matter/src/data_model/core.rs index 12b67c63..f010e81d 100644 --- a/rs-matter/src/data_model/core.rs +++ b/rs-matter/src/data_model/core.rs @@ -15,142 +15,821 @@ * limitations under the License. */ -use portable_atomic::{AtomicU32, Ordering}; +use core::cell::{Cell, RefCell}; +use core::iter::Peekable; +use core::pin::pin; +use core::time::Duration; -use super::objects::*; -use crate::{ - alloc, - error::*, - interaction_model::core::Interaction, - transport::{exchange::Exchange, packet::Packet}, +use embassy_futures::select::select; +use embassy_time::{Instant, Timer}; +use log::{debug, error, info, warn}; + +use crate::acl::Accessor; +use crate::interaction_model::messages::ib::AttrStatus; +use crate::utils::buf::BufferAccess; +use crate::{error::*, Matter}; + +use crate::interaction_model::core::{ + IMStatusCode, OpCode, ReportDataReq, PROTO_ID_INTERACTION_MODEL, }; +use crate::interaction_model::messages::msg::{ + InvReq, InvRespTag, ReadReq, ReportDataTag, StatusResp, SubscribeReq, SubscribeResp, TimedReq, + WriteReq, WriteRespTag, +}; +use crate::respond::ExchangeHandler; +use crate::tlv::{get_root_node_struct, FromTLV, TLVWriter, TagType}; +use crate::transport::exchange::{Exchange, MAX_EXCHANGE_RX_BUF_SIZE, MAX_EXCHANGE_TX_BUF_SIZE}; +use crate::utils::writebuf::WriteBuf; -// TODO: For now... -static SUBS_ID: AtomicU32 = AtomicU32::new(1); +use super::objects::*; +use super::subscriptions::Subscriptions; /// The Maximum number of expanded writer request per transaction /// /// The write requests are first wildcard-expanded, and these many number of /// write requests per-transaction will be supported. -pub const MAX_WRITE_ATTRS_IN_ONE_TRANS: usize = 7; +const MAX_WRITE_ATTRS_IN_ONE_TRANS: usize = 7; + +pub type IMBuffer = heapless::Vec; + +struct SubscriptionBuffer { + node_id: u64, + id: u32, + buffer: B, +} + +/// An `ExchangeHandler` implementation capable of handling responder exchanges for the Interaction Model protocol. +/// The implementation needs a `DataModelHandler` instance to interact with the underlying clusters of the data model. +pub struct DataModel<'a, const N: usize, B, T> +where + B: BufferAccess, +{ + handler: T, + subscriptions: &'a Subscriptions, + subscriptions_buffers: RefCell>, N>>, + buffers: &'a B, +} + +impl<'a, const N: usize, B, T> DataModel<'a, N, B, T> +where + B: BufferAccess, + T: DataModelHandler, +{ + /// Create the handler. + /// + /// The parameters are as follows: + /// * `buffers` - a reference to an implementation of `BufferAccess` which is used for allocating RX and TX buffers on the fly, when necessary + /// * `subscriptions` - a reference to a `Subscriptions` struct which is used for managing subscriptions. `N` designates the maximum + /// number of subscriptions that can be managed by this handler. + /// * `handler` - an instance of type `T` which implements the `DataModelHandler` trait. This instance is used for interacting with the underlying + /// clusters of the data model. + #[inline(always)] + pub const fn new(buffers: &'a B, subscriptions: &'a Subscriptions, handler: T) -> Self { + Self { + handler, + subscriptions, + subscriptions_buffers: RefCell::new(heapless::Vec::new()), + buffers, + } + } + + /// Answer a responding exchange using the `DataModelHandler` instance wrapped by this exchange handler. + pub async fn handle(&self, exchange: &mut Exchange<'_>) -> Result<(), Error> { + let mut timeout_instant = None; + + loop { + let mut repeat = false; + + if exchange.rx().is_err() { + exchange.recv_fetch().await?; + } + + let meta = exchange.rx()?.meta(); + if meta.proto_id != PROTO_ID_INTERACTION_MODEL { + Err(ErrorCode::InvalidProto)?; + } + + match meta.opcode::()? { + OpCode::ReadRequest => self.read(exchange).await?, + OpCode::WriteRequest => { + repeat = self.write(exchange, timeout_instant.take()).await?; + } + OpCode::InvokeRequest => self.invoke(exchange, timeout_instant.take()).await?, + OpCode::SubscribeRequest => self.subscribe(exchange).await?, + OpCode::TimedRequest => { + timeout_instant = Some(self.timed(exchange).await?); + repeat = true; + } + opcode => { + error!("Invalid opcode: {:?}", opcode); + Err(ErrorCode::InvalidOpcode)? + } + } + + if !repeat { + break; + } + } + + exchange.acknowledge().await?; + exchange.matter().notify_changed(); + + Ok(()) + } + + async fn read(&self, exchange: &mut Exchange<'_>) -> Result<(), Error> { + let Some(mut tx) = self.tx_buffer(exchange).await? else { + return Ok(()); + }; + + let mut wb = WriteBuf::new(&mut tx); -pub struct DataModel(T); + let metadata = self.handler.lock().await; -impl DataModel { - pub fn new(handler: T) -> Self { - Self(handler) + let req = ReadReq::from_tlv(&get_root_node_struct(exchange.rx()?.payload())?)?; + debug!("IM: Read request: {:?}", req); + + let req = ReportDataReq::Read(&req); + + let accessor = exchange.accessor()?; + + // Will the clusters that are to be invoked await? + let awaits = metadata.node().read(&req, None, &accessor).any(|item| { + item.map(|attr| self.handler.read_awaits(&attr)) + .unwrap_or(false) + }); + + if !awaits { + // No, they won't. Answer the request by directly using the RX packet + // of the transport layer, as the operation won't await. + + let node = metadata.node(); + let mut attrs = node.read(&req, None, &accessor).peekable(); + + if !req + .respond(&self.handler, None, &mut attrs, &mut wb, true) + .await? + { + drop(attrs); + + exchange.send(OpCode::ReportData, wb.as_slice()).await?; + + // TODO: We are unconditionally using `suppress_resp = true` here. + // However, the spec is a bit unclear when `suppress_resp = true` is allowed. + // + // At one place, it says this is a decision of the caller (i.e. what we do) + // At another place, it says it is a decision of the caller, but _only_ if the + // sets of attributes and events to be reported are both empty. + // + // I've also noticed the other peer (Google Controller) to reply with a status code + // (that we don't expect due to `suppress_resp = true`) in the case of malformed response... + // + // Resolve this discrepancy in future. + // Self::recv_status(exchange).await?; + + return Ok(()); + } + } + + // The clusters will await. + // Allocate a separate RX buffer then and copy the RX packet into this buffer, + // so as not to hold on to the transport layer (single) RX packet for too long + // and block send / receive for everybody + + let Some(rx) = self.rx_buffer(exchange).await? else { + return Ok(()); + }; + + let req = ReadReq::from_tlv(&get_root_node_struct(&rx)?)?; + let req = ReportDataReq::Read(&req); + + let node = metadata.node(); + let mut attrs = node.read(&req, None, &accessor).peekable(); + + loop { + let more_chunks = req + .respond(&self.handler, None, &mut attrs, &mut wb, true) + .await?; + + exchange.send(OpCode::ReportData, wb.as_slice()).await?; + + if more_chunks && !Self::recv_status_success(exchange).await? { + break; + } + + if !more_chunks { + break; + } + } + + Ok(()) } - pub async fn handle<'r, 'p>( + async fn write( &self, - exchange: &'r mut Exchange<'_>, - rx: &'r mut Packet<'p>, - tx: &'r mut Packet<'p>, - rx_status: &'r mut Packet<'p>, - ) -> Result<(), Error> - where - T: DataModelHandler, - { - let timeout = Interaction::timeout(exchange, rx, tx).await?; - - let mut interaction = alloc!(Interaction::new( - exchange, - rx, - tx, - rx_status, - || SUBS_ID.fetch_add(1, Ordering::SeqCst), - timeout, - )?); - - #[cfg(feature = "alloc")] - let interaction = &mut *interaction; - - #[cfg(not(feature = "alloc"))] - let interaction = &mut interaction; - - let metadata = self.0.lock().await; - - if interaction.start().await? { - match interaction { - Interaction::Read { - req, - ref mut driver, - } => { - let accessor = driver.accessor()?; - - 'outer: for item in metadata.node().read(req, None, &accessor) { - while !AttrDataEncoder::handle_read(&item, &self.0, &mut driver.writer()?) - .await? - { - if !driver.send_chunk(req).await? { - break 'outer; - } - } - } + exchange: &mut Exchange<'_>, + timeout_instant: Option, + ) -> Result { + let req = WriteReq::from_tlv(&get_root_node_struct(exchange.rx()?.payload())?)?; + debug!("IM: Write request: {:?}", req); + + let timed = req.timed_request.unwrap_or(false); + + if self.timed_out(exchange, timeout_instant, timed).await? { + return Ok(false); + } - driver.complete(req).await?; + let Some(mut tx) = self.tx_buffer(exchange).await? else { + return Ok(false); + }; + + let mut wb = WriteBuf::new(&mut tx); + + let metadata = self.handler.lock().await; + + let req = WriteReq::from_tlv(&get_root_node_struct(exchange.rx()?.payload())?)?; + + // Will the clusters that are to be invoked await? + let awaits = metadata + .node() + .write(&req, &exchange.accessor()?) + .any(|item| { + item.map(|(attr, _)| self.handler.write_awaits(&attr)) + .unwrap_or(false) + }); + + let more_chunks = if awaits { + // Yes, they will + // Allocate a separate RX buffer then and copy the RX packet into this buffer, + // so as not to hold on to the transport layer (single) RX packet for too long + // and block send / receive for everybody + + let Some(rx) = self.rx_buffer(exchange).await? else { + return Ok(false); + }; + + let req = WriteReq::from_tlv(&get_root_node_struct(&rx)?)?; + + req.respond( + &self.handler, + &exchange.accessor()?, + &metadata.node(), + &mut wb, + ) + .await? + } else { + // No, they won't. Answer the request by directly using the RX packet + // of the transport layer, as the operation won't await. + + req.respond( + &self.handler, + &exchange.accessor()?, + &metadata.node(), + &mut wb, + ) + .await? + }; + + exchange.send(OpCode::WriteResponse, wb.as_slice()).await?; + + Ok(more_chunks) + } + + async fn invoke( + &self, + exchange: &mut Exchange<'_>, + timeout_instant: Option, + ) -> Result<(), Error> { + let req = InvReq::from_tlv(&get_root_node_struct(exchange.rx()?.payload())?)?; + debug!("IM: Invoke request: {:?}", req); + + let timed = req.timed_request.unwrap_or(false); + + if self.timed_out(exchange, timeout_instant, timed).await? { + return Ok(()); + } + + let Some(mut tx) = self.tx_buffer(exchange).await? else { + return Ok(()); + }; + + let mut wb = WriteBuf::new(&mut tx); + + let metadata = self.handler.lock().await; + + let req = InvReq::from_tlv(&get_root_node_struct(exchange.rx()?.payload())?)?; + + // Will the clusters that are to be invoked await? + let awaits = metadata + .node() + .invoke(&req, &exchange.accessor()?) + .any(|item| { + item.map(|(cmd, _)| self.handler.invoke_awaits(&cmd)) + .unwrap_or(false) + }); + + if awaits { + // Yes, they will + // Allocate a separate RX buffer then and copy the RX packet into this buffer, + // so as not to hold on to the transport layer (single) RX packet for too long + // and block send / receive for everybody + + let Some(rx) = self.rx_buffer(exchange).await? else { + return Ok(()); + }; + + let req = InvReq::from_tlv(&get_root_node_struct(&rx)?)?; + + req.respond(&self.handler, exchange, &metadata.node(), &mut wb) + .await?; + } else { + // No, they won't. Answer the request by directly using the RX packet + // of the transport layer, as the operation won't await. + + req.respond(&self.handler, exchange, &metadata.node(), &mut wb) + .await?; + } + + exchange.send(OpCode::InvokeResponse, wb.as_slice()).await?; + + Ok(()) + } + + async fn subscribe(&self, exchange: &mut Exchange<'_>) -> Result<(), Error> { + let Some(rx) = self.rx_buffer(exchange).await? else { + return Ok(()); + }; + + let Some(mut tx) = self.tx_buffer(exchange).await? else { + return Ok(()); + }; + + let req = SubscribeReq::from_tlv(&get_root_node_struct(&rx)?)?; + debug!("IM: Subscribe request: {:?}", req); + + let node_id = exchange + .with_session(|sess| sess.get_peer_node_id().ok_or(ErrorCode::Invalid.into()))?; + + if !req.keep_subs { + self.subscriptions.remove(Some(node_id), None); + self.subscriptions_buffers + .borrow_mut() + .retain(|sb| sb.node_id != node_id); + + info!("All subscriptions for node {node_id:x} removed"); + } + + let max_int_secs = core::cmp::max(req.max_int_ceil, 40); // Say we need at least 4 secs for potential latencies + let min_int_secs = req.min_int_floor; + + let Some(id) = self.subscriptions.add(node_id, min_int_secs, max_int_secs) else { + return Self::send_status(exchange, IMStatusCode::ResourceExhausted).await; + }; + + let subscribed = Cell::new(false); + + let _guard = scopeguard::guard((), |_| { + if !subscribed.get() { + self.subscriptions.remove(None, Some(id)); + } + }); + + let primed = self + .report_data(id, node_id, &rx, &mut tx, exchange) + .await?; + + if primed { + exchange + .send_with(|_, wb| { + SubscribeResp::write(wb, id, max_int_secs)?; + Ok(Some(OpCode::SubscribeResponse.into())) + }) + .await?; + + info!("Subscription {node_id:x}::{id} created"); + + if self.subscriptions.mark_reported(id) { + let _ = self + .subscriptions_buffers + .borrow_mut() + .push(SubscriptionBuffer { + node_id, + id, + buffer: rx, + }); + + subscribed.set(true); + } + } + + Ok(()) + } + + pub async fn process_subscriptions(&self, matter: &Matter<'_>) -> Result<(), Error> { + loop { + // TODO: Un-hardcode these 4 seconds of waiting when the more precise change detection logic is implemented + let mut timeout = pin!(Timer::after(embassy_time::Duration::from_secs(4))); + let mut notification = pin!(self.subscriptions.notification.wait()); + + select(&mut notification, &mut timeout).await; + + let now = Instant::now(); + + { + while let Some((node_id, id)) = self.subscriptions.find_expired(now) { + self.subscriptions.remove(None, Some(id)); + self.subscriptions_buffers + .borrow_mut() + .retain(|sb| sb.id != id); + + info!("Subscription {node_id:x}::{id} removed due to inactivity"); } - Interaction::Write { - req, - ref mut driver, - } => { - let accessor = driver.accessor()?; - // The spec expects that a single write request like DeleteList + AddItem - // should cause all ACLs of that fabric to be deleted and the new one to be added (Case 1). - // - // This is in conflict with the immediate-effect expectation of ACL: an ACL - // write should instantaneously update the ACL so that immediate next WriteAttribute - // *in the same WriteRequest* should see that effect (Case 2). - // - // As with the C++ SDK, here we do all the ACLs checks first, before any write begins. - // Thus we support the Case1 by doing this. It does come at the cost of maintaining an - // additional list of expanded write requests as we start processing those. - let node = metadata.node(); - let write_attrs: heapless::Vec<_, MAX_WRITE_ATTRS_IN_ONE_TRANS> = - node.write(req, &accessor).collect(); - - for item in write_attrs { - AttrDataEncoder::handle_write(&item, &self.0, &mut driver.writer()?) + } + + loop { + let sub = self.subscriptions.find_report_due(now); + + if let Some((node_id, id)) = sub { + info!("About to report data for subscription {node_id:x}::{id}"); + + let subscribed = Cell::new(false); + + let _guard = scopeguard::guard((), |_| { + if !subscribed.get() { + self.subscriptions.remove(None, Some(id)); + } + }); + + // TODO: Do a more sophisticated check whether something had actually changed w.r.t. this subscription + + let index = self + .subscriptions_buffers + .borrow() + .iter() + .position(|sb| sb.id == id) + .unwrap(); + let rx = self.subscriptions_buffers.borrow_mut().remove(index).buffer; + + let mut req = SubscribeReq::from_tlv(&get_root_node_struct(&rx)?)?; + + // Only used when priming the subscription + req.dataver_filters = None; + + let mut exchange = Exchange::initiate(matter, node_id, true).await?; + + if let Some(mut tx) = self.buffers.get().await { + let primed = self + .report_data(id, node_id, &rx, &mut tx, &mut exchange) .await?; - } - driver.complete(req).await?; + exchange.acknowledge().await?; + + if primed && self.subscriptions.mark_reported(id) { + let _ = + self.subscriptions_buffers + .borrow_mut() + .push(SubscriptionBuffer { + node_id, + id, + buffer: rx, + }); + subscribed.set(true); + } + } + } else { + break; } - Interaction::Invoke { - req, - ref mut driver, - } => { - let accessor = driver.accessor()?; + } + } + } - for item in metadata.node().invoke(req, &accessor) { - let (mut tw, exchange) = driver.writer_exchange()?; + async fn timed(&self, exchange: &mut Exchange<'_>) -> Result { + let req = TimedReq::from_tlv(&get_root_node_struct(exchange.rx()?.payload())?)?; + debug!("IM: Timed request: {:?}", req); - CmdDataEncoder::handle(&item, &self.0, &mut tw, exchange).await?; - } + let timeout_instant = req.timeout_instant(exchange.matter().epoch); + + Self::send_status(exchange, IMStatusCode::Success).await?; + + Ok(timeout_instant) + } + + async fn timed_out( + &self, + exchange: &mut Exchange<'_>, + timeout_instant: Option, + timed_req: bool, + ) -> Result { + let status = { + if timed_req != timeout_instant.is_some() { + Some(IMStatusCode::TimedRequestMisMatch) + } else if timeout_instant + .map(|timeout_instant| (exchange.matter().epoch)() > timeout_instant) + .unwrap_or(false) + { + Some(IMStatusCode::Timeout) + } else { + None + } + }; + + if let Some(status) = status { + Self::send_status(exchange, status).await?; + + Ok(true) + } else { + Ok(false) + } + } + + async fn report_data( + &self, + id: u32, + node_id: u64, + rx: &[u8], + tx: &mut [u8], + exchange: &mut Exchange<'_>, + ) -> Result + where + T: DataModelHandler, + { + let mut wb = WriteBuf::new(tx); + + let req = SubscribeReq::from_tlv(&get_root_node_struct(rx)?)?; + let req = ReportDataReq::Subscribe(&req); + + let metadata = self.handler.lock().await; + + let accessor = exchange.accessor()?; + + { + let node = metadata.node(); + let mut attrs = node.read(&req, None, &accessor).peekable(); + + loop { + let more_chunks = req + .respond(&self.handler, Some(id), &mut attrs, &mut wb, false) + .await?; - driver.complete(req).await?; + exchange.send(OpCode::ReportData, wb.as_slice()).await?; + + if !Self::recv_status_success(exchange).await? { + info!("Subscription {node_id:x}::{id} removed during reporting"); + return Ok(false); } - Interaction::Subscribe { - req, - ref mut driver, - } => { - let accessor = driver.accessor()?; - - 'outer: for item in metadata.node().subscribing_read(req, None, &accessor) { - while !AttrDataEncoder::handle_read(&item, &self.0, &mut driver.writer()?) - .await? - { - if !driver.send_chunk(req).await? { - break 'outer; - } - } - } - driver.complete(req).await?; + if !more_chunks { + break; } } } + Ok(true) + } + + async fn rx_buffer(&self, exchange: &mut Exchange<'_>) -> Result>, Error> { + if let Some(mut buffer) = self.buffer(exchange).await? { + let rx = exchange.rx()?; + + buffer.clear(); + + // Safe to unwrap, as `IMBuffer` is defined to be `MAX_EXCHANGE_RX_BUF_SIZE`, i.e. it cannot be overflown + // by the payload of the received exchange. + buffer.extend_from_slice(rx.payload()).unwrap(); + + exchange.rx_done()?; + + Ok(Some(buffer)) + } else { + Ok(None) + } + } + + async fn tx_buffer(&self, exchange: &mut Exchange<'_>) -> Result>, Error> { + if let Some(mut buffer) = self.buffer(exchange).await? { + // Always safe as `IMBuffer` is defined to be `MAX_EXCHANGE_RX_BUF_SIZE`, which is bigger than `MAX_EXCHANGE_TX_BUF_SIZE` + buffer.resize_default(MAX_EXCHANGE_TX_BUF_SIZE).unwrap(); + + Ok(Some(buffer)) + } else { + Self::send_status(exchange, IMStatusCode::Busy).await?; + + Ok(None) + } + } + + async fn buffer(&self, exchange: &mut Exchange<'_>) -> Result>, Error> { + if let Some(buffer) = self.buffers.get().await { + Ok(Some(buffer)) + } else { + Self::send_status(exchange, IMStatusCode::Busy).await?; + + Ok(None) + } + } + + async fn recv_status_success(exchange: &mut Exchange<'_>) -> Result { + let rx = exchange.recv().await?; + let opcode = rx.meta().proto_opcode; + + if opcode != OpCode::StatusResponse as u8 { + warn!( + "Got opcode {opcode:02x}, while expecting status code {:02x}", + OpCode::StatusResponse as u8 + ); + + return Err(ErrorCode::Invalid.into()); + } + + let resp = StatusResp::from_tlv(&get_root_node_struct(rx.payload())?)?; + + if resp.status == IMStatusCode::Success { + Ok(true) + } else { + warn!( + "Got status response {:?}, aborting interaction", + resp.status + ); + + drop(rx); + exchange.acknowledge().await?; + + Ok(false) + } + } + + async fn send_status(exchange: &mut Exchange<'_>, status: IMStatusCode) -> Result<(), Error> { + exchange + .send_with(|_, wb| { + StatusResp::write(wb, status)?; + + Ok(Some(OpCode::StatusResponse.into())) + }) + .await + } +} + +impl<'a, const N: usize, B, T> ExchangeHandler for DataModel<'a, N, B, T> +where + T: DataModelHandler, + B: BufferAccess, +{ + async fn handle(&self, exchange: &mut Exchange<'_>) -> Result<(), Error> { + DataModel::handle(self, exchange).await + } +} + +impl<'a> ReportDataReq<'a> { + // This is the amount of space we reserve for other things to be attached towards + // the end of long reads. + const LONG_READS_TLV_RESERVE_SIZE: usize = 24; + + pub(crate) async fn respond( + &self, + handler: T, + subscription_id: Option, + attrs: &mut Peekable, + wb: &mut WriteBuf<'_>, + suppress_resp: bool, + ) -> Result + where + T: DataModelHandler, + I: Iterator, AttrStatus>>, + { + wb.reset(); + wb.shrink(Self::LONG_READS_TLV_RESERVE_SIZE)?; + + let mut tw = TLVWriter::new(wb); + + tw.start_struct(TagType::Anonymous)?; + + if let Some(subscription_id) = subscription_id { + assert!(matches!(self, ReportDataReq::Subscribe(_))); + tw.u32( + TagType::Context(ReportDataTag::SubscriptionId as u8), + subscription_id, + )?; + } else { + assert!(matches!(self, ReportDataReq::Read(_))); + } + + let has_requests = self.attr_requests().is_some(); + + if has_requests { + tw.start_array(TagType::Context(ReportDataTag::AttributeReports as u8))?; + } + + while let Some(item) = attrs.peek() { + if AttrDataEncoder::handle_read(item, &handler, &mut tw).await? { + attrs.next(); + } else { + break; + } + } + + wb.expand(Self::LONG_READS_TLV_RESERVE_SIZE)?; + let mut tw = TLVWriter::new(wb); + + if has_requests { + tw.end_container()?; + } + + let more_chunks = attrs.peek().is_some(); + + if more_chunks { + tw.bool(TagType::Context(ReportDataTag::MoreChunkedMsgs as u8), true)?; + } + + if !more_chunks && suppress_resp { + tw.bool(TagType::Context(ReportDataTag::SupressResponse as u8), true)?; + } + + tw.end_container()?; + + Ok(more_chunks) + } +} + +impl<'a> WriteReq<'a> { + async fn respond( + &self, + handler: T, + accessor: &Accessor<'_>, + node: &Node<'_>, + wb: &mut WriteBuf<'_>, + ) -> Result + where + T: DataModelHandler, + { + wb.reset(); + + let mut tw = TLVWriter::new(wb); + + tw.start_struct(TagType::Anonymous)?; + tw.start_array(TagType::Context(WriteRespTag::WriteResponses as u8))?; + + // The spec expects that a single write request like DeleteList + AddItem + // should cause all ACLs of that fabric to be deleted and the new one to be added (Case 1). + // + // This is in conflict with the immediate-effect expectation of ACL: an ACL + // write should instantaneously update the ACL so that immediate next WriteAttribute + // *in the same WriteRequest* should see that effect (Case 2). + // + // As with the C++ SDK, here we do all the ACLs checks first, before any write begins. + // Thus we support the Case1 by doing this. It does come at the cost of maintaining an + // additional list of expanded write requests as we start processing those. + let write_attrs: heapless::Vec<_, MAX_WRITE_ATTRS_IN_ONE_TRANS> = + node.write(self, accessor).collect(); + + for item in write_attrs { + AttrDataEncoder::handle_write(&item, &handler, &mut tw).await?; + } + + tw.end_container()?; + tw.end_container()?; + + Ok(self.more_chunked.unwrap_or(false)) + } +} + +impl<'a> InvReq<'a> { + async fn respond( + &self, + handler: T, + exchange: &Exchange<'_>, + node: &Node<'_>, + wb: &mut WriteBuf<'_>, + ) -> Result<(), Error> + where + T: DataModelHandler, + { + wb.reset(); + + let mut tw = TLVWriter::new(wb); + + tw.start_struct(TagType::Anonymous)?; + + // Suppress Response -> TODO: Need to revisit this for cases where we send a command back + tw.bool(TagType::Context(InvRespTag::SupressResponse as u8), false)?; + + let has_requests = self.inv_requests.is_some(); + + if has_requests { + tw.start_array(TagType::Context(InvRespTag::InvokeResponses as u8))?; + } + + let accessor = exchange.accessor()?; + + for item in node.invoke(self, &accessor) { + CmdDataEncoder::handle(&item, &handler, &mut tw, exchange).await?; + } + + if has_requests { + tw.end_container()?; + } + + tw.end_container()?; + Ok(()) } } diff --git a/rs-matter/src/data_model/mod.rs b/rs-matter/src/data_model/mod.rs index c76e07cf..a71eb0fc 100644 --- a/rs-matter/src/data_model/mod.rs +++ b/rs-matter/src/data_model/mod.rs @@ -25,4 +25,5 @@ pub mod cluster_on_off; pub mod cluster_template; pub mod root_endpoint; pub mod sdm; +pub mod subscriptions; pub mod system_model; diff --git a/rs-matter/src/data_model/objects/handler.rs b/rs-matter/src/data_model/objects/handler.rs index 73f1b25f..b76b0ec7 100644 --- a/rs-matter/src/data_model/objects/handler.rs +++ b/rs-matter/src/data_model/objects/handler.rs @@ -28,10 +28,14 @@ pub use asynch::*; pub trait DataModelHandler: super::asynch::AsyncMetadata + asynch::AsyncHandler {} impl DataModelHandler for T where T: super::asynch::AsyncMetadata + asynch::AsyncHandler {} +// TODO: Re-assess once once proper cluster change notifications are implemented. pub trait ChangeNotifier { fn consume_change(&mut self) -> Option; } +/// A version of the `AsyncHandler` trait that never awaits any operation. +/// +/// Prefer this trait when implementing handlers that are known to be non-blocking. pub trait Handler { fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error>; @@ -96,6 +100,7 @@ where } } +// TODO: Re-assess the need for this trait. pub trait NonBlockingHandler: Handler {} impl NonBlockingHandler for &T where T: NonBlockingHandler {} @@ -127,9 +132,19 @@ where impl NonBlockingHandler for (M, H) where H: NonBlockingHandler {} +/// A handler that always fails with attribute / command not found. +/// +/// Useful when chaining multiple handlers together as the end of the chain. pub struct EmptyHandler; impl EmptyHandler { + /// Chain the empty handler with another handler thus providing an "end of handler chain" + /// fallback that errors out. + /// + /// The returned chained handler works as follows: + /// - It will call the provided `handler` instance if the endpoint and cluster + /// of the incoming request do match the `handler_endpoint` and `handler_cluster` provided here. + /// - Otherwise, the empty handler would be invoked, causing the operation to error out. pub const fn chain( self, handler_endpoint: u16, @@ -159,6 +174,7 @@ impl ChangeNotifier<(u16, u32)> for EmptyHandler { } } +/// A handler that chains two handlers together in a composite handler. pub struct ChainedHandler { pub handler_endpoint: u16, pub handler_cluster: u32, @@ -167,18 +183,32 @@ pub struct ChainedHandler { } impl ChainedHandler { + /// Construct a chained handler that works as follows: + /// - It will call the provided `handler` instance if the endpoint and cluster + /// of the incoming request do match the `handler_endpoint` and `handler_cluster` provided here. + /// - Otherwise, it will call the `next` handler + pub const fn new(handler_endpoint: u16, handler_cluster: u32, handler: H, next: T) -> Self { + Self { + handler_endpoint, + handler_cluster, + handler, + next, + } + } + + /// Chain itself with another handler. + /// + /// The returned chained handler works as follows: + /// - It will call the provided `handler` instance if the endpoint and cluster + /// of the incoming request do match the `handler_endpoint` and `handler_cluster` provided here. + /// - Otherwise, it will call the `self` handler pub const fn chain

( self, handler_endpoint: u16, handler_cluster: u32, handler: H2, ) -> ChainedHandler { - ChainedHandler { - handler_endpoint, - handler_cluster, - handler, - next: self, - } + ChainedHandler::new(handler_endpoint, handler_cluster, handler, self) } } @@ -241,6 +271,8 @@ where /// Wrap your `NonBlockingHandler` or `AsyncHandler` implementation in this struct /// to get your code compilable with and without the `nightly` feature +/// +/// TODO: Re-assess the need for this struct now that we no longer use a nightly compiler. pub struct HandlerCompat(pub T); impl Handler for HandlerCompat @@ -268,6 +300,24 @@ where impl NonBlockingHandler for HandlerCompat where T: NonBlockingHandler {} +/// A helper macro that makes it easier to specify the full type of a `ChainedHandler` instantiation, +/// which can be quite annoying in the case of long chains of handlers. +/// +/// Use with type aliases: +/// ```ignore +/// pub type RootEndpointHandler<'a> = handler_chain_type!( +/// DescriptorCluster<'static>, +/// BasicInfoCluster<'a>, +/// GenCommCluster<'a>, +/// NwCommCluster, +/// AdminCommCluster<'a>, +/// NocCluster<'a>, +/// AccessControlCluster<'a>, +/// GenDiagCluster, +/// EthNwDiagCluster, +/// GrpKeyMgmtCluster +/// ); +/// ``` #[allow(unused_macros)] #[macro_export] macro_rules! handler_chain_type { @@ -296,13 +346,65 @@ mod asynch { use super::{ChainedHandler, EmptyHandler, Handler, HandlerCompat, NonBlockingHandler}; + /// A handler for processing a single IM operation: + /// read an attribute, write an attribute, or invoke a command. + /// + /// Handlers are typically implemented by user-defined clusters, but there is no 1:1 correspondence between + /// a handler and a cluster, as a single handler can handle multiple clusters and even multiple endpoints. + /// + /// Moreover, the `DataModel` implementation expects a single `AsyncHandler` instance, so the expectation + /// is that the user will compose multiple handlers into a single `AsyncHandler` instance, using `ChainedHandler` + /// or other means. pub trait AsyncHandler { + /// Provides information whether the handler will internally await while reading + /// the current value of the provided attribute. + /// + /// Handlers which report `false` via this method provide an opportunity + /// for the Data Model processing to use less memory by not storing the incoming request + /// in an intermediate buffer. + /// + /// The default implementation unconditionally returns `true` i.e. the handler is assumed to + /// await while reading any attribute. + fn read_awaits(&self, _attr: &AttrDetails) -> bool { + true + } + + /// Provides information whether the handler will internally await while updating + /// the value of the provided attribute. + /// + /// Handlers which report `false` via this method provide an opportunity + /// for the Data Model processing to use less memory by not storing the incoming request + /// in an intermediate buffer. + /// + /// The default implementation unconditionally returns `true` i.e. the handler is assumed to + /// await while writing any attribute. + fn write_awaits(&self, _attr: &AttrDetails) -> bool { + true + } + + /// Provides information whether the handler will internally await while invoking + /// the provided command. + /// + /// Handlers which report `false` via this method provide an opportunity + /// for the Data Model processing to use less memory by not storing the incoming request + /// in an intermediate buffer. + /// + /// The default implementation unconditionally returns `true` i.e. the handler is assumed to + /// await while invoking any command. + fn invoke_awaits(&self, _cmd: &CmdDetails) -> bool { + true + } + + /// Reads from the requested attribute and encodes the result using the provided encoder. async fn read<'a>( &'a self, attr: &'a AttrDetails<'_>, encoder: AttrDataEncoder<'a, '_, '_>, ) -> Result<(), Error>; + /// Writes into the requested attribute using the provided data. + /// + /// The default implementation errors out with `ErrorCode::AttributeNotFound`. async fn write<'a>( &'a self, _attr: &'a AttrDetails<'_>, @@ -311,6 +413,9 @@ mod asynch { Err(ErrorCode::AttributeNotFound.into()) } + /// Invokes the requested command with the provided data and encodes the result using the provided encoder. + /// + /// The default implementation errors out with `ErrorCode::CommandNotFound`. async fn invoke<'a>( &'a self, _exchange: &'a Exchange<'_>, @@ -326,6 +431,18 @@ mod asynch { where T: AsyncHandler, { + fn read_awaits(&self, attr: &AttrDetails) -> bool { + (**self).read_awaits(attr) + } + + fn write_awaits(&self, attr: &AttrDetails) -> bool { + (**self).write_awaits(attr) + } + + fn invoke_awaits(&self, cmd: &CmdDetails) -> bool { + (**self).invoke_awaits(cmd) + } + async fn read<'a>( &'a self, attr: &'a AttrDetails<'_>, @@ -357,6 +474,18 @@ mod asynch { where T: AsyncHandler, { + fn read_awaits(&self, attr: &AttrDetails) -> bool { + (**self).read_awaits(attr) + } + + fn write_awaits(&self, attr: &AttrDetails) -> bool { + (**self).write_awaits(attr) + } + + fn invoke_awaits(&self, cmd: &CmdDetails) -> bool { + (**self).invoke_awaits(cmd) + } + async fn read<'a>( &'a self, attr: &'a AttrDetails<'_>, @@ -388,6 +517,18 @@ mod asynch { where H: AsyncHandler, { + fn read_awaits(&self, attr: &AttrDetails) -> bool { + self.1.read_awaits(attr) + } + + fn write_awaits(&self, attr: &AttrDetails) -> bool { + self.1.write_awaits(attr) + } + + fn invoke_awaits(&self, cmd: &CmdDetails) -> bool { + self.1.invoke_awaits(cmd) + } + async fn read<'a>( &'a self, attr: &'a AttrDetails<'_>, @@ -419,6 +560,18 @@ mod asynch { where T: NonBlockingHandler, { + fn read_awaits(&self, _attr: &AttrDetails) -> bool { + false + } + + fn write_awaits(&self, _attr: &AttrDetails) -> bool { + false + } + + fn invoke_awaits(&self, _cmd: &CmdDetails) -> bool { + false + } + async fn read<'a>( &'a self, attr: &'a AttrDetails<'_>, @@ -447,6 +600,18 @@ mod asynch { } impl AsyncHandler for EmptyHandler { + fn read_awaits(&self, _attr: &AttrDetails) -> bool { + false + } + + fn write_awaits(&self, _attr: &AttrDetails) -> bool { + false + } + + fn invoke_awaits(&self, _cmd: &CmdDetails) -> bool { + false + } + async fn read<'a>( &'a self, _attr: &'a AttrDetails<'_>, @@ -461,6 +626,32 @@ mod asynch { H: AsyncHandler, T: AsyncHandler, { + fn read_awaits(&self, attr: &AttrDetails) -> bool { + if self.handler_endpoint == attr.endpoint_id && self.handler_cluster == attr.cluster_id + { + self.handler.read_awaits(attr) + } else { + self.next.read_awaits(attr) + } + } + + fn write_awaits(&self, attr: &AttrDetails) -> bool { + if self.handler_endpoint == attr.endpoint_id && self.handler_cluster == attr.cluster_id + { + self.handler.write_awaits(attr) + } else { + self.next.write_awaits(attr) + } + } + + fn invoke_awaits(&self, cmd: &CmdDetails) -> bool { + if self.handler_endpoint == cmd.endpoint_id && self.handler_cluster == cmd.cluster_id { + self.handler.invoke_awaits(cmd) + } else { + self.next.invoke_awaits(cmd) + } + } + async fn read<'a>( &'a self, attr: &'a AttrDetails<'_>, diff --git a/rs-matter/src/data_model/objects/node.rs b/rs-matter/src/data_model/objects/node.rs index 1ffa8967..4054bb0a 100644 --- a/rs-matter/src/data_model/objects/node.rs +++ b/rs-matter/src/data_model/objects/node.rs @@ -20,10 +20,10 @@ use crate::{ alloc, data_model::objects::Endpoint, interaction_model::{ - core::IMStatusCode, + core::{IMStatusCode, ReportDataReq}, messages::{ ib::{AttrPath, AttrStatus, CmdStatus, DataVersionFilter}, - msg::{InvReq, ReadReq, SubscribeReq, WriteReq}, + msg::{InvReq, WriteReq}, GenericPath, }, }, @@ -67,7 +67,7 @@ pub struct Node<'a> { impl<'a> Node<'a> { pub fn read<'s, 'm>( &'s self, - req: &'m ReadReq, + req: &'m ReportDataReq, from: Option, accessor: &'m Accessor<'m>, ) -> impl Iterator> + 'm @@ -75,31 +75,11 @@ impl<'a> Node<'a> { 's: 'm, { self.read_attr_requests( - req.attr_requests + req.attr_requests() .iter() .flat_map(|attr_requests| attr_requests.iter()), - req.dataver_filters.as_ref(), - req.fabric_filtered, - accessor, - from, - ) - } - - pub fn subscribing_read<'s, 'm>( - &'s self, - req: &'m SubscribeReq, - from: Option, - accessor: &'m Accessor<'m>, - ) -> impl Iterator> + 'm - where - 's: 'm, - { - self.read_attr_requests( - req.attr_requests - .iter() - .flat_map(|attr_requests| attr_requests.iter()), - req.dataver_filters.as_ref(), - req.fabric_filtered, + req.dataver_filters(), + req.fabric_filtered(), accessor, from, ) diff --git a/rs-matter/src/data_model/sdm/noc.rs b/rs-matter/src/data_model/sdm/noc.rs index 1abaad2d..6aea6a78 100644 --- a/rs-matter/src/data_model/sdm/noc.rs +++ b/rs-matter/src/data_model/sdm/noc.rs @@ -327,7 +327,7 @@ impl<'a> NocCluster<'a> { data: &TLVElement, ) -> Result { let noc_data = exchange - .with_session_mut(|sess| Ok(sess.take_noc_data()))? + .with_session(|sess| Ok(sess.take_noc_data()))? .ok_or(NocStatus::MissingCsr)?; if !self @@ -596,7 +596,7 @@ impl<'a> NocCluster<'a> { let noc_data = NocData::new(noc_keypair); // Store this in the session data instead of cluster data, so it gets cleared // if the session goes away for some reason - exchange.with_session_mut(|sess| { + exchange.with_session(|sess| { sess.set_noc_data(noc_data); Ok(()) })?; @@ -605,7 +605,7 @@ impl<'a> NocCluster<'a> { } fn add_rca_to_session_noc_data(exchange: &Exchange, data: &TLVElement) -> Result<(), Error> { - exchange.with_session_mut(|sess| { + exchange.with_session(|sess| { let noc_data = sess.get_noc_data().ok_or(ErrorCode::NoSession)?; let req = CommonReq::from_tlv(data).map_err(Error::map_invalid_command)?; diff --git a/rs-matter/src/data_model/subscriptions.rs b/rs-matter/src/data_model/subscriptions.rs new file mode 100644 index 00000000..9f31a9a2 --- /dev/null +++ b/rs-matter/src/data_model/subscriptions.rs @@ -0,0 +1,144 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use core::cell::RefCell; + +use embassy_sync::blocking_mutex::raw::NoopRawMutex; +use embassy_time::Instant; + +use portable_atomic::{AtomicU32, Ordering}; + +use crate::utils::notification::Notification; + +struct Subscription { + node_id: u64, + id: u32, + // We use u16 instead of embassy::Duration to save some storage + min_int_secs: u16, + // Ditto + max_int_secs: u16, + reported_at: Instant, + changed: bool, +} + +impl Subscription { + pub fn report_due(&self, now: Instant) -> bool { + self.changed + && self.reported_at + embassy_time::Duration::from_secs(self.min_int_secs as _) <= now + } + + pub fn is_expired(&self, now: Instant) -> bool { + self.reported_at + embassy_time::Duration::from_secs(self.max_int_secs as _) <= now + } +} + +/// A utility for tracking subscriptions accepted by the data model. +/// +/// The `N` type parameter specifies the maximum number of subscriptions that can be tracked at the same time. +/// Additional subscriptions are rejected by the data model with a "resource exhausted" IM status message. +pub struct Subscriptions { + next_subscription_id: AtomicU32, + subscriptions: RefCell>, + pub(crate) notification: Notification, +} + +impl Subscriptions { + /// Create the instance. + #[inline(always)] + pub const fn new() -> Self { + Self { + next_subscription_id: AtomicU32::new(1), + subscriptions: RefCell::new(heapless::Vec::new()), + notification: Notification::new(), + } + } + + /// Notify the instance that some data in the data model has changed and that it should re-evaluate the subscriptions + /// and report on those that concern the changed data. + /// + /// This method is supposed to be called by the application code whenever it changes the data model. + pub fn notify_changed(&self) { + for sub in self.subscriptions.borrow_mut().iter_mut() { + sub.changed = true; + } + + self.notification.notify(); + } + + pub(crate) fn add(&self, node_id: u64, min_int_secs: u16, max_int_secs: u16) -> Option { + let id = self.next_subscription_id.fetch_add(1, Ordering::SeqCst); + + self.subscriptions + .borrow_mut() + .push(Subscription { + node_id, + id, + min_int_secs, + max_int_secs, + reported_at: Instant::MAX, + changed: false, + }) + .map(|_| id) + .ok() + } + + /// Mark the subscription with the given ID as reported. + /// + /// Will return `false` if the subscription with the given ID does no longer exist, as it might be + /// removed by a concurrent transaction while being reported on. + pub(crate) fn mark_reported(&self, id: u32) -> bool { + let mut subscriptions = self.subscriptions.borrow_mut(); + + if let Some(sub) = subscriptions.iter_mut().find(|sub| sub.id == id) { + sub.reported_at = Instant::now(); + sub.changed = false; + + true + } else { + false + } + } + + pub(crate) fn remove(&self, node_id: Option, id: Option) { + let mut subscriptions = self.subscriptions.borrow_mut(); + while let Some(index) = subscriptions.iter().position(|sub| { + sub.node_id == node_id.unwrap_or(sub.node_id) && sub.id == id.unwrap_or(sub.id) + }) { + subscriptions.swap_remove(index); + } + } + + pub(crate) fn find_expired(&self, now: Instant) -> Option<(u64, u32)> { + self.subscriptions + .borrow() + .iter() + .find_map(|sub| sub.is_expired(now).then_some((sub.node_id, sub.id))) + } + + /// Note that this method has a side effect: + /// it updates the `reported_at` field of the subscription that is returned. + pub(crate) fn find_report_due(&self, now: Instant) -> Option<(u64, u32)> { + self.subscriptions + .borrow_mut() + .iter_mut() + .find(|sub| sub.report_due(now)) + .map(|sub| { + sub.reported_at = now; + (sub.node_id, sub.id) + }) + } +} diff --git a/rs-matter/src/error.rs b/rs-matter/src/error.rs index 052170eb..71a52abe 100644 --- a/rs-matter/src/error.rs +++ b/rs-matter/src/error.rs @@ -60,6 +60,7 @@ pub enum ErrorCode { InvalidData, InvalidKeyLength, InvalidOpcode, + InvalidProto, InvalidPeerAddr, // Invalid Auth Key in the Matter Certificate InvalidAuthKey, diff --git a/rs-matter/src/interaction_model/busy.rs b/rs-matter/src/interaction_model/busy.rs new file mode 100644 index 00000000..ef4aa7e7 --- /dev/null +++ b/rs-matter/src/interaction_model/busy.rs @@ -0,0 +1,78 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use crate::error::*; +use crate::respond::ExchangeHandler; +use crate::transport::exchange::Exchange; + +use super::{ + core::{IMStatusCode, OpCode, PROTO_ID_INTERACTION_MODEL}, + messages::msg::StatusResp, +}; + +/// A Interaction Model implementation that is only capable of sending Busy status codes +/// +/// Use with e.g. +/// +/// ```ignore +/// let matter = Matter::new(...); +/// +/// // ... +/// +/// let busy_responder = Responder::new("IM Busy Responder", BusyInteractionModel::new(), &matter, 200/*ms*/); +/// busy_responder.run::<10>().await?; +/// ``` +/// +/// ... to respond with "I'm busy, please try later" or similar status codes to all incoming IM messages, which were +/// not accepted in time by the actual Interaction Model responder, due to all its handlers being occupied with work. +pub struct BusyInteractionModel(()); + +impl BusyInteractionModel { + #[inline(always)] + pub const fn new() -> Self { + Self(()) + } + + pub async fn handle(&self, exchange: &mut Exchange<'_>) -> Result<(), Error> { + let meta = exchange.recv().await?.meta(); + if meta.proto_id != PROTO_ID_INTERACTION_MODEL { + Err(ErrorCode::InvalidProto)?; + } + + let status = match meta.opcode()? { + OpCode::ReadRequest + | OpCode::WriteRequest + | OpCode::SubscribeRequest + | OpCode::InvokeRequest => IMStatusCode::Busy, + _ => IMStatusCode::Failure, + }; + + exchange + .send_with(|_, wb| { + StatusResp::write(wb, status)?; + + Ok(Some(OpCode::StatusResponse.meta())) + }) + .await + } +} + +impl ExchangeHandler for BusyInteractionModel { + async fn handle(&self, exchange: &mut Exchange<'_>) -> Result<(), Error> { + BusyInteractionModel::handle(self, exchange).await + } +} diff --git a/rs-matter/src/interaction_model/core.rs b/rs-matter/src/interaction_model/core.rs index 678eabf3..772b0411 100644 --- a/rs-matter/src/interaction_model/core.rs +++ b/rs-matter/src/interaction_model/core.rs @@ -18,19 +18,16 @@ use core::time::Duration; use crate::{ - acl::Accessor, error::*, - tlv::{get_root_node_struct, FromTLV, TLVElement, TLVWriter, TagType, ToTLV}, - transport::{exchange::Exchange, packet::Packet}, - utils::epoch::Epoch, + tlv::{FromTLV, TLVArray, TLVElement, TLVWriter, TagType, ToTLV}, + transport::exchange::MessageMeta, + utils::{epoch::Epoch, writebuf::WriteBuf}, }; -use log::error; use num::FromPrimitive; use num_derive::FromPrimitive; -use super::messages::msg::{ - self, InvReq, ReadReq, StatusResp, SubscribeReq, SubscribeResp, TimedReq, WriteReq, -}; +use super::messages::ib::{AttrPath, DataVersionFilter}; +use super::messages::msg::{ReadReq, StatusResp, SubscribeReq, SubscribeResp, TimedReq}; #[macro_export] macro_rules! cmd_enter { @@ -121,613 +118,87 @@ pub enum OpCode { TimedRequest = 10, } -/* Interaction Model ID as per the Matter Spec */ -pub const PROTO_ID_INTERACTION_MODEL: u16 = 0x01; - -// This is the amount of space we reserve for other things to be attached towards -// the end of long reads. -const LONG_READS_TLV_RESERVE_SIZE: usize = 24; - -impl<'a> ReadReq<'a> { - pub fn tx_start<'r, 'p>(&self, tx: &'r mut Packet<'p>) -> Result, Error> { - tx.reset(); - tx.set_proto_id(PROTO_ID_INTERACTION_MODEL); - tx.set_proto_opcode(OpCode::ReportData as u8); - - let mut tw = Self::reserve_long_read_space(tx)?; - - tw.start_struct(TagType::Anonymous)?; - - if self.attr_requests.is_some() { - tw.start_array(TagType::Context(msg::ReportDataTag::AttributeReports as u8))?; - } - - Ok(tw) - } - - pub fn tx_finish_chunk(&self, tx: &mut Packet) -> Result<(), Error> { - self.complete(tx, true) - } - - pub fn tx_finish(&self, tx: &mut Packet) -> Result<(), Error> { - self.complete(tx, false) - } - - fn complete(&self, tx: &mut Packet<'_>, more_chunks: bool) -> Result<(), Error> { - let mut tw = Self::restore_long_read_space(tx)?; - - if self.attr_requests.is_some() { - tw.end_container()?; - } - - if more_chunks { - tw.bool( - TagType::Context(msg::ReportDataTag::MoreChunkedMsgs as u8), - true, - )?; - } - - tw.bool( - TagType::Context(msg::ReportDataTag::SupressResponse as u8), - !more_chunks, - )?; - - tw.end_container() - } - - fn reserve_long_read_space<'p, 'b>(tx: &'p mut Packet<'b>) -> Result, Error> { - let wb = tx.get_writebuf()?; - wb.shrink(LONG_READS_TLV_RESERVE_SIZE)?; - - Ok(TLVWriter::new(wb)) - } - - fn restore_long_read_space<'p, 'b>(tx: &'p mut Packet<'b>) -> Result, Error> { - let wb = tx.get_writebuf()?; - wb.expand(LONG_READS_TLV_RESERVE_SIZE)?; - - Ok(TLVWriter::new(wb)) - } -} - -impl<'a> WriteReq<'a> { - pub fn tx_start<'r, 'p>( - &self, - tx: &'r mut Packet<'p>, - epoch: Epoch, - timeout: Option, - ) -> Result>, Error> { - if has_timed_out(epoch, timeout) { - Interaction::status_response(tx, IMStatusCode::Timeout)?; - - Ok(None) - } else { - tx.reset(); - tx.set_proto_id(PROTO_ID_INTERACTION_MODEL); - tx.set_proto_opcode(OpCode::WriteResponse as u8); - - let mut tw = TLVWriter::new(tx.get_writebuf()?); - - tw.start_struct(TagType::Anonymous)?; - tw.start_array(TagType::Context(msg::WriteRespTag::WriteResponses as u8))?; - - Ok(Some(tw)) - } - } - - pub fn tx_finish(&self, tx: &mut Packet<'_>) -> Result<(), Error> { - let mut tw = TLVWriter::new(tx.get_writebuf()?); - - tw.end_container()?; - tw.end_container() - } -} - -impl<'a> InvReq<'a> { - pub fn tx_start<'r, 'p>( - &self, - tx: &'r mut Packet<'p>, - epoch: Epoch, - timeout: Option, - ) -> Result>, Error> { - if has_timed_out(epoch, timeout) { - Interaction::status_response(tx, IMStatusCode::Timeout)?; - - Ok(None) - } else { - let timed_tx = timeout.map(|_| true); - let timed_request = self.timed_request.filter(|a| *a); - - // Either both should be None, or both should be Some(true) - if timed_tx != timed_request { - Interaction::status_response(tx, IMStatusCode::TimedRequestMisMatch)?; - - Ok(None) - } else { - tx.reset(); - tx.set_proto_id(PROTO_ID_INTERACTION_MODEL); - tx.set_proto_opcode(OpCode::InvokeResponse as u8); - - let mut tw = TLVWriter::new(tx.get_writebuf()?); - - tw.start_struct(TagType::Anonymous)?; - - // Suppress Response -> TODO: Need to revisit this for cases where we send a command back - tw.bool( - TagType::Context(msg::InvRespTag::SupressResponse as u8), - false, - )?; - - if self.inv_requests.is_some() { - tw.start_array(TagType::Context(msg::InvRespTag::InvokeResponses as u8))?; - } - - Ok(Some(tw)) - } +impl OpCode { + pub fn meta(&self) -> MessageMeta { + MessageMeta { + proto_id: PROTO_ID_INTERACTION_MODEL, + proto_opcode: *self as u8, + reliable: true, } } - pub fn tx_finish(&self, tx: &mut Packet<'_>) -> Result<(), Error> { - let mut tw = TLVWriter::new(tx.get_writebuf()?); - - if self.inv_requests.is_some() { - tw.end_container()?; - } - - tw.end_container() + pub fn is_tlv(&self) -> bool { + !matches!(self, Self::Reserved) } } -impl TimedReq { - pub fn timeout(&self, epoch: Epoch) -> Duration { - epoch() - .checked_add(Duration::from_millis(self.timeout as _)) - .unwrap() - } - - pub fn tx_process(self, tx: &mut Packet<'_>, epoch: Epoch) -> Result { - Interaction::status_response(tx, IMStatusCode::Success)?; - - Ok(epoch() - .checked_add(Duration::from_millis(self.timeout as _)) - .unwrap()) +impl From for MessageMeta { + fn from(opcode: OpCode) -> Self { + opcode.meta() } } -impl<'a> SubscribeReq<'a> { - pub fn tx_start<'r, 'p>( - &self, - tx: &'r mut Packet<'p>, - subscription_id: u32, - ) -> Result, Error> { - tx.reset(); - tx.set_proto_id(PROTO_ID_INTERACTION_MODEL); - tx.set_proto_opcode(OpCode::ReportData as u8); - - let mut tw = ReadReq::reserve_long_read_space(tx)?; - - tw.start_struct(TagType::Anonymous)?; - - tw.u32( - TagType::Context(msg::ReportDataTag::SubscriptionId as u8), - subscription_id, - )?; - - if self.attr_requests.is_some() { - tw.start_array(TagType::Context(msg::ReportDataTag::AttributeReports as u8))?; - } - - Ok(tw) - } - - pub fn tx_finish_chunk(&self, tx: &mut Packet<'_>, more_chunks: bool) -> Result<(), Error> { - let mut tw = ReadReq::restore_long_read_space(tx)?; - - if self.attr_requests.is_some() { - tw.end_container()?; - } - - if more_chunks { - tw.bool( - TagType::Context(msg::ReportDataTag::MoreChunkedMsgs as u8), - true, - )?; - } - - tw.bool( - TagType::Context(msg::ReportDataTag::SupressResponse as u8), - false, - )?; - - tw.end_container() - } - - pub fn tx_process_final(&self, tx: &mut Packet, subscription_id: u32) -> Result<(), Error> { - tx.reset(); - tx.set_proto_id(PROTO_ID_INTERACTION_MODEL); - tx.set_proto_opcode(OpCode::SubscribeResponse as u8); - - let mut tw = TLVWriter::new(tx.get_writebuf()?); - - let resp = SubscribeResp::new(subscription_id, 40); - resp.to_tlv(&mut tw, TagType::Anonymous) - } -} +/* Interaction Model ID as per the Matter Spec */ +pub const PROTO_ID_INTERACTION_MODEL: u16 = 0x01; -pub struct ReadDriver<'a, 'r, 'p> { - exchange: &'r mut Exchange<'a>, - tx: &'r mut Packet<'p>, - rx: &'r mut Packet<'p>, - completed: bool, +/// A wrapper enum for `ReadReq` and `SubscribeReq` that allows downstream code to +/// treat the two in a unified manner with regards to `OpCode::ReportDataResp` type responses. +pub enum ReportDataReq<'a> { + Read(&'a ReadReq<'a>), + Subscribe(&'a SubscribeReq<'a>), } -impl<'a, 'r, 'p> ReadDriver<'a, 'r, 'p> { - fn new(exchange: &'r mut Exchange<'a>, tx: &'r mut Packet<'p>, rx: &'r mut Packet<'p>) -> Self { - Self { - exchange, - tx, - rx, - completed: false, +impl<'a> ReportDataReq<'a> { + pub fn attr_requests(&self) -> &Option> { + match self { + ReportDataReq::Read(req) => &req.attr_requests, + ReportDataReq::Subscribe(req) => &req.attr_requests, } } - fn start(&mut self, req: &ReadReq) -> Result<(), Error> { - req.tx_start(self.tx)?; - - Ok(()) - } - - pub fn accessor(&self) -> Result, Error> { - self.exchange.accessor() - } - - pub fn writer(&mut self) -> Result, Error> { - if self.completed { - Err(ErrorCode::Invalid.into()) // TODO - } else { - Ok(TLVWriter::new(self.tx.get_writebuf()?)) + pub fn dataver_filters(&self) -> Option<&TLVArray<'_, DataVersionFilter>> { + match self { + ReportDataReq::Read(req) => req.dataver_filters.as_ref(), + ReportDataReq::Subscribe(req) => req.dataver_filters.as_ref(), } } - pub async fn send_chunk(&mut self, req: &ReadReq<'_>) -> Result { - req.tx_finish_chunk(self.tx)?; - - if exchange_confirm(self.exchange, self.tx, self.rx).await? != IMStatusCode::Success { - self.completed = true; - Ok(false) - } else { - req.tx_start(self.tx)?; - - Ok(true) + pub fn fabric_filtered(&self) -> bool { + match self { + ReportDataReq::Read(req) => req.fabric_filtered, + ReportDataReq::Subscribe(req) => req.fabric_filtered, } } - - pub async fn complete(&mut self, req: &ReadReq<'_>) -> Result<(), Error> { - req.tx_finish(self.tx)?; - - self.exchange.send_complete(self.tx).await - } } -pub struct WriteDriver<'a, 'r, 'p> { - exchange: &'r mut Exchange<'a>, - tx: &'r mut Packet<'p>, - epoch: Epoch, - timeout: Option, -} - -impl<'a, 'r, 'p> WriteDriver<'a, 'r, 'p> { - fn new( - exchange: &'r mut Exchange<'a>, - epoch: Epoch, - timeout: Option, - tx: &'r mut Packet<'p>, - ) -> Self { - Self { - exchange, - tx, - epoch, - timeout, - } - } - - async fn start(&mut self, req: &WriteReq<'_>) -> Result { - if req.tx_start(self.tx, self.epoch, self.timeout)?.is_some() { - Ok(true) - } else { - self.exchange.send_complete(self.tx).await?; - - Ok(false) - } - } - - pub fn accessor(&self) -> Result, Error> { - self.exchange.accessor() - } - - pub fn writer(&mut self) -> Result, Error> { - Ok(TLVWriter::new(self.tx.get_writebuf()?)) - } +impl StatusResp { + pub fn write(wb: &mut WriteBuf, status: IMStatusCode) -> Result<(), Error> { + let mut tw = TLVWriter::new(wb); - pub async fn complete(&mut self, req: &WriteReq<'_>) -> Result<(), Error> { - if !req.supress_response.unwrap_or_default() { - req.tx_finish(self.tx)?; - self.exchange.send_complete(self.tx).await?; - } - - Ok(()) + let status = Self { status }; + status.to_tlv(&mut tw, TagType::Anonymous) } } -pub struct InvokeDriver<'a, 'r, 'p> { - exchange: &'r mut Exchange<'a>, - tx: &'r mut Packet<'p>, - epoch: Epoch, - timeout: Option, -} - -impl<'a, 'r, 'p> InvokeDriver<'a, 'r, 'p> { - fn new( - exchange: &'r mut Exchange<'a>, - epoch: Epoch, - timeout: Option, - tx: &'r mut Packet<'p>, - ) -> Self { - Self { - exchange, - tx, - epoch, - timeout, - } - } - - async fn start(&mut self, req: &InvReq<'_>) -> Result { - if req.tx_start(self.tx, self.epoch, self.timeout)?.is_some() { - Ok(true) - } else { - self.exchange.send_complete(self.tx).await?; - - Ok(false) - } - } - - pub fn accessor(&self) -> Result, Error> { - self.exchange.accessor() - } - - pub fn writer(&mut self) -> Result, Error> { - Ok(TLVWriter::new(self.tx.get_writebuf()?)) - } - - pub fn writer_exchange(&mut self) -> Result<(TLVWriter<'_, 'p>, &Exchange<'a>), Error> { - Ok((TLVWriter::new(self.tx.get_writebuf()?), (self.exchange))) - } - - pub async fn complete(&mut self, req: &InvReq<'_>) -> Result<(), Error> { - if !req.suppress_response.unwrap_or_default() { - req.tx_finish(self.tx)?; - self.exchange.send_complete(self.tx).await?; - } - - Ok(()) +impl TimedReq { + pub fn timeout_instant(&self, epoch: Epoch) -> Duration { + epoch() + .checked_add(Duration::from_millis(self.timeout as _)) + .unwrap() } } -pub struct SubscribeDriver<'a, 'r, 'p> { - exchange: &'r mut Exchange<'a>, - tx: &'r mut Packet<'p>, - rx: &'r mut Packet<'p>, - subscription_id: u32, - completed: bool, -} - -impl<'a, 'r, 'p> SubscribeDriver<'a, 'r, 'p> { - fn new( - exchange: &'r mut Exchange<'a>, +impl SubscribeResp { + pub fn write<'a>( + wb: &'a mut WriteBuf, subscription_id: u32, - tx: &'r mut Packet<'p>, - rx: &'r mut Packet<'p>, - ) -> Self { - Self { - exchange, - tx, - rx, - subscription_id, - completed: false, - } - } - - fn start(&mut self, req: &SubscribeReq) -> Result<(), Error> { - req.tx_start(self.tx, self.subscription_id)?; - - Ok(()) - } - - pub fn accessor(&self) -> Result, Error> { - self.exchange.accessor() - } - - pub fn writer(&mut self) -> Result, Error> { - if self.completed { - Err(ErrorCode::Invalid.into()) // TODO - } else { - Ok(TLVWriter::new(self.tx.get_writebuf()?)) - } - } - - pub async fn send_chunk(&mut self, req: &SubscribeReq<'_>) -> Result { - req.tx_finish_chunk(self.tx, true)?; - - if exchange_confirm(self.exchange, self.tx, self.rx).await? != IMStatusCode::Success { - self.completed = true; - Ok(false) - } else { - req.tx_start(self.tx, self.subscription_id)?; - - Ok(true) - } - } - - pub async fn complete(&mut self, req: &SubscribeReq<'_>) -> Result<(), Error> { - if !self.completed { - req.tx_finish_chunk(self.tx, false)?; - - if exchange_confirm(self.exchange, self.tx, self.rx).await? != IMStatusCode::Success { - self.completed = true; - } else { - req.tx_process_final(self.tx, self.subscription_id)?; - self.exchange.send_complete(self.tx).await?; - } - } - - Ok(()) - } -} + max_int: u16, + ) -> Result<&'a [u8], Error> { + let mut tw = TLVWriter::new(wb); -pub enum Interaction<'a, 'r, 'p> { - Read { - req: ReadReq<'r>, - driver: ReadDriver<'a, 'r, 'p>, - }, - Write { - req: WriteReq<'r>, - driver: WriteDriver<'a, 'r, 'p>, - }, - Invoke { - req: InvReq<'r>, - driver: InvokeDriver<'a, 'r, 'p>, - }, - Subscribe { - req: SubscribeReq<'r>, - driver: SubscribeDriver<'a, 'r, 'p>, - }, -} - -impl<'a, 'r, 'p> Interaction<'a, 'r, 'p> { - pub async fn timeout( - exchange: &mut Exchange<'_>, - rx: &mut Packet<'_>, - tx: &mut Packet<'_>, - ) -> Result, Error> { - let epoch = exchange.matter.epoch; - - let mut opcode: OpCode = rx.get_proto_opcode()?; - - let mut timeout = None; - - while opcode == OpCode::TimedRequest { - let rx_data = rx.as_slice(); - let req = TimedReq::from_tlv(&get_root_node_struct(rx_data)?)?; - - timeout = Some(req.tx_process(tx, epoch)?); + let resp = Self::new(subscription_id, max_int); + resp.to_tlv(&mut tw, TagType::Anonymous)?; - exchange.exchange(tx, rx).await?; - - opcode = rx.get_proto_opcode()?; - } - - Ok(timeout) + Ok(wb.as_slice()) } - - #[inline(always)] - pub fn new( - exchange: &'r mut Exchange<'a>, - rx: &'r Packet<'p>, - tx: &'r mut Packet<'p>, - rx_status: &'r mut Packet<'p>, - subscription_id: S, - timeout: Option, - ) -> Result, Error> - where - S: FnOnce() -> u32, - { - let epoch = exchange.matter.epoch; - - let opcode = rx.get_proto_opcode()?; - let rx_data = rx.as_slice(); - - match opcode { - OpCode::ReadRequest => { - let req = ReadReq::from_tlv(&get_root_node_struct(rx_data)?)?; - let driver = ReadDriver::new(exchange, tx, rx_status); - - Ok(Self::Read { req, driver }) - } - OpCode::WriteRequest => { - let req = WriteReq::from_tlv(&get_root_node_struct(rx_data)?)?; - let driver = WriteDriver::new(exchange, epoch, timeout, tx); - - Ok(Self::Write { req, driver }) - } - OpCode::InvokeRequest => { - let req = InvReq::from_tlv(&get_root_node_struct(rx_data)?)?; - let driver = InvokeDriver::new(exchange, epoch, timeout, tx); - - Ok(Self::Invoke { req, driver }) - } - OpCode::SubscribeRequest => { - let req = SubscribeReq::from_tlv(&get_root_node_struct(rx_data)?)?; - let driver = SubscribeDriver::new(exchange, subscription_id(), tx, rx_status); - - Ok(Self::Subscribe { req, driver }) - } - _ => { - error!("Opcode not handled: {:?}", opcode); - Err(ErrorCode::InvalidOpcode.into()) - } - } - } - - pub async fn start(&mut self) -> Result { - let started = match self { - Self::Read { req, driver } => { - driver.start(req)?; - true - } - Self::Write { req, driver } => driver.start(req).await?, - Self::Invoke { req, driver } => driver.start(req).await?, - Self::Subscribe { req, driver } => { - driver.start(req)?; - true - } - }; - - Ok(started) - } - - fn status_response(tx: &mut Packet, status: IMStatusCode) -> Result<(), Error> { - tx.reset(); - tx.set_proto_id(PROTO_ID_INTERACTION_MODEL); - tx.set_proto_opcode(OpCode::StatusResponse as u8); - - let mut tw = TLVWriter::new(tx.get_writebuf()?); - - let status = StatusResp { status }; - status.to_tlv(&mut tw, TagType::Anonymous) - } -} - -async fn exchange_confirm( - exchange: &mut Exchange<'_>, - tx: &mut Packet<'_>, - rx: &mut Packet<'_>, -) -> Result { - exchange.exchange(tx, rx).await?; - - let opcode: OpCode = rx.get_proto_opcode()?; - - if opcode == OpCode::StatusResponse { - let resp = StatusResp::from_tlv(&get_root_node_struct(rx.as_slice())?)?; - Ok(resp.status) - } else { - Interaction::status_response(tx, IMStatusCode::Busy)?; // TODO - - exchange.send_complete(tx).await?; - - Err(ErrorCode::Invalid.into()) // TODO - } -} - -fn has_timed_out(epoch: Epoch, timeout: Option) -> bool { - timeout.map(|timeout| epoch() > timeout).unwrap_or(false) } diff --git a/rs-matter/src/interaction_model/messages.rs b/rs-matter/src/interaction_model/messages.rs index 44c26e4d..a1372965 100644 --- a/rs-matter/src/interaction_model/messages.rs +++ b/rs-matter/src/interaction_model/messages.rs @@ -125,7 +125,7 @@ pub mod msg { } } - #[derive(FromTLV, ToTLV)] + #[derive(FromTLV, ToTLV, Debug)] pub struct TimedReq { pub timeout: u16, } @@ -141,7 +141,7 @@ pub mod msg { InvokeRequests = 2, } - #[derive(FromTLV, ToTLV)] + #[derive(FromTLV, ToTLV, Debug)] #[tlvargs(lifetime = "'a")] pub struct InvReq<'a> { pub suppress_response: Option, @@ -191,9 +191,9 @@ pub mod msg { #[tlvargs(lifetime = "'a")] pub struct WriteReq<'a> { pub supress_response: Option, - timed_request: Option, + pub timed_request: Option, pub write_requests: TLVArray<'a, AttrData<'a>>, - more_chunked: Option, + pub more_chunked: Option, } impl<'a> WriteReq<'a> { diff --git a/rs-matter/src/interaction_model/mod.rs b/rs-matter/src/interaction_model/mod.rs index 22e0ee96..42bfc2f1 100644 --- a/rs-matter/src/interaction_model/mod.rs +++ b/rs-matter/src/interaction_model/mod.rs @@ -15,5 +15,6 @@ * limitations under the License. */ +pub mod busy; pub mod core; pub mod messages; diff --git a/rs-matter/src/lib.rs b/rs-matter/src/lib.rs index 86f74747..aa12f3a6 100644 --- a/rs-matter/src/lib.rs +++ b/rs-matter/src/lib.rs @@ -85,6 +85,7 @@ pub mod interaction_model; pub mod mdns; pub mod pairing; pub mod persist; +pub mod respond; pub mod secure_channel; pub mod tlv; pub mod transport; diff --git a/rs-matter/src/mdns/builtin.rs b/rs-matter/src/mdns/builtin.rs index b643c443..22b3b7e2 100644 --- a/rs-matter/src/mdns/builtin.rs +++ b/rs-matter/src/mdns/builtin.rs @@ -12,10 +12,7 @@ use crate::transport::network::{ Address, Ipv4Addr, Ipv6Addr, NetworkReceive, NetworkSend, SocketAddr, SocketAddrV4, SocketAddrV6, }; -use crate::utils::{ - buf::BufferAccess, - select::{EitherUnwrap, Notification}, -}; +use crate::utils::{buf::BufferAccess, notification::Notification, select::Coalesce}; use super::{Service, ServiceMode}; @@ -38,7 +35,7 @@ pub struct MdnsImpl<'a> { dev_det: &'a BasicInfoConfig<'a>, matter_port: u16, services: RefCell, ServiceMode), 4>>, - notification: Notification, + notification: Notification, } impl<'a> MdnsImpl<'a> { @@ -64,7 +61,7 @@ impl<'a> MdnsImpl<'a> { .push((service.try_into().unwrap(), mode)) .map_err(|_| ErrorCode::NoSpace)?; - self.notification.signal(()); + self.notification.notify(); Ok(()) } @@ -74,7 +71,7 @@ impl<'a> MdnsImpl<'a> { services.retain(|(name, _)| name != service); - self.notification.signal(()); + self.notification.notify(); Ok(()) } @@ -106,15 +103,15 @@ impl<'a> MdnsImpl<'a> { where S: NetworkSend, R: NetworkReceive, - SB: BufferAccess, - RB: BufferAccess, + SB: BufferAccess<[u8]>, + RB: BufferAccess<[u8]>, { let send = Mutex::::new(send); let mut broadcast = pin!(self.broadcast(&send, &tx_buf, &host, interface)); let mut respond = pin!(self.respond(&send, recv, &tx_buf, &rx_buf, &host, interface)); - select(&mut broadcast, &mut respond).await.unwrap() + select(&mut broadcast, &mut respond).coalesce().await } async fn broadcast( @@ -126,14 +123,13 @@ impl<'a> MdnsImpl<'a> { ) -> Result<(), Error> where S: NetworkSend, - B: BufferAccess, + B: BufferAccess<[u8]>, { loop { - select( - self.notification.wait(), - Timer::after(Duration::from_secs(30)), - ) - .await; + let mut notification = pin!(self.notification.wait()); + let mut timeout = pin!(Timer::after(Duration::from_secs(30))); + + select(&mut notification, &mut timeout).await; for addr in core::iter::once(SocketAddr::V4(SocketAddrV4::new( MDNS_IPV4_BROADCAST_ADDR, @@ -151,7 +147,7 @@ impl<'a> MdnsImpl<'a> { }) .into_iter(), ) { - let mut buf = buffer.get().await; + let mut buf = buffer.get().await.ok_or(ErrorCode::NoSpace)?; let mut send = send.lock().await; let len = host.broadcast(self, &mut buf, 60)?; @@ -176,17 +172,17 @@ impl<'a> MdnsImpl<'a> { where S: NetworkSend, R: NetworkReceive, - SB: BufferAccess, - RB: BufferAccess, + SB: BufferAccess<[u8]>, + RB: BufferAccess<[u8]>, { loop { recv.wait_available().await?; { - let mut rx = rx_buf.get().await; + let mut rx = rx_buf.get().await.ok_or(ErrorCode::NoSpace)?; let (len, addr) = recv.recv_from(&mut rx).await?; - let mut tx = tx_buf.get().await; + let mut tx = tx_buf.get().await.ok_or(ErrorCode::NoSpace)?; let mut send = send.lock().await; let len = match host.respond(self, &rx[..len], &mut tx, 60) { diff --git a/rs-matter/src/respond.rs b/rs-matter/src/respond.rs new file mode 100644 index 00000000..794b4521 --- /dev/null +++ b/rs-matter/src/respond.rs @@ -0,0 +1,298 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use core::fmt::Display; +use core::pin::pin; + +use embassy_futures::select::{select3, select_slice}; + +use log::{error, info}; + +use crate::data_model::core::{DataModel, IMBuffer}; +use crate::data_model::objects::DataModelHandler; +use crate::data_model::subscriptions::Subscriptions; +use crate::error::Error; +use crate::interaction_model::busy::BusyInteractionModel; +use crate::interaction_model::core::PROTO_ID_INTERACTION_MODEL; +use crate::secure_channel::busy::BusySecureChannel; +use crate::secure_channel::core::SecureChannel; +use crate::transport::exchange::Exchange; +use crate::utils::buf::BufferAccess; +use crate::utils::select::Coalesce; +use crate::Matter; + +/// A trait modeling a generic handler for an exchange. +pub trait ExchangeHandler { + async fn handle(&self, exchange: &mut Exchange<'_>) -> Result<(), Error>; +} + +impl ExchangeHandler for &T +where + T: ExchangeHandler, +{ + async fn handle(&self, exchange: &mut Exchange<'_>) -> Result<(), Error> { + (*self).handle(exchange).await + } +} + +/// A struct for chaining two exchange handlers into a single one, +/// where each handler is handling one specific protocol (i.e. SC vs IM) in a sequential fashion. +/// I.e. if the first exchange handler refuses to handle the exchange, the second one is tried. +pub struct ChainedExchangeHandler { + pub handler_proto: u16, + pub handler: H, + pub next: T, +} + +impl ChainedExchangeHandler { + /// Construct a chained handler that works as follows: + /// - It will call the provided `handler` instance if the protocol ID of the incoming message does match the supplied `handler_proto` value. + /// - Otherwise, it will call the `next` handler + pub const fn new(handler_proto: u16, handler: H, next: T) -> Self { + Self { + handler_proto, + handler, + next, + } + } + + /// Chain itself with another exchange handler. + /// + /// The returned chained handler works as follows: + /// - It will call the provided `handler` instance if the protocol ID of the incoming message does match the supplied `handler_proto` value. + /// - Otherwise, it will call the `self` handler + pub const fn chain

( + self, + handler_proto: u16, + handler: H2, + ) -> ChainedExchangeHandler { + ChainedExchangeHandler::new(handler_proto, handler, self) + } +} + +impl ExchangeHandler for ChainedExchangeHandler +where + H: ExchangeHandler, + T: ExchangeHandler, +{ + async fn handle(&self, exchange: &mut Exchange<'_>) -> Result<(), Error> { + let rx = exchange.recv_fetch().await?; + + if rx.meta().proto_id == self.handler_proto { + self.handler.handle(exchange).await + } else { + self.next.handle(exchange).await + } + } +} + +/// A generic responder utility for accepting and handling exchanges received by the provided `Matter` stack, +/// by applying the provided `ExchangeHandler` instance to each accepted exchange. +/// +/// This responder uses an intra-task concurrency model - without an external executor - where all handling is done as a single future. +pub struct Responder<'a, T> { + name: &'a str, + handler: T, + matter: &'a Matter<'a>, + respond_after_ms: u64, +} + +impl<'a, T> Responder<'a, T> +where + T: ExchangeHandler, +{ + /// Create a new responder. + /// + /// The `respond_after_ms` parameter instructs the responder how much time to wait before accepting an exchange. + /// + /// This is useful when utilizing multiple responders on a single `Matter` instance, where e.g. the first (main) responder is the actual one, + /// responsible for handling the incoming exchanges, while e.g. another one - with a non-zero `respond_after_ms` - is answerring all exchanges + /// not accepted in time by the main responder with a simple "I'm busy, try again later" handling. + #[inline(always)] + pub const fn new( + name: &'a str, + handler: T, + matter: &'a Matter<'a>, + respond_after_ms: u64, + ) -> Self { + Self { + name, + handler, + matter, + respond_after_ms, + } + } + + /// Get a reference to the `ExchangeHandler` instance used by this responder + pub fn handler(&self) -> &T { + &self.handler + } + + /// Run the responder with a given number of handlers. + pub async fn run(&self) -> Result<(), Error> { + info!("{}: Creating {N} handlers", self.name); + + let mut handlers = heapless::Vec::<_, N>::new(); + info!( + "{}: Handlers size: {}B", + self.name, + core::mem::size_of_val(&handlers) + ); + + for handler_id in 0..N { + handlers + .push(self.handle(handler_id)) + .map_err(|_| ()) + .unwrap(); // Cannot fail because the vector has size N + } + + select_slice(&mut handlers).await.0 + } + + #[inline(always)] + async fn handle(&self, handler_id: impl Display) -> Result<(), Error> { + loop { + // Ignore the error as it had been logged already + let _ = self.respond_once(&handler_id).await; + } + } + + /// Respond to a single exchange. + /// Useful in e.g. integration tests, where we know that we are expecting to respond to a single exchange within the run of the test. + #[inline(always)] + pub async fn respond_once(&self, handler_id: impl Display) -> Result<(), Error> { + let mut exchange = Exchange::accept_after(self.matter, self.respond_after_ms).await?; + + info!( + "{}: Handler {handler_id} / exchange {}: Starting", + self.name, + exchange.id() + ); + + let result = self.handler.handle(&mut exchange).await; + + if let Err(err) = &result { + error!( + "{}: Handler {handler_id} / exchange {}: Abandoned because of error {err:?}", + self.name, + exchange.id() + ); + } else { + info!( + "{}: Handler {handler_id} / exchange {}: Completed", + self.name, + exchange.id() + ); + } + + result + } +} + +impl<'a, const N: usize, B, T> + Responder<'a, ChainedExchangeHandler, SecureChannel>> +where + B: BufferAccess, +{ + /// Creates a "default" responder. This is a responder that composes and uses the `rs-matter`-provided `ExchangeHandler` implementations + /// (`SecureChannel` and `DataModel`) for handling the Secure Channel protocol and the Interaction Model protocol. + #[inline(always)] + pub const fn new_default( + matter: &'a Matter<'a>, + buffers: &'a B, + subscriptions: &'a Subscriptions, + dm_handler: T, + ) -> Self + where + T: DataModelHandler, + { + Self::new( + "Responder", + ChainedExchangeHandler::new( + PROTO_ID_INTERACTION_MODEL, + DataModel::new(buffers, subscriptions, dm_handler), + SecureChannel::new(), + ), + matter, + 0, + ) + } +} + +impl<'a> Responder<'a, ChainedExchangeHandler> { + /// Creates a simple "busy" responder, which is answering all exchanges with a simple "I'm busy, try again later" handling. + /// The resonder is using the `rs-matter`-provided `ExchangeHandler` instances (`BusySecureChannel` and `BusyInteractionModel`) + /// capable of answering with "busy" messages the SC and IM protocols, respectively. + /// + /// Exchanges which are not accepted after 200ms are answered by this responder, as the assumption is that the main responder is + /// busy and cannot answer these right now. + #[inline(always)] + pub const fn new_busy(matter: &'a Matter<'a>) -> Self { + Self::new( + "Busy Responder", + ChainedExchangeHandler::new( + PROTO_ID_INTERACTION_MODEL, + BusyInteractionModel::new(), + BusySecureChannel::new(), + ), + matter, + 200, + ) + } +} + +/// A composition of the `Responder::new_default` and `Responder::new_busy` responders. +pub struct DefaultResponder<'a, const N: usize, B, T> +where + B: BufferAccess, +{ + responder: Responder<'a, ChainedExchangeHandler, SecureChannel>>, + busy_responder: Responder<'a, ChainedExchangeHandler>, +} + +impl<'a, const N: usize, B, T> DefaultResponder<'a, N, B, T> +where + B: BufferAccess, + T: DataModelHandler, +{ + /// Creates the responder composition. + #[inline(always)] + pub const fn new( + matter: &'a Matter<'a>, + buffers: &'a B, + subscriptions: &'a Subscriptions, + dm_handler: T, + ) -> Self { + Self { + responder: Responder::new_default(matter, buffers, subscriptions, dm_handler), + busy_responder: Responder::new_busy(matter), + } + } + + /// Run the responder. + pub async fn run(&self) -> Result<(), Error> { + let mut actual = pin!(self.responder.run::()); + let mut busy = pin!(self.busy_responder.run::()); + let mut sub = pin!(self + .responder + .handler() + .handler + .process_subscriptions(self.responder.matter)); + + select3(&mut actual, &mut busy, &mut sub).coalesce().await + } +} diff --git a/rs-matter/src/secure_channel/busy.rs b/rs-matter/src/secure_channel/busy.rs new file mode 100644 index 00000000..8a8421f2 --- /dev/null +++ b/rs-matter/src/secure_channel/busy.rs @@ -0,0 +1,81 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use log::error; + +use crate::error::*; +use crate::respond::ExchangeHandler; +use crate::transport::exchange::Exchange; + +use super::common::{sc_write, OpCode, SCStatusCodes, PROTO_ID_SECURE_CHANNEL}; + +/// A Secure Channel implementation that is only capable of sending Busy status codes +/// +/// Use with e.g. +/// +/// ```ignore +/// let matter = Matter::new(...); +/// +/// // ... +/// +/// let busy_responder = Responder::new("SC Busy Responder", BusySecureChannel::new(), &matter, 200/*ms*/); +/// busy_responder.run::<10>().await?; +/// ``` +/// +/// ... to respond with "I'm busy, please try later" status code to all incoming Secure Channel messages, which were +/// not accepted in time by the actual Secure Channel responder, due to all its handlers being occupied with work. +pub struct BusySecureChannel(()); + +impl BusySecureChannel { + const BUSY_RETRY_DELAY_MS: u16 = 500; + + #[inline(always)] + pub const fn new() -> Self { + Self(()) + } + + pub async fn handle(&self, exchange: &mut Exchange<'_>) -> Result<(), Error> { + let meta = exchange.recv().await?.meta(); + if meta.proto_id != PROTO_ID_SECURE_CHANNEL { + Err(ErrorCode::InvalidProto)?; + } + + match meta.opcode()? { + OpCode::PBKDFParamRequest | OpCode::CASESigma1 => { + exchange + .send_with(|_, wb| { + sc_write( + wb, + SCStatusCodes::Busy, + &u16::to_le_bytes(Self::BUSY_RETRY_DELAY_MS), + ) + }) + .await + } + proto_opcode => { + error!("OpCode not handled: {:?}", proto_opcode); + Err(ErrorCode::InvalidOpcode.into()) + } + } + } +} + +impl ExchangeHandler for BusySecureChannel { + async fn handle(&self, exchange: &mut Exchange<'_>) -> Result<(), Error> { + BusySecureChannel::handle(self, exchange).await + } +} diff --git a/rs-matter/src/secure_channel/case.rs b/rs-matter/src/secure_channel/case.rs index 49d527fc..38f95cc5 100644 --- a/rs-matter/src/secure_channel/case.rs +++ b/rs-matter/src/secure_channel/case.rs @@ -23,23 +23,20 @@ use crate::{ crypto::{self, KeyPair, Sha256}, error::{Error, ErrorCode}, fabric::Fabric, - secure_channel::common::{self, OpCode, PROTO_ID_SECURE_CHANNEL}, - secure_channel::common::{complete_with_status, SCStatusCodes}, + secure_channel::common::{complete_with_status, OpCode, SCStatusCodes}, tlv::{get_root_node_struct, FromTLV, OctetStr, TLVWriter, TagType}, transport::{ exchange::Exchange, - network::Address, - packet::Packet, - session::{CaseDetails, CloneData, NocCatIds, SessionMode}, + session::{CaseDetails, NocCatIds, ReservedSession, SessionMode}, }, utils::{rand::Rand, writebuf::WriteBuf}, }; #[derive(Debug, Clone)] -struct CaseSession { +pub struct CaseSession { peer_sessid: u16, local_sessid: u16, - tt_hash: Sha256, + tt_hash: Option, shared_secret: [u8; crypto::ECDH_SHARED_SECRET_LEN_BYTES], our_pub_key: [u8; crypto::EC_POINT_LEN_BYTES], peer_pub_key: [u8; crypto::EC_POINT_LEN_BYTES], @@ -48,16 +45,16 @@ struct CaseSession { impl CaseSession { #[inline(always)] - pub fn new() -> Result { - Ok(Self { + pub const fn new() -> Self { + Self { peer_sessid: 0, local_sessid: 0, - tt_hash: Sha256::new()?, + tt_hash: None, shared_secret: [0; crypto::ECDH_SHARED_SECRET_LEN_BYTES], our_pub_key: [0; crypto::EC_POINT_LEN_BYTES], peer_pub_key: [0; crypto::EC_POINT_LEN_BYTES], local_fabric_idx: 0, - }) + } } } @@ -72,31 +69,37 @@ impl Case { pub async fn handle( &mut self, exchange: &mut Exchange<'_>, - rx: &mut Packet<'_>, - tx: &mut Packet<'_>, + case_session: &mut CaseSession, ) -> Result<(), Error> { - let mut session = alloc!(CaseSession::new()?); + let session = ReservedSession::reserve(exchange.matter()).await?; + + self.handle_casesigma1(exchange, case_session).await?; + + exchange.recv_fetch().await?; - self.handle_casesigma1(exchange, rx, tx, &mut session) + self.handle_casesigma3(exchange, case_session, session) .await?; - self.handle_casesigma3(exchange, rx, tx, &mut session).await + + exchange.acknowledge().await?; + exchange.matter().notify_changed(); + + Ok(()) } async fn handle_casesigma3( &mut self, exchange: &mut Exchange<'_>, - rx: &Packet<'_>, - tx: &mut Packet<'_>, case_session: &mut CaseSession, + mut session: ReservedSession<'_>, ) -> Result<(), Error> { - rx.check_proto_opcode(OpCode::CASESigma3 as _)?; + exchange.rx()?.meta().check_opcode(OpCode::CASESigma3)?; - let result = { - let fabric_mgr = exchange.matter.fabric_mgr.borrow(); + let status = { + let fabric_mgr = exchange.matter().fabric_mgr.borrow(); let fabric = fabric_mgr.get_fabric(case_session.local_fabric_idx)?; if let Some(fabric) = fabric { - let root = get_root_node_struct(rx.as_slice())?; + let root = get_root_node_struct(exchange.rx()?.payload())?; let encrypted = root.find_tag(1)?.slice()?; let mut decrypted = alloc!([0; 800]); @@ -128,7 +131,7 @@ impl Case { if let Err(e) = Case::validate_certs(fabric, &initiator_noc, initiator_icac_mut) { error!("Certificate Chain doesn't match: {}", e); - Err(SCStatusCodes::InvalidParameter) + SCStatusCodes::InvalidParameter } else if let Err(e) = Case::validate_sigma3_sign( d.initiator_noc.0, d.initiator_icac.map(|a| a.0), @@ -137,73 +140,90 @@ impl Case { case_session, ) { error!("Sigma3 Signature doesn't match: {}", e); - Err(SCStatusCodes::InvalidParameter) + SCStatusCodes::InvalidParameter } else { // Only now do we add this message to the TT Hash let mut peer_catids: NocCatIds = Default::default(); initiator_noc.get_cat_ids(&mut peer_catids); - case_session.tt_hash.update(rx.as_slice())?; - - Ok(Case::get_session_clone_data( + case_session + .tt_hash + .as_mut() + .unwrap() + .update(exchange.rx()?.payload())?; + + let mut session_keys = [0_u8; 3 * crypto::SYMM_KEY_LEN_BYTES]; + Case::get_session_keys( fabric.ipk.op_key(), + case_session.tt_hash.as_ref().unwrap(), + &case_session.shared_secret, + &mut session_keys, + )?; + + let peer_addr = exchange.with_session(|sess| Ok(sess.get_peer_addr()))?; + + session.update( fabric.get_node_id(), initiator_noc.get_node_id()?, - exchange.with_session(|sess| Ok(sess.get_peer_addr()))?, - case_session, - &peer_catids, - )?) + case_session.peer_sessid, + case_session.local_sessid, + peer_addr, + SessionMode::Case(CaseDetails::new( + case_session.local_fabric_idx as u8, + &peer_catids, + )), + Some(&session_keys[0..16]), + Some(&session_keys[16..32]), + Some(&session_keys[32..48]), + )?; + + session.complete(); + + SCStatusCodes::SessionEstablishmentSuccess } } else { - Err(SCStatusCodes::NoSharedTrustRoots) - } - }; - - let status = match result { - Ok(clone_data) => { - exchange.clone_session(tx, &clone_data).await?; - SCStatusCodes::SessionEstablishmentSuccess + SCStatusCodes::NoSharedTrustRoots } - Err(status) => status, }; - complete_with_status(exchange, tx, status, None).await + complete_with_status(exchange, status, &[]).await } async fn handle_casesigma1( &mut self, exchange: &mut Exchange<'_>, - rx: &mut Packet<'_>, - tx: &mut Packet<'_>, case_session: &mut CaseSession, ) -> Result<(), Error> { - rx.check_proto_opcode(OpCode::CASESigma1 as _)?; + exchange.rx()?.meta().check_opcode(OpCode::CASESigma1)?; - let rx_buf = rx.as_slice(); - let root = get_root_node_struct(rx_buf)?; + let root = get_root_node_struct(exchange.rx()?.payload())?; let r = Sigma1Req::from_tlv(&root)?; let local_fabric_idx = exchange - .matter + .matter() .fabric_mgr .borrow_mut() .match_dest_id(r.initiator_random.0, r.dest_id.0); if local_fabric_idx.is_err() { error!("Fabric Index mismatch"); - complete_with_status( - exchange, - tx, - common::SCStatusCodes::NoSharedTrustRoots, - None, - ) - .await?; + complete_with_status(exchange, SCStatusCodes::NoSharedTrustRoots, &[]).await?; return Ok(()); } - let local_sessid = exchange.get_next_sess_id(); + let local_sessid = exchange + .matter() + .transport_mgr + .session_mgr + .borrow_mut() + .get_next_sess_id(); case_session.peer_sessid = r.initiator_sessid; case_session.local_sessid = local_sessid; - case_session.tt_hash.update(rx_buf)?; + case_session.tt_hash = Some(Sha256::new()?); + case_session + .tt_hash + .as_mut() + .unwrap() + .update(exchange.rx()?.payload())?; case_session.local_fabric_idx = local_fabric_idx?; if r.peer_pub_key.0.len() != crypto::EC_POINT_LEN_BYTES { error!("Invalid public key length"); @@ -216,7 +236,7 @@ impl Case { ); // Create an ephemeral Key Pair - let key_pair = KeyPair::new(exchange.matter.rand)?; + let key_pair = KeyPair::new(exchange.matter().rand)?; let _ = key_pair.get_public_key(&mut case_session.our_pub_key)?; // Derive the Shared Secret @@ -228,7 +248,7 @@ impl Case { // println!("Derived secret: {:x?} len: {}", secret, len); let mut our_random: [u8; 32] = [0; 32]; - (exchange.matter.rand)(&mut our_random); + (exchange.matter().rand)(&mut our_random); // Derive the Encrypted Part const MAX_ENCRYPTED_SIZE: usize = 800; @@ -236,8 +256,8 @@ impl Case { let mut encrypted = alloc!([0; MAX_ENCRYPTED_SIZE]); let mut signature = alloc!([0u8; crypto::EC_SIGNATURE_LEN_BYTES]); - let fabric_found = { - let fabric_mgr = exchange.matter.fabric_mgr.borrow(); + let encrypted_len = { + let fabric_mgr = exchange.matter().fabric_mgr.borrow(); let fabric = fabric_mgr.get_fabric(case_session.local_fabric_idx)?; if let Some(fabric) = fabric { @@ -263,85 +283,50 @@ impl Case { let encrypted_len = Case::get_sigma2_encryption( fabric, - exchange.matter.rand, + exchange.matter().rand, &our_random, case_session, signature, encrypted_mut, )?; - let encrypted = &encrypted[0..encrypted_len]; - - // Generate our Response Body - tx.reset(); - tx.set_proto_id(PROTO_ID_SECURE_CHANNEL); - tx.set_proto_opcode(OpCode::CASESigma2 as u8); - - let mut tw = TLVWriter::new(tx.get_writebuf()?); - tw.start_struct(TagType::Anonymous)?; - tw.str8(TagType::Context(1), &our_random)?; - tw.u16(TagType::Context(2), local_sessid)?; - tw.str8(TagType::Context(3), &case_session.our_pub_key)?; - tw.str16(TagType::Context(4), encrypted)?; - tw.end_container()?; - - case_session.tt_hash.update(tx.as_mut_slice())?; - - true + Some(encrypted_len) } else { - false + None } }; - if fabric_found { - exchange.exchange(tx, rx).await + if let Some(encrypted_len) = encrypted_len { + let mut hash_updated = false; + let encrypted = &encrypted[0..encrypted_len]; + + exchange + .send_with(|_, wb| { + let mut tw = TLVWriter::new(wb); + tw.start_struct(TagType::Anonymous)?; + tw.str8(TagType::Context(1), &our_random)?; + tw.u16(TagType::Context(2), local_sessid)?; + tw.str8(TagType::Context(3), &case_session.our_pub_key)?; + tw.str16(TagType::Context(4), encrypted)?; + tw.end_container()?; + + if !hash_updated { + case_session + .tt_hash + .as_mut() + .unwrap() + .update(wb.as_mut_slice())?; + hash_updated = true; + } + + Ok(Some(OpCode::CASESigma2.into())) + }) + .await } else { - complete_with_status( - exchange, - tx, - common::SCStatusCodes::NoSharedTrustRoots, - None, - ) - .await + complete_with_status(exchange, SCStatusCodes::NoSharedTrustRoots, &[]).await } } - fn get_session_clone_data( - ipk: &[u8], - local_nodeid: u64, - peer_nodeid: u64, - peer_addr: Address, - case_session: &CaseSession, - peer_catids: &NocCatIds, - ) -> Result { - let mut session_keys = [0_u8; 3 * crypto::SYMM_KEY_LEN_BYTES]; - Case::get_session_keys( - ipk, - &case_session.tt_hash, - &case_session.shared_secret, - &mut session_keys, - )?; - - let mut clone_data = CloneData::new( - local_nodeid, - peer_nodeid, - case_session.peer_sessid, - case_session.local_sessid, - peer_addr, - SessionMode::Case(CaseDetails::new( - case_session.local_fabric_idx as u8, - peer_catids, - )), - ); - - clone_data.dec_key.copy_from_slice(&session_keys[0..16]); - clone_data.enc_key.copy_from_slice(&session_keys[16..32]); - clone_data - .att_challenge - .copy_from_slice(&session_keys[32..48]); - Ok(clone_data) - } - fn validate_sigma3_sign( initiator_noc: &[u8], initiator_icac: Option<&[u8]>, @@ -425,7 +410,7 @@ impl Case { let mut sigma3_key = [0_u8; crypto::SYMM_KEY_LEN_BYTES]; Case::get_sigma3_key( ipk, - &case_session.tt_hash, + case_session.tt_hash.as_ref().unwrap(), &case_session.shared_secret, &mut sigma3_key, )?; @@ -482,7 +467,7 @@ impl Case { salt.extend_from_slice(our_random).unwrap(); salt.extend_from_slice(&case_session.our_pub_key).unwrap(); - let tt = case_session.tt_hash.clone(); + let tt = case_session.tt_hash.as_ref().unwrap().clone(); let mut tt_hash = [0u8; crypto::SHA256_HASH_LEN_BYTES]; tt.finish(&mut tt_hash)?; @@ -572,6 +557,12 @@ impl Case { } } +impl Default for Case { + fn default() -> Self { + Self::new() + } +} + #[derive(FromTLV)] #[tlvargs(start = 1, lifetime = "'a")] struct Sigma1Req<'a> { diff --git a/rs-matter/src/secure_channel/common.rs b/rs-matter/src/secure_channel/common.rs index eeb1e4a7..4d4d3921 100644 --- a/rs-matter/src/secure_channel/common.rs +++ b/rs-matter/src/secure_channel/common.rs @@ -17,12 +17,11 @@ use num_derive::FromPrimitive; -use crate::{ - error::Error, - transport::{exchange::Exchange, packet::Packet}, -}; +use crate::error::Error; +use crate::transport::exchange::{Exchange, MessageMeta}; +use crate::utils::writebuf::WriteBuf; -use super::status_report::{create_status_report, GeneralCode}; +use super::status_report::{GeneralCode, StatusReport}; /* Interaction Model ID as per the Matter Spec */ pub const PROTO_ID_SECURE_CHANNEL: u16 = 0x00; @@ -44,7 +43,33 @@ pub enum OpCode { StatusReport = 0x40, } -#[derive(PartialEq)] +impl OpCode { + pub fn meta(&self) -> MessageMeta { + MessageMeta { + proto_id: PROTO_ID_SECURE_CHANNEL, + proto_opcode: *self as u8, + reliable: !matches!(self, Self::MRPStandAloneAck), + } + } + + pub fn is_tlv(&self) -> bool { + !matches!( + self, + Self::MRPStandAloneAck + | Self::StatusReport + | Self::MsgCounterSyncReq + | Self::MsgCounterSyncResp + ) + } +} + +impl From for MessageMeta { + fn from(op: OpCode) -> Self { + op.meta() + } +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum SCStatusCodes { SessionEstablishmentSuccess = 0, NoSharedTrustRoots = 1, @@ -54,56 +79,49 @@ pub enum SCStatusCodes { SessionNotFound = 5, } +impl SCStatusCodes { + pub fn reliable(&self) -> bool { + // CloseSession and Busy are sent without the R flag raised + !matches!(self, SCStatusCodes::CloseSession | SCStatusCodes::Busy) + } + + pub fn as_report<'a>(&self, payload: &'a [u8]) -> StatusReport<'a> { + let general_code = match self { + SCStatusCodes::SessionEstablishmentSuccess => GeneralCode::Success, + SCStatusCodes::CloseSession => GeneralCode::Success, + SCStatusCodes::Busy => GeneralCode::Busy, + SCStatusCodes::InvalidParameter + | SCStatusCodes::NoSharedTrustRoots + | SCStatusCodes::SessionNotFound => GeneralCode::Failure, + }; + + StatusReport { + general_code, + proto_id: PROTO_ID_SECURE_CHANNEL as u32, + proto_code: *self as u16, + proto_data: payload, + } + } +} + pub async fn complete_with_status( exchange: &mut Exchange<'_>, - tx: &mut Packet<'_>, status_code: SCStatusCodes, - proto_data: Option<&[u8]>, + payload: &[u8], ) -> Result<(), Error> { - create_sc_status_report(tx, status_code, proto_data)?; - - exchange.send_complete(tx).await + exchange + .send_with(|_, wb| sc_write(wb, status_code, payload)) + .await } -pub fn create_sc_status_report( - proto_tx: &mut Packet, +pub fn sc_write( + wb: &mut WriteBuf, status_code: SCStatusCodes, - proto_data: Option<&[u8]>, -) -> Result<(), Error> { - let general_code = match status_code { - SCStatusCodes::SessionEstablishmentSuccess => GeneralCode::Success, - SCStatusCodes::CloseSession => { - proto_tx.unset_reliable(); - // No time to manage reliable delivery for close session - // the session will be closed soon - GeneralCode::Success - } - SCStatusCodes::Busy => GeneralCode::Busy, - SCStatusCodes::InvalidParameter - | SCStatusCodes::NoSharedTrustRoots - | SCStatusCodes::SessionNotFound => GeneralCode::Failure, - }; - - create_status_report( - proto_tx, - general_code, - PROTO_ID_SECURE_CHANNEL as u32, - status_code as u16, - proto_data, - ) -} + payload: &[u8], +) -> Result, Error> { + status_code.as_report(payload).write(wb)?; -pub fn create_mrp_standalone_ack(proto_tx: &mut Packet) { - proto_tx.reset(); - proto_tx.set_proto_id(PROTO_ID_SECURE_CHANNEL); - proto_tx.set_proto_opcode(OpCode::MRPStandAloneAck as u8); - proto_tx.unset_reliable(); + Ok(Some( + OpCode::StatusReport.meta().reliable(status_code.reliable()), + )) } - -// TODO -// pub fn send_mrp_standalone_ack(exch: &mut Exchange, sess: &mut SessionHandle) -> Result<(), Error> { -// info!("Sending standalone ACK"); -// let mut ack_packet = Slab::::try_new(Packet::new_tx()?).ok_or(Error::NoMemory)?; -// create_mrp_standalone_ack(&mut ack_packet); -// exch.send(ack_packet, sess) -// } diff --git a/rs-matter/src/secure_channel/core.rs b/rs-matter/src/secure_channel/core.rs index ee1d36c0..b5d77259 100644 --- a/rs-matter/src/secure_channel/core.rs +++ b/rs-matter/src/secure_channel/core.rs @@ -18,12 +18,17 @@ use log::error; use crate::{ + alloc, error::*, + respond::ExchangeHandler, secure_channel::{common::*, pake::Pake}, - transport::{exchange::Exchange, packet::Packet}, + transport::exchange::Exchange, }; -use super::case::Case; +use super::{ + case::{Case, CaseSession}, + spake2p::Spake2P, +}; /* Handle messages related to the Secure Channel */ @@ -36,17 +41,27 @@ impl SecureChannel { Self(()) } - pub async fn handle( - &self, - exchange: &mut Exchange<'_>, - rx: &mut Packet<'_>, - tx: &mut Packet<'_>, - ) -> Result<(), Error> { - match rx.get_proto_opcode()? { - OpCode::PBKDFParamRequest => Pake::new().handle(exchange, rx, tx).await, - OpCode::CASESigma1 => Case::new().handle(exchange, rx, tx).await, - proto_opcode => { - error!("OpCode not handled: {:?}", proto_opcode); + pub async fn handle(&self, exchange: &mut Exchange<'_>) -> Result<(), Error> { + if exchange.rx().is_err() { + exchange.recv_fetch().await?; + } + + let meta = exchange.rx()?.meta(); + if meta.proto_id != PROTO_ID_SECURE_CHANNEL { + Err(ErrorCode::InvalidProto)?; + } + + match meta.opcode()? { + OpCode::PBKDFParamRequest => { + let mut spake2p = alloc!(Spake2P::new()); + Pake::new().handle(exchange, &mut spake2p).await + } + OpCode::CASESigma1 => { + let mut case_session = alloc!(CaseSession::new()); + Case::new().handle(exchange, &mut case_session).await + } + opcode => { + error!("Invalid opcode: {:?}", opcode); Err(ErrorCode::InvalidOpcode.into()) } } @@ -58,3 +73,9 @@ impl Default for SecureChannel { Self::new() } } + +impl ExchangeHandler for SecureChannel { + async fn handle(&self, exchange: &mut Exchange<'_>) -> Result<(), Error> { + SecureChannel::handle(self, exchange).await + } +} diff --git a/rs-matter/src/secure_channel/mod.rs b/rs-matter/src/secure_channel/mod.rs index 9b538b60..70aa6dd8 100644 --- a/rs-matter/src/secure_channel/mod.rs +++ b/rs-matter/src/secure_channel/mod.rs @@ -28,6 +28,7 @@ pub mod crypto_openssl; #[cfg(feature = "rustcrypto")] pub mod crypto_rustcrypto; +pub mod busy; pub mod core; pub mod crypto; pub mod pake; diff --git a/rs-matter/src/secure_channel/pake.rs b/rs-matter/src/secure_channel/pake.rs index 56c30543..2590bff7 100644 --- a/rs-matter/src/secure_channel/pake.rs +++ b/rs-matter/src/secure_channel/pake.rs @@ -18,19 +18,18 @@ use core::{fmt::Write, time::Duration}; use super::{ - common::{SCStatusCodes, PROTO_ID_SECURE_CHANNEL}, - spake2p::{Spake2P, VerifierData}, + common::SCStatusCodes, + spake2p::{Spake2P, VerifierData, MAX_SALT_SIZE_BYTES}, }; use crate::{ - alloc, crypto, + crypto, error::{Error, ErrorCode}, mdns::{Mdns, ServiceMode}, secure_channel::common::{complete_with_status, OpCode}, tlv::{self, get_root_node_struct, FromTLV, OctetStr, TLVWriter, TagType, ToTLV}, transport::{ exchange::{Exchange, ExchangeId}, - packet::Packet, - session::{CloneData, SessionMode}, + session::{ReservedSession, SessionMode}, }, utils::{epoch::Epoch, rand::Rand}, }; @@ -122,7 +121,7 @@ impl Timeout { fn new(exchange: &Exchange, epoch: Epoch) -> Self { Self { start_time: epoch(), - exch_id: exchange.id().clone(), + exch_id: exchange.id(), } } @@ -142,30 +141,48 @@ impl Pake { pub async fn handle( &mut self, exchange: &mut Exchange<'_>, - rx: &mut Packet<'_>, - tx: &mut Packet<'_>, + spake2p: &mut Spake2P, ) -> Result<(), Error> { - let mut spake2p = alloc!(Spake2P::new()); + let session = ReservedSession::reserve(exchange.matter()).await?; + + if !self.update_timeout(exchange, true).await? { + return Ok(()); + } + + self.handle_pbkdfparamrequest(exchange, spake2p).await?; + + exchange.recv_fetch().await?; + + if !self.update_timeout(exchange, false).await? { + return Ok(()); + } + + self.handle_pasepake1(exchange, spake2p).await?; + + exchange.recv_fetch().await?; + + if !self.update_timeout(exchange, false).await? { + return Ok(()); + } - self.handle_pbkdfparamrequest(exchange, rx, tx, &mut spake2p) - .await?; - self.handle_pasepake1(exchange, rx, tx, &mut spake2p) - .await?; - self.handle_pasepake3(exchange, rx, tx, &mut spake2p).await + self.handle_pasepake3(exchange, session, spake2p).await?; + + exchange.acknowledge().await?; + exchange.matter().notify_changed(); + + Ok(()) } #[allow(non_snake_case)] async fn handle_pasepake3( &mut self, exchange: &mut Exchange<'_>, - rx: &Packet<'_>, - tx: &mut Packet<'_>, + mut session: ReservedSession<'_>, spake2p: &mut Spake2P, ) -> Result<(), Error> { - rx.check_proto_opcode(OpCode::PASEPake3 as _)?; - self.update_timeout(exchange, tx, true).await?; + exchange.rx()?.meta().check_opcode(OpCode::PASEPake3)?; - let cA = extract_pasepake_1_or_3_params(rx.as_slice())?; + let cA = extract_pasepake_1_or_3_params(exchange.rx()?.payload())?; let (status, ke) = spake2p.handle_cA(cA); let result = if status == SCStatusCodes::SessionEstablishmentSuccess { @@ -179,117 +196,119 @@ impl Pake { let data = spake2p.get_app_data(); let peer_sessid: u16 = (data & 0xffff) as u16; let local_sessid: u16 = ((data >> 16) & 0xffff) as u16; - let mut clone_data = CloneData::new( + let peer_addr = exchange.with_session(|sess| Ok(sess.get_peer_addr()))?; + + session.update( 0, 0, peer_sessid, local_sessid, - exchange.with_session(|sess| Ok(sess.get_peer_addr()))?, + peer_addr, SessionMode::Pase, - ); - clone_data.dec_key.copy_from_slice(&session_keys[0..16]); - clone_data.enc_key.copy_from_slice(&session_keys[16..32]); - clone_data - .att_challenge - .copy_from_slice(&session_keys[32..48]); - - Ok(clone_data) + Some(&session_keys[0..16]), + Some(&session_keys[16..32]), + Some(&session_keys[32..48]), + )?; + + Ok(()) } else { Err(status) }; let status = match result { - Ok(clone_data) => { - let mdns = &exchange.matter.mdns; + Ok(()) => { + let mdns = &exchange.matter().transport_mgr.mdns; - exchange.clone_session(tx, &clone_data).await?; exchange - .matter + .matter() .pase_mgr .borrow_mut() .disable_pase_session(mdns)?; + session.complete(); SCStatusCodes::SessionEstablishmentSuccess } Err(status) => status, }; - complete_with_status(exchange, tx, status, None).await + complete_with_status(exchange, status, &[]).await } #[allow(non_snake_case)] async fn handle_pasepake1( &mut self, exchange: &mut Exchange<'_>, - rx: &mut Packet<'_>, - tx: &mut Packet<'_>, spake2p: &mut Spake2P, ) -> Result<(), Error> { - rx.check_proto_opcode(OpCode::PASEPake1 as _)?; - self.update_timeout(exchange, tx, false).await?; + exchange.rx()?.meta().check_opcode(OpCode::PASEPake1)?; + + let pA = extract_pasepake_1_or_3_params(exchange.rx()?.payload())?; + let mut pB: [u8; 65] = [0; 65]; + let mut cB: [u8; 32] = [0; 32]; { - let pase = exchange.matter.pase_mgr.borrow(); + let pase = exchange.matter().pase_mgr.borrow(); let session = pase.session.as_ref().ok_or(ErrorCode::NoSession)?; - let pA = extract_pasepake_1_or_3_params(rx.as_slice())?; - let mut pB: [u8; 65] = [0; 65]; - let mut cB: [u8; 32] = [0; 32]; spake2p.start_verifier(&session.verifier)?; spake2p.handle_pA(pA, &mut pB, &mut cB, pase.rand)?; - - // Generate response - tx.reset(); - tx.set_proto_id(PROTO_ID_SECURE_CHANNEL); - tx.set_proto_opcode(OpCode::PASEPake2 as u8); - - let mut tw = TLVWriter::new(tx.get_writebuf()?); - let resp = Pake1Resp { - pb: OctetStr(&pB), - cb: OctetStr(&cB), - }; - resp.to_tlv(&mut tw, TagType::Anonymous)?; } - exchange.exchange(tx, rx).await + exchange + .send_with(|_, wb| { + let resp = Pake1Resp { + pb: OctetStr(&pB), + cb: OctetStr(&cB), + }; + resp.to_tlv(&mut TLVWriter::new(wb), TagType::Anonymous)?; + + Ok(Some(OpCode::PASEPake2.into())) + }) + .await } async fn handle_pbkdfparamrequest( &mut self, exchange: &mut Exchange<'_>, - rx: &mut Packet<'_>, - tx: &mut Packet<'_>, spake2p: &mut Spake2P, ) -> Result<(), Error> { - rx.check_proto_opcode(OpCode::PBKDFParamRequest as _)?; - self.update_timeout(exchange, tx, true).await?; + let rx = exchange.rx()?; + rx.meta().check_opcode(OpCode::PBKDFParamRequest)?; - { - let pase = exchange.matter.pase_mgr.borrow(); + let mut our_random = [0; 32]; + let mut initiator_random = [0; 32]; + let mut salt = [0; MAX_SALT_SIZE_BYTES]; + + let resp = { + let pase = exchange.matter().pase_mgr.borrow(); let session = pase.session.as_ref().ok_or(ErrorCode::NoSession)?; - let root = tlv::get_root_node(rx.as_slice())?; + let root = tlv::get_root_node(rx.payload())?; let a = PBKDFParamReq::from_tlv(&root)?; if a.passcode_id != 0 { error!("Can't yet handle passcode_id != 0"); Err(ErrorCode::Invalid)?; } - let mut our_random: [u8; 32] = [0; 32]; - (exchange.matter.rand)(&mut our_random); + (exchange.matter().pase_mgr.borrow().rand)(&mut our_random); - let local_sessid = exchange.get_next_sess_id(); + let local_sessid = exchange + .matter() + .transport_mgr + .session_mgr + .borrow_mut() + .get_next_sess_id(); let spake2p_data: u32 = ((local_sessid as u32) << 16) | a.initiator_ssid as u32; spake2p.set_app_data(spake2p_data); - // Generate response - tx.reset(); - tx.set_proto_id(PROTO_ID_SECURE_CHANNEL); - tx.set_proto_opcode(OpCode::PBKDFParamResponse as u8); + initiator_random[..a.initiator_random.0.len()].copy_from_slice(a.initiator_random.0); + let initiator_random = &initiator_random[..a.initiator_random.0.len()]; + + salt.copy_from_slice(&session.verifier.salt); - let mut tw = TLVWriter::new(tx.get_writebuf()?); + // Generate response let mut resp = PBKDFParamResp { - init_random: a.initiator_random, + init_random: OctetStr::new(initiator_random), our_random: OctetStr(&our_random), local_sessid, params: None, @@ -297,28 +316,43 @@ impl Pake { if !a.has_params { let params_resp = PBKDFParamRespParams { count: session.verifier.count, - salt: OctetStr(&session.verifier.salt), + salt: OctetStr(&salt), }; resp.params = Some(params_resp); } - resp.to_tlv(&mut tw, TagType::Anonymous)?; - spake2p.set_context(rx.as_slice(), tx.as_mut_slice())?; - } + resp + }; + + spake2p.set_context()?; + spake2p.update_context(rx.payload())?; + + let mut context_set = false; + exchange + .send_with(|_, wb| { + resp.to_tlv(&mut TLVWriter::new(wb), TagType::Anonymous)?; + + if !context_set { + spake2p.update_context(wb.as_slice())?; + context_set = true; + } - exchange.exchange(tx, rx).await + Ok(Some(OpCode::PBKDFParamResponse.into())) + }) + .await } async fn update_timeout( &mut self, exchange: &mut Exchange<'_>, - tx: &mut Packet<'_>, new: bool, - ) -> Result<(), Error> { - self.check_session(exchange, tx).await?; + ) -> Result { + if !self.check_session(exchange).await? { + return Ok(false); + } let status = { - let mut pase = exchange.matter.pase_mgr.borrow_mut(); + let mut pase = exchange.matter().pase_mgr.borrow_mut(); if pase .timeout @@ -330,7 +364,7 @@ impl Pake { } if let Some(sd) = pase.timeout.as_mut() { - if &sd.exch_id != exchange.id() { + if sd.exch_id != exchange.id() { info!("Other PAKE session in progress"); Some(SCStatusCodes::Busy) } else { @@ -345,26 +379,25 @@ impl Pake { }; if let Some(status) = status { - complete_with_status(exchange, tx, status, None).await - } else { - let mut pase = exchange.matter.pase_mgr.borrow_mut(); + complete_with_status(exchange, status, &[]).await?; + Ok(false) + } else { + let mut pase = exchange.matter().pase_mgr.borrow_mut(); pase.timeout = Some(Timeout::new(exchange, pase.epoch)); - Ok(()) + Ok(true) } } - async fn check_session( - &mut self, - exchange: &mut Exchange<'_>, - tx: &mut Packet<'_>, - ) -> Result<(), Error> { - if exchange.matter.pase_mgr.borrow().session.is_none() { + async fn check_session(&mut self, exchange: &mut Exchange<'_>) -> Result { + if exchange.matter().pase_mgr.borrow().session.is_none() { error!("PASE not enabled"); - complete_with_status(exchange, tx, SCStatusCodes::InvalidParameter, None).await + complete_with_status(exchange, SCStatusCodes::InvalidParameter, &[]).await?; + + Ok(false) } else { - Ok(()) + Ok(true) } } } diff --git a/rs-matter/src/secure_channel/spake2p.rs b/rs-matter/src/secure_channel/spake2p.rs index 1ee00b6b..4c61880f 100644 --- a/rs-matter/src/secure_channel/spake2p.rs +++ b/rs-matter/src/secure_channel/spake2p.rs @@ -73,7 +73,7 @@ const CRYPTO_GROUP_SIZE_BYTES: usize = 32; const CRYPTO_W_SIZE_BYTES: usize = CRYPTO_GROUP_SIZE_BYTES + 8; const CRYPTO_PUBLIC_KEY_SIZE_BYTES: usize = (2 * CRYPTO_GROUP_SIZE_BYTES) + 1; -const MAX_SALT_SIZE_BYTES: usize = 32; +pub const MAX_SALT_SIZE_BYTES: usize = 32; const VERIFIER_SIZE_BYTES: usize = CRYPTO_GROUP_SIZE_BYTES + CRYPTO_PUBLIC_KEY_SIZE_BYTES; fn crypto_spake2_new() -> Result { @@ -150,15 +150,17 @@ impl Spake2P { self.app_data } - pub fn set_context(&mut self, buf1: &[u8], buf2: &[u8]) -> Result<(), Error> { + pub fn set_context(&mut self) -> Result<(), Error> { let mut context = Sha256::new()?; context.update(&SPAKE2P_CONTEXT_PREFIX)?; - context.update(buf1)?; - context.update(buf2)?; self.context = Some(context); Ok(()) } + pub fn update_context(&mut self, buf: &[u8]) -> Result<(), Error> { + self.context.as_mut().unwrap().update(buf) + } + #[inline(always)] fn get_w0w1s(pw: u32, iter: u32, salt: &[u8], w0w1s: &mut [u8]) { let mut pw_str: [u8; 4] = [0; 4]; diff --git a/rs-matter/src/secure_channel/status_report.rs b/rs-matter/src/secure_channel/status_report.rs index e8378746..d365dd1a 100644 --- a/rs-matter/src/secure_channel/status_report.rs +++ b/rs-matter/src/secure_channel/status_report.rs @@ -15,11 +15,15 @@ * limitations under the License. */ -use super::common::*; -use crate::{error::Error, transport::packet::Packet}; +use num_derive::FromPrimitive; + +use crate::{ + error::{Error, ErrorCode}, + utils::{parsebuf::ParseBuf, writebuf::WriteBuf}, +}; #[allow(dead_code)] -#[derive(Debug, Copy, Clone)] +#[derive(FromPrimitive, PartialEq, Eq, Debug, Copy, Clone)] pub enum GeneralCode { Success = 0, Failure = 1, @@ -40,23 +44,32 @@ pub enum GeneralCode { DataLoss = 16, } -pub fn create_status_report( - proto_tx: &mut Packet, - general_code: GeneralCode, - proto_id: u32, - proto_code: u16, - proto_data: Option<&[u8]>, -) -> Result<(), Error> { - proto_tx.reset(); - proto_tx.set_proto_id(PROTO_ID_SECURE_CHANNEL); - proto_tx.set_proto_opcode(OpCode::StatusReport as u8); - let wb = proto_tx.get_writebuf()?; - wb.le_u16(general_code as u16)?; - wb.le_u32(proto_id)?; - wb.le_u16(proto_code)?; - if let Some(s) = proto_data { - wb.copy_from_slice(s)?; +/// Represents a Status Report message, as per "Appendix D: Status Report Messages" of the Matter Spec. +#[derive(Debug, Clone)] +pub struct StatusReport<'a> { + pub general_code: GeneralCode, + pub proto_id: u32, + pub proto_code: u16, + pub proto_data: &'a [u8], +} + +impl<'a> StatusReport<'a> { + pub fn read(pb: &'a mut ParseBuf) -> Result { + Ok(Self { + general_code: num::FromPrimitive::from_u16(pb.le_u16()?) + .ok_or(ErrorCode::InvalidOpcode)?, + proto_id: pb.le_u32()?, + proto_code: pb.le_u16()?, + proto_data: pb.as_slice(), + }) } - Ok(()) + pub fn write(&self, wb: &mut WriteBuf) -> Result<(), Error> { + wb.le_u16(self.general_code as u16)?; + wb.le_u32(self.proto_id)?; + wb.le_u16(self.proto_code)?; + wb.copy_from_slice(self.proto_data)?; + + Ok(()) + } } diff --git a/rs-matter/src/tlv/parser.rs b/rs-matter/src/tlv/parser.rs index 5e6964c3..826653d5 100644 --- a/rs-matter/src/tlv/parser.rs +++ b/rs-matter/src/tlv/parser.rs @@ -18,7 +18,7 @@ use crate::error::{Error, ErrorCode}; use byteorder::{ByteOrder, LittleEndian}; -use core::fmt; +use core::fmt::{self, Display}; use log::{error, info}; use super::{TagType, MAX_TAG_INDEX, TAG_MASK, TAG_SHIFT_BITS, TAG_SIZE_MAP, TYPE_MASK}; @@ -33,6 +33,64 @@ impl<'a> TLVList<'a> { } } +impl<'a> Display for TLVList<'a> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let tlvlist = self; + + const MAX_DEPTH: usize = 9; + const SPACE_BUF: &str = " "; + + let space: [&str; MAX_DEPTH] = [ + &SPACE_BUF[0..0], + &SPACE_BUF[0..4], + &SPACE_BUF[0..8], + &SPACE_BUF[0..12], + &SPACE_BUF[0..16], + &SPACE_BUF[0..20], + &SPACE_BUF[0..24], + &SPACE_BUF[0..28], + &SPACE_BUF[0..32], + ]; + + let mut stack: [char; MAX_DEPTH] = [' '; MAX_DEPTH]; + let mut index = 0_usize; + let iter = tlvlist.iter(); + for a in iter { + match a.element_type { + ElementType::Struct(_) => { + if index < MAX_DEPTH { + writeln!(f, "{}{}", space[index], a)?; + stack[index] = '}'; + index += 1; + } else { + writeln!(f, "<>")?; + } + } + ElementType::Array(_) | ElementType::List(_) => { + if index < MAX_DEPTH { + writeln!(f, "{}{}", space[index], a)?; + stack[index] = ']'; + index += 1; + } else { + writeln!(f, "<>")?; + } + } + ElementType::EndCnt => { + if index > 0 { + index -= 1; + writeln!(f, "{}{}", space[index], stack[index])?; + } else { + writeln!(f, "<>")?; + } + } + _ => writeln!(f, "{}{}", space[index], a)?, + } + } + + Ok(()) + } +} + #[derive(Debug, Clone, PartialEq)] pub enum ElementType<'a> { S8(i8), @@ -722,57 +780,7 @@ pub fn get_root_node_list(b: &[u8]) -> Result { } pub fn print_tlv_list(b: &[u8]) { - let tlvlist = TLVList::new(b); - - const MAX_DEPTH: usize = 9; - info!("TLV list:"); - let space_buf = " "; - let space: [&str; MAX_DEPTH] = [ - &space_buf[0..0], - &space_buf[0..4], - &space_buf[0..8], - &space_buf[0..12], - &space_buf[0..16], - &space_buf[0..20], - &space_buf[0..24], - &space_buf[0..28], - &space_buf[0..32], - ]; - let mut stack: [char; MAX_DEPTH] = [' '; MAX_DEPTH]; - let mut index = 0_usize; - let iter = tlvlist.iter(); - for a in iter { - match a.element_type { - ElementType::Struct(_) => { - if index < MAX_DEPTH { - info!("{}{}", space[index], a); - stack[index] = '}'; - index += 1; - } else { - error!("Too Deep"); - } - } - ElementType::Array(_) | ElementType::List(_) => { - if index < MAX_DEPTH { - info!("{}{}", space[index], a); - stack[index] = ']'; - index += 1; - } else { - error!("Too Deep"); - } - } - ElementType::EndCnt => { - if index > 0 { - index -= 1; - info!("{}{}", space[index], stack[index]); - } else { - error!("Incorrect TLV List"); - } - } - _ => info!("{}{}", space[index], a), - } - } - info!("---------"); + info!("TLV list:\n{}\n---------", TLVList::new(b)); } #[cfg(test)] diff --git a/rs-matter/src/transport/core.rs b/rs-matter/src/transport/core.rs index 3aeb8b0d..3dfdd63e 100644 --- a/rs-matter/src/transport/core.rs +++ b/rs-matter/src/transport/core.rs @@ -15,79 +15,205 @@ * limitations under the License. */ -use core::borrow::Borrow; -use core::mem::MaybeUninit; +use core::cell::RefCell; +use core::fmt::{self, Display}; +use core::ops::{Deref, DerefMut}; use core::pin::pin; -use embassy_futures::select::{select, select_slice, Either}; -use embassy_sync::{blocking_mutex::raw::NoopRawMutex, channel::Channel}; -use embassy_time::{Duration, Timer}; +use embassy_futures::select::{select, select3}; +use embassy_sync::blocking_mutex::raw::NoopRawMutex; +use embassy_time::Timer; -use log::{error, info, warn}; +use log::{debug, error, info, trace, warn}; -use crate::interaction_model::core::IMStatusCode; -use crate::mdns::Mdns; -use crate::secure_channel::common::SCStatusCodes; -use crate::secure_channel::status_report::{create_status_report, GeneralCode}; +use crate::error::{Error, ErrorCode}; +use crate::mdns::{Mdns, MdnsImpl}; +use crate::secure_channel::common::{sc_write, OpCode, SCStatusCodes, PROTO_ID_SECURE_CHANNEL}; +use crate::secure_channel::status_report::StatusReport; +use crate::tlv::TLVList; use crate::utils::buf::BufferAccess; -use crate::utils::select::Notification; -use crate::{ - alloc, - data_model::{core::DataModel, objects::DataModelHandler}, - error::{Error, ErrorCode}, - interaction_model::core::PROTO_ID_INTERACTION_MODEL, - secure_channel::{ - common::{OpCode, PROTO_ID_SECURE_CHANNEL}, - core::SecureChannel, - }, - transport::packet::Packet, - utils::select::EitherUnwrap, - CommissioningData, Matter, MATTER_PORT, +use crate::utils::{ + epoch::Epoch, + ifmutex::{IfMutex, IfMutexGuard}, + notification::Notification, + parsebuf::ParseBuf, + rand::Rand, + select::Coalesce, + writebuf::WriteBuf, }; +use crate::{Matter, MATTER_PORT}; -use super::{ - exchange::{ - Exchange, ExchangeCtr, ExchangeCtx, ExchangeId, ExchangeState, Role, SessionId, - MAX_EXCHANGES, - }, - mrp::ReliableMessage, - network::{Ipv6Addr, NetworkReceive, NetworkSend, SocketAddr, SocketAddrV6}, - packet::{MAX_RX_BUF_SIZE, MAX_RX_STATUS_BUF_SIZE, MAX_TX_BUF_SIZE}, +use super::exchange::{Exchange, ExchangeId, ExchangeState, MessageMeta, ResponderState, Role}; +use super::network::{ + self, Address, Ipv6Addr, NetworkReceive, NetworkSend, SocketAddr, SocketAddrV6, }; +use super::packet::PacketHdr; +use super::proto_hdr::ProtoHdr; +use super::session::{Session, SessionMgr}; + +#[cfg(all(feature = "large-buffers", feature = "alloc"))] +extern crate alloc; pub const MATTER_SOCKET_BIND_ADDR: SocketAddr = SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, MATTER_PORT, 0, 0)); -type TxBuf = MaybeUninit<[u8; MAX_TX_BUF_SIZE]>; -type RxBuf = MaybeUninit<[u8; MAX_RX_BUF_SIZE]>; -type SxBuf = MaybeUninit<[u8; MAX_RX_STATUS_BUF_SIZE]>; - -pub struct PacketBuffers { - tx: [TxBuf; MAX_EXCHANGES], - rx: [RxBuf; MAX_EXCHANGES], - sx: [SxBuf; MAX_EXCHANGES + 1], +const ACCEPT_TIMEOUT_MS: u64 = 1000; + +#[cfg(all(feature = "large-buffers", feature = "alloc"))] +pub(crate) const MAX_RX_BUF_SIZE: usize = network::MAX_RX_LARGE_PACKET_SIZE; +#[cfg(all(feature = "large-buffers", feature = "alloc"))] +pub(crate) const MAX_TX_BUF_SIZE: usize = network::MAX_TX_LARGE_PACKET_SIZE; + +#[cfg(not(all(feature = "large-buffers", feature = "alloc")))] +pub(crate) const MAX_RX_BUF_SIZE: usize = network::MAX_RX_PACKET_SIZE; +#[cfg(not(all(feature = "large-buffers", feature = "alloc")))] +pub(crate) const MAX_TX_BUF_SIZE: usize = network::MAX_TX_PACKET_SIZE; + +/// Represents the transport layer of a `Matter` instance. +/// Each `Matter` instance has exactly one `TransportMgr` instance. +/// +/// To the outside world, the transport layer is only visible and usable via the notion of `Exchange`. +pub struct TransportMgr<'m> { + pub(crate) rx: IfMutex>, + pub(crate) tx: IfMutex>, + pub(crate) dropped: Notification, + pub session_mgr: RefCell, // For testing + pub(crate) mdns: MdnsImpl<'m>, } -impl PacketBuffers { - const TX_ELEM: TxBuf = MaybeUninit::uninit(); - const RX_ELEM: RxBuf = MaybeUninit::uninit(); - const SX_ELEM: SxBuf = MaybeUninit::uninit(); - - const TX_INIT: [TxBuf; MAX_EXCHANGES] = [Self::TX_ELEM; MAX_EXCHANGES]; - const RX_INIT: [RxBuf; MAX_EXCHANGES] = [Self::RX_ELEM; MAX_EXCHANGES]; - const SX_INIT: [SxBuf; MAX_EXCHANGES + 1] = [Self::SX_ELEM; MAX_EXCHANGES + 1]; - +impl<'m> TransportMgr<'m> { #[inline(always)] - pub const fn new() -> Self { + pub(crate) const fn new(mdns: MdnsImpl<'m>, epoch: Epoch, rand: Rand) -> Self { Self { - tx: Self::TX_INIT, - rx: Self::RX_INIT, - sx: Self::SX_INIT, + rx: IfMutex::new(Packet::new()), + tx: IfMutex::new(Packet::new()), + dropped: Notification::new(), + session_mgr: RefCell::new(SessionMgr::new(epoch, rand)), + mdns, } } -} -impl<'a> Matter<'a> { + #[cfg(all(feature = "large-buffers", feature = "alloc"))] + pub fn initialize_buffers(&self) -> Result<(), Error> { + let mut rx = self.rx.try_lock().map_err(|_| ErrorCode::InvalidState)?; + let mut tx = self.tx.try_lock().map_err(|_| ErrorCode::InvalidState)?; + + if rx.buf.0.is_none() { + rx.buf.0 = Some(alloc::boxed::Box::new(heapless::Vec::new())); + } + + if tx.buf.0.is_none() { + tx.buf.0 = Some(alloc::boxed::Box::new(heapless::Vec::new())); + } + + Ok(()) + } + + #[cfg(not(all(feature = "large-buffers", feature = "alloc")))] + pub fn initialize_buffers(&self) -> Result<(), Error> { + // No-op, as buffers are allocated inline + Ok(()) + } + + pub fn reset(&self) { + self.session_mgr.borrow_mut().reset(); + self.mdns.reset(); + } + + pub(crate) async fn initiate<'a>( + &'a self, + matter: &'a Matter<'a>, + node_id: u64, + secure: bool, + ) -> Result, Error> { + let mut session_mgr = self.session_mgr.borrow_mut(); + + session_mgr + .get_for_node(node_id, secure) + .ok_or(ErrorCode::NoSession)?; + + let exch_id = session_mgr.get_next_exch_id(); + + // `unwrap` is safe because we know we have a session or else the early return from above would've triggered + // The reason why we call `get_for_node` twice is to ensure that we don't waste an `exch_id` in case + // we don't have a session in the first place + let session = session_mgr.get_for_node(node_id, secure).unwrap(); + + let exch_index = session + .add_exch(exch_id, Role::Initiator(Default::default())) + .ok_or(ErrorCode::NoSpaceExchanges)?; + + let id = ExchangeId::new(session.id, exch_index); + + info!("Exchange {id}: Initiated"); + + Ok(Exchange::new(id, matter)) + } + + pub(crate) async fn accept_if<'a, F>( + &'a self, + matter: &'a Matter<'a>, + mut f: F, + ) -> Result, Error> + where + F: FnMut(&Session, &ExchangeState, &Packet) -> bool, + { + let exchange = self + .with_locked(&self.rx, |packet| { + let mut session_mgr = self.session_mgr.borrow_mut(); + + let session = session_mgr.get_for_rx(&packet.peer, &packet.header.plain)?; + + let exch_index = session.get_exch_for_rx(&packet.header.proto)?; + + let matches = { + // `unwrap` is safe because the transport code is single threaded, and since we don't `await` + // after computing `exch_index` no code can remove the exchange from the session + let exch = session.exchanges[exch_index].as_ref().unwrap(); + + matches!(exch.role, Role::Responder(ResponderState::AcceptPending)) + && f(session, exch, packet) + }; + + if !matches { + return None; + } + + // `unwrap` is safe because the transport code is single threaded, and since we don't `await` + // after computing `exch_index` no code can remove the exchange from the session + let exch = session.exchanges[exch_index].as_mut().unwrap(); + + exch.role = Role::Responder(ResponderState::Owned); + + let id = ExchangeId::new(session.id, exch_index); + + info!("Exchange {id}: Accepted"); + + let exchange = Exchange::new(id, matter); + + Some(exchange) + }) + .await; + + Ok(exchange) + } + + pub async fn run(&self, send: S, recv: R) -> Result<(), Error> + where + S: NetworkSend, + R: NetworkReceive, + { + info!("Running Matter transport"); + + let send = IfMutex::new(send); + + let mut rx = pin!(self.process_rx(recv, &send)); + let mut tx = pin!(self.process_tx(&send)); + let mut orphaned = pin!(self.process_orphaned()); + + select3(&mut rx, &mut tx, &mut orphaned).coalesce().await + } + #[cfg(not(all( feature = "std", any(target_os = "macos", all(feature = "zeroconf", target_os = "linux")) @@ -103,680 +229,926 @@ impl<'a> Matter<'a> { S: NetworkSend, R: NetworkReceive, { - use crate::mdns::MdnsImpl; - info!("Running Matter built-in mDNS service"); if let MdnsImpl::Builtin(mdns) = &self.mdns { - mdns.run(send, recv, &self.tx_buf, &self.rx_buf, host, interface) - .await + mdns.run( + send, + recv, + &PacketBufferExternalAccess(&self.tx), + PacketBufferExternalAccess(&self.rx), + host, + interface, + ) + .await } else { Err(ErrorCode::MdnsError.into()) } } - #[allow(clippy::too_many_arguments)] - pub async fn run( - &self, - send: S, - recv: R, - buffers: &mut PacketBuffers, - dev_comm: CommissioningData, - handler: &H, - ) -> Result<(), Error> + pub(crate) async fn get_if<'a, F, const N: usize>( + &'a self, + packet_mutex: &'a IfMutex>, + f: F, + ) -> PacketAccess<'a, N> where - H: DataModelHandler, - S: NetworkSend, - R: NetworkReceive, + F: Fn(&Packet) -> bool, { - info!("Running Matter transport"); - - { - let mut recv_buf = self.rx_buf.get().await; + PacketAccess(packet_mutex.lock_if(f).await, false) + } - if self.start_comissioning(dev_comm, &mut recv_buf)? { - info!("Comissioning started"); - } - } + async fn with_locked<'a, F, R, T>( + &'a self, + packet_mutex: &'a IfMutex, + f: F, + ) -> R + where + F: FnMut(&mut T) -> Option, + { + packet_mutex.with(f).await + } - let construction_notification = Notification::new(); + async fn process_tx(&self, send: &IfMutex) -> Result<(), Error> + where + S: NetworkSend, + { + loop { + debug!("Waiting for outgoing packet"); - let mut rx = pin!(self.handle_rx(recv, buffers, &construction_notification, handler)); - let mut tx = pin!(self.handle_tx(send)); + let mut tx = self.get_if(&self.tx, |packet| !packet.buf.is_empty()).await; + tx.clear_on_drop(true); - select(&mut rx, &mut tx).await.unwrap() + Self::netw_send(send, tx.peer, &tx.buf[tx.payload_start..], false).await?; + } } - #[inline(always)] - async fn handle_rx( + async fn process_rx( &self, - recv: R, - buffers: &mut PacketBuffers, - construction_notification: &Notification, - handler: &H, + mut recv: R, + send: &IfMutex, ) -> Result<(), Error> where - H: DataModelHandler, R: NetworkReceive, + S: NetworkSend, { - info!("Creating queue for {} exchanges", 1); + loop { + debug!("Waiting for incoming packet"); - let channel = Channel::::new(); + recv.wait_available().await?; - info!("Creating {} handlers", MAX_EXCHANGES); - let mut handlers = heapless::Vec::<_, MAX_EXCHANGES>::new(); + let mut rx = self.get_if(&self.rx, |packet| packet.buf.is_empty()).await; + rx.clear_on_drop(true); // In case of error, or if the future is dropped - info!("Handlers size: {}", core::mem::size_of_val(&handlers)); + // TODO: Resizing might be a bit expensive with large buffers + // Resizing to `MAX_RX_BUF_SIZE` is always safe because the size of the `buf` heapless vec `MAX_RX_BUF_SIZE` + rx.buf.resize_default(MAX_RX_BUF_SIZE).unwrap(); - // Unsafely allow mutable aliasing in the packet pools by different indices - let pools: *mut PacketBuffers = buffers; + let (len, peer) = Self::netw_recv(&mut recv, &mut rx.buf).await?; - for index in 0..MAX_EXCHANGES { - let channel = &channel; - let handler_id = index; + rx.peer = peer; + rx.buf.truncate(len); + rx.payload_start = 0; - let pools = unsafe { pools.as_mut() }.unwrap(); + match self.handle_rx_packet(&mut rx, send).await { + Ok(true) => { + // Leave the packet in place for accepting by responders + rx.clear_on_drop(false); + } + Ok(false) => { + // Drop the packet, as no further processing is necessary + } + Err(e) => { + // Drop the packet and report the unexpected error + error!("UNEXPECTED RX ERROR: {e:?}"); + } + } + } + } - let tx_buf = unsafe { pools.tx[handler_id].assume_init_mut() }; - let rx_buf = unsafe { pools.rx[handler_id].assume_init_mut() }; - let sx_buf = unsafe { pools.sx[handler_id].assume_init_mut() }; + async fn process_orphaned(&self) -> Result<(), Error> { + let mut rx_accept_timeout = pin!(self.process_accept_timeout_rx()); + let mut rx_orphaned = pin!(self.process_orphaned_rx()); + let mut exch_dropped = pin!(self.process_dropped_exchanges()); - handlers - .push(self.exchange_handler(tx_buf, rx_buf, sx_buf, handler_id, channel, handler)) - .map_err(|_| ()) - .unwrap(); - } + select3(&mut rx_accept_timeout, &mut rx_orphaned, &mut exch_dropped) + .coalesce() + .await + } - let mut rx = pin!(self.handle_rx_multiplex( - recv, - unsafe { buffers.sx[MAX_EXCHANGES].assume_init_mut() }, - construction_notification, - &channel, - )); + async fn process_accept_timeout_rx(&self) -> Result<(), Error> { + loop { + trace!("Waiting for accept timeout"); - let result = select(&mut rx, select_slice(&mut handlers)).await; + let mut accept_timeout = pin!(self.with_locked(&self.rx, |packet| { + self.handle_accept_timeout_rx_packet(packet).then_some(()) + })); - if let Either::First(result) = result { - if let Err(e) = &result { - error!("Exitting RX loop due to an error: {:?}", e); - } + let mut timer = pin!(Timer::after(embassy_time::Duration::from_millis(50))); - result?; + select(&mut accept_timeout, &mut timer).await; } - - Ok(()) } - #[inline(always)] - pub async fn handle_tx(&self, mut send: S) -> Result<(), Error> - where - S: NetworkSend, - { + async fn process_orphaned_rx(&self) -> Result<(), Error> { loop { - loop { - { - let mut send_buf = self.tx_buf.get().await; + info!("Waiting for orphaned RX packets"); - let mut tx = alloc!(Packet::new_tx(&mut send_buf)); + self.with_locked(&self.rx, |packet| { + self.handle_orphaned_rx_packet(packet).then_some(()) + }) + .await; + } + } - if self.pull_tx(&mut tx)? { - let addr = tx.peer; + async fn process_dropped_exchanges(&self) -> Result<(), Error> { + loop { + trace!("Waiting for dropped exchanges"); - let start = tx.get_writebuf()?.get_start(); - let end = tx.get_writebuf()?.get_tail(); + let mut tx = self.get_if(&self.tx, |packet| packet.buf.is_empty()).await; + tx.clear_on_drop(true); // In case of error, or if the future is dropped - send.send_to(&send_buf[start..end], addr).await?; - } else { - break; - } + let wait = match self.handle_dropped_exchange(&mut tx) { + Ok(wait) => { + tx.clear_on_drop(false); + wait } - } + Err(e) => { + error!("UNEXPECTED RX ERROR: {e:?}"); + false + } + }; + + drop(tx); + + if wait { + let mut timeout = pin!(Timer::after(embassy_time::Duration::from_millis(100))); + let mut wait = pin!(self.dropped.wait()); - self.wait_tx().await?; + select(&mut timeout, &mut wait).await; + } } } - #[inline(always)] - pub async fn handle_rx_multiplex<'t, 'e, const N: usize, R>( - &'t self, - mut receiver: R, - sts_buf: &mut [u8; MAX_RX_STATUS_BUF_SIZE], - construction_notification: &'e Notification, - channel: &Channel, N>, - ) -> Result<(), Error> + async fn handle_rx_packet( + &self, + packet: &mut Packet, + send: &IfMutex, + ) -> Result where - R: NetworkReceive, - 't: 'e, + S: NetworkSend, { - let mut sts_tx = alloc!(Packet::new_tx(sts_buf)); - - loop { - info!("Transport: waiting for incoming packets"); + let result = self.decode_packet(packet); + match result { + Err(e) if matches!(e.code(), ErrorCode::Duplicate) => { + if !packet.peer.is_reliable() { + info!("\n>>>>> {packet}\n => Duplicate, sending ACK"); + + { + let mut session_mgr = self.session_mgr.borrow_mut(); + let epoch = session_mgr.epoch; + + // `unwrap` is safe because we know we have a session. + // If we didn't have a session, the error code would've been `NoSession` + // + // Also, since the transport code is single threaded, and since we don't `await` + // after decoding the packet, no code can the session + let session = session_mgr + .get_for_rx(&packet.peer, &packet.header.plain) + .unwrap(); + + let ack = packet.header.plain.ctr; + + packet.header.proto.toggle_initiator(); + packet.header.proto.set_ack(Some(ack)); + + self.encode_packet(packet, Some(session), None, epoch, |_| { + Ok(Some(OpCode::MRPStandAloneAck.into())) + })?; + } - receiver.wait_available().await?; + Self::netw_send(send, packet.peer, &packet.buf[packet.payload_start..], true) + .await?; + } else { + info!("\n>>>>> {packet}\n => Duplicate, discarding"); + } + } + Err(e) if matches!(e.code(), ErrorCode::NoSpaceSessions) => { + if !packet.header.plain.is_encrypted() + && MessageMeta::from(&packet.header.proto).is_new_session() + { + error!("\n>>>>> {packet}\n => No space for a new unencrypted session, sending Busy"); - { - let mut recv_buf = self.rx_buf.get().await; + let ack = packet.header.plain.ctr; - let (len, remote) = receiver.recv_from(&mut recv_buf).await?; + packet.header.proto.toggle_initiator(); + packet.header.proto.set_ack(Some(ack)); - let mut rx = alloc!(Packet::new_rx(&mut recv_buf[..len])); - rx.peer = remote; + self.encode_packet( + packet, + None, + None, + self.session_mgr.borrow().epoch, + |wb| sc_write(wb, SCStatusCodes::Busy, &[0xF4, 0x01]), + )?; - if let Some(exchange_ctr) = self - .process_rx(construction_notification, &mut rx, &mut sts_tx) - .await? - { - let exchange_id = exchange_ctr.id().clone(); + Self::netw_send(send, packet.peer, &packet.buf[packet.payload_start..], true) + .await?; - info!("Transport: got new exchange: {:?}", exchange_id); + if self.encode_evict_some_session(packet)? { + Self::netw_send( + send, + packet.peer, + &packet.buf[packet.payload_start..], + true, + ) + .await?; + } + } else { + error!("\n>>>>> {packet}\n => No space for a new encrypted session, dropping"); + } + } + Err(e) if matches!(e.code(), ErrorCode::NoSpaceExchanges) => { + // TODO: Before closing the session, try to take other measures: + // - For CASESigma1 & PBKDFParamRequest - send Busy instead + // - For Interaction Model interactions that do need an ACK - send IM Busy, + // wait for ACK and retransmit without releasing the RX buffer, potentially + // blocking all other interactions - channel.send(exchange_ctr).await; - info!("Transport: exchange sent"); + error!("\n>>>>> {packet}\n => No space for a new exchange, closing session"); - self.wait_construction(construction_notification, &rx, &exchange_id) - .await?; + { + let mut session_mgr = self.session_mgr.borrow_mut(); + + // `unwrap` is safe because we know we have a session. + // If we didn't have a session, the error code would've been `NoSession` + // + // Also, since the transport code is single threaded, and since we don't `await` + // after decoding the packet, no code can the session + let session_id = session_mgr + .get_for_rx(&packet.peer, &packet.header.plain) + .unwrap() + .id; + + packet.header.proto.exch_id = session_mgr.get_next_exch_id(); + packet.header.proto.set_initiator(); + + // See above why `unwrap` is safe + let mut session = session_mgr.remove(session_id).unwrap(); + + self.encode_packet( + packet, + Some(&mut session), + None, + session_mgr.epoch, + |wb| sc_write(wb, SCStatusCodes::CloseSession, &[]), + )?; + } - info!("Transport: exchange started"); + Self::netw_send(send, packet.peer, &packet.buf[packet.payload_start..], true) + .await?; + } + Err(e) => { + error!("\n>>>>> {packet}\n => Error ({e:?}), dropping"); + } + Ok(new_exchange) => { + let meta = MessageMeta::from(&packet.header.proto); + + if meta.is_standalone_ack() { + // No need to propagate this further + info!("\n>>>>> {packet}\n => Standalone Ack, dropping"); + } else if meta.is_sc_status() + && matches!( + Self::is_close_session(&mut packet.buf[packet.payload_start..]), + Ok(true) + ) + { + warn!("\n>>>>> {packet}\n => Close session received, removing this session"); + + let mut session_mgr = self.session_mgr.borrow_mut(); + if let Some(session_id) = session_mgr + .get_for_rx(&packet.peer, &packet.header.plain) + .map(|sess| sess.id) + { + session_mgr.remove(session_id); + } + } else { + info!( + "\n>>>>> {packet}\n => Processing{}", + if new_exchange { " (new exchange)" } else { "" } + ); + + debug!( + "{}", + Packet::<0>::display_payload( + &packet.header.proto, + &packet.buf[core::cmp::min(packet.payload_start, packet.buf.len())..] + ) + ); + + return Ok(true); } } } - #[allow(unreachable_code)] - Ok::<_, Error>(()) + Ok(false) } - #[inline(always)] - pub async fn exchange_handler( - &self, - tx_buf: &mut [u8; MAX_TX_BUF_SIZE], - rx_buf: &mut [u8; MAX_RX_BUF_SIZE], - sx_buf: &mut [u8; MAX_RX_STATUS_BUF_SIZE], - handler_id: impl core::fmt::Display, - channel: &Channel, N>, - handler: &H, - ) -> Result<(), Error> - where - H: DataModelHandler, - { - loop { - let exchange_ctr: ExchangeCtr<'_> = channel.receive().await; + fn handle_accept_timeout_rx_packet(&self, packet: &mut Packet) -> bool { + if packet.buf.is_empty() { + return false; + } - info!( - "Handler {}: Got exchange {:?}", - handler_id, - exchange_ctr.id() - ); + let mut session_mgr = self.session_mgr.borrow_mut(); + let epoch = session_mgr.epoch; - let result = self - .handle_exchange(tx_buf, rx_buf, sx_buf, exchange_ctr, handler) - .await; + let Some(session) = session_mgr.get_for_rx(&packet.peer, &packet.header.plain) else { + return false; + }; - if let Err(err) = result { - warn!( - "Handler {}: Exchange closed because of error: {:?}", - handler_id, err - ); - } else { - info!("Handler {}: Exchange completed", handler_id); - } + let Some(exch_index) = session.get_exch_for_rx(&packet.header.proto) else { + return false; + }; + + // `unwrap` is safe because we know we have a session and an exchange, or else the early returns from above would've triggered + let exchange = session.exchanges[exch_index].as_mut().unwrap(); + + if !matches!( + exchange.role, + Role::Responder(ResponderState::AcceptPending) + ) || !exchange.mrp.has_rx_timed_out(ACCEPT_TIMEOUT_MS, epoch) + { + return false; } + + warn!("\n----- {packet}\n => Accept timeout, marking exchange as dropped"); + + exchange.role = Role::Responder(ResponderState::Dropped); + packet.buf.clear(); + self.dropped.notify(); + + true } - #[inline(always)] - pub async fn handle_exchange( - &self, - tx_buf: &mut [u8; MAX_TX_BUF_SIZE], - rx_buf: &mut [u8; MAX_RX_BUF_SIZE], - sx_buf: &mut [u8; MAX_RX_STATUS_BUF_SIZE], - exchange_ctr: ExchangeCtr<'_>, - handler: &H, - ) -> Result<(), Error> - where - H: DataModelHandler, - { - let mut tx = alloc!(Packet::new_tx(tx_buf.as_mut())); - let mut rx = alloc!(Packet::new_rx(rx_buf.as_mut())); + fn handle_orphaned_rx_packet(&self, packet: &mut Packet) -> bool { + if packet.buf.is_empty() { + return false; + } - let mut exchange = alloc!(exchange_ctr.get(&mut rx).await?); + let mut session_mgr = self.session_mgr.borrow_mut(); - match rx.get_proto_id() { - PROTO_ID_SECURE_CHANNEL => { - let sc = SecureChannel::new(); + let Some(session) = session_mgr.get_for_rx(&packet.peer, &packet.header.plain) else { + warn!("\n----- {packet}\n => No session, dropping"); - sc.handle(&mut exchange, &mut rx, &mut tx).await?; + packet.buf.clear(); + return true; + }; - self.notify_changed(); - } - PROTO_ID_INTERACTION_MODEL => { - let dm = DataModel::new(handler); + let Some(exch_index) = session.get_exch_for_rx(&packet.header.proto) else { + warn!("\n----- {packet}\n => No exchange, dropping"); + + packet.buf.clear(); + return true; + }; - let mut rx_status = alloc!(Packet::new_rx(sx_buf)); + // `unwrap` is safe because we know we have a session and an exchange, or else the early returns from above would've triggered + let exchange = session.exchanges[exch_index].as_mut().unwrap(); - dm.handle(&mut exchange, &mut rx, &mut tx, &mut rx_status) - .await?; + if exchange.role.is_dropped_state() { + warn!( + "\n----- {packet}\n => Owned by orphaned dropped {}, dropping packet", + ExchangeId::new(session.id, exch_index) + ); - self.notify_changed(); - } - other => { - error!("Unknown Proto-ID: {}", other); - } + packet.buf.clear(); + return true; } - Ok(()) + false } - pub fn reset_transport(&self) { - self.exchanges.borrow_mut().clear(); - self.session_mgr.borrow_mut().reset(); - self.mdns.reset(); - } + fn handle_dropped_exchange( + &self, + packet: &mut Packet, + ) -> Result { + let mut session_mgr = self.session_mgr.borrow_mut(); - pub async fn process_rx<'r>( - &'r self, - construction_notification: &'r Notification, - src_rx: &mut Packet<'_>, - sts_tx: &mut Packet<'_>, - ) -> Result>, Error> { - src_rx.plain_hdr_decode()?; + let exch = session_mgr + .get_exch(|_, exch| exch.role.is_dropped_state() && exch.mrp.is_retrans_pending()) + .map(|(sess, exch_index)| (sess.id, exch_index, true)) + .or_else(|| { + session_mgr + .get_exch(|_, exch| { + exch.role.is_dropped_state() && !exch.mrp.is_retrans_pending() + }) + .map(|(sess, exch_index)| (sess.id, exch_index, false)) + }); + + let Some((session_id, exch_index, close_session)) = exch else { + return Ok(exch.is_none()); + }; - self.purge()?; + let exchange_id = ExchangeId::new(session_id, exch_index); - let (exchange_index, new) = loop { - let result = self.assign_exchange(&mut self.exchanges.borrow_mut(), src_rx); + if close_session { + // Found a dropped exchange which has an incomplete (re)transmission + // Close the whole session - match result { - Err(e) => match e.code() { - ErrorCode::Duplicate => { - self.send_notification.signal(()); - return Ok(None); - } - // TODO: NoSession, NoExchange and others - ErrorCode::NoSpaceSessions => self.evict_session(sts_tx).await?, - ErrorCode::NoSpaceExchanges => { - self.send_busy(src_rx, sts_tx).await?; - return Ok(None); - } - _ => break Err(e), - }, - other => break other, - } - }?; + error!( + "Dropped exchange {exchange_id}: Closing session because the exchange cannot be closed cleanly" + ); - let mut exchanges = self.exchanges.borrow_mut(); - let ctx = &mut exchanges[exchange_index]; + self.encode_evict_session(packet, &mut session_mgr, session_id)?; + } else { + // Found a dropped exchange which has no outstanding (re)transmission + // Send a standalone ACK if necessary and then close it - src_rx.log("Got packet"); + let epoch = session_mgr.epoch; - if src_rx.proto.is_ack() { - if new { - Err(ErrorCode::Invalid)?; - } else { - let state = &mut ctx.state; + // `unwrap` is safe because we know we have a session and an exchange, or else the early returns from above would've triggered + let session = session_mgr.get(session_id).unwrap(); + // Ditto + let exchange = session.exchanges[exch_index].as_mut().unwrap(); - match state { - ExchangeState::ExchangeRecv { - tx_acknowledged, .. - } => { - *tx_acknowledged = true; - } - ExchangeState::CompleteAcknowledge { notification, .. } => { - unsafe { notification.as_ref() }.unwrap().signal(()); - ctx.state = ExchangeState::Closed; - } - _ => { - // TODO: Error handling - todo!() - } - } - - self.notify_changed(); + if exchange.mrp.is_ack_pending() { + self.encode_packet(packet, Some(session), Some(exch_index), epoch, |_| { + Ok(Some(OpCode::MRPStandAloneAck.into())) + })?; } + + session.exchanges[exch_index] = None; + warn!("Dropped exchange {exchange_id}: Closed"); } - if new { - let constructor = ExchangeCtr { - exchange: Exchange { - id: ctx.id.clone(), - matter: self, - notification: Notification::new(), - }, - construction_notification, - }; + Ok(exch.is_none()) + } - self.notify_changed(); + pub(crate) async fn evict_some_session(&self) -> Result<(), Error> { + let mut tx = self.get_if(&self.tx, |packet| packet.buf.is_empty()).await; + tx.clear_on_drop(true); // By default, if an error occurs - Ok(Some(constructor)) - } else if src_rx.proto.proto_id == PROTO_ID_SECURE_CHANNEL - && src_rx.proto.proto_opcode == OpCode::MRPStandAloneAck as u8 - { - // Standalone ack, do nothing - Ok(None) + let evicted = self.encode_evict_some_session(&mut tx)?; + + if evicted { + // Send it + tx.clear_on_drop(false); + + Ok(()) } else { - let state = &mut ctx.state; + Err(ErrorCode::NoSpaceSessions.into()) + } + } - match state { - ExchangeState::ExchangeRecv { - rx, notification, .. - } => { - // TODO: Handle Busy status codes + fn decode_packet(&self, packet: &mut Packet) -> Result { + packet.header.reset(); - let rx = unsafe { rx.as_mut() }.unwrap(); - rx.load(src_rx)?; + let mut pb = ParseBuf::new(&mut packet.buf[packet.payload_start..]); + packet.header.plain.decode(&mut pb)?; - unsafe { notification.as_ref() }.unwrap().signal(()); - *state = ExchangeState::Active; - } - _ => { - // TODO: Error handling - todo!() - } + let mut session_mgr = self.session_mgr.borrow_mut(); + let epoch = session_mgr.epoch; + + let res = if let Some(session) = session_mgr.get_for_rx(&packet.peer, &packet.header.plain) + { + session.post_recv(&mut packet.header, &mut pb, epoch) + } else if !packet.header.plain.is_encrypted() { + let mut session = + session_mgr.add(false, packet.peer, packet.header.plain.get_src_nodeid()); + + if let Some(session) = session.as_mut() { + session.post_recv(&mut packet.header, &mut pb, epoch) + } else { + packet.header.decode_remaining(&mut pb, 0, None)?; + packet.header.proto.adjust_reliability(true, &packet.peer); + + Err(ErrorCode::NoSpaceSessions.into()) } + } else { + Err(ErrorCode::NoSession.into()) + }; - self.notify_changed(); + let range = pb.slice_range(); + packet.payload_start = range.0; + packet.buf.truncate(range.1); - Ok(None) - } + res } - pub async fn wait_construction( + fn encode_packet( &self, - construction_notification: &Notification, - src_rx: &Packet<'_>, - exchange_id: &ExchangeId, - ) -> Result<(), Error> { - construction_notification.wait().await; + packet: &mut Packet, + mut session: Option<&mut Session>, + exchange_index: Option, + epoch: Epoch, + payload_writer: F, + ) -> Result<(), Error> + where + F: FnOnce(&mut WriteBuf) -> Result, Error>, + { + // TODO: Resizing might be a bit expensive with large buffers + // Resizing to `N` is always safe because it is a responsibility of the caller to ensure that N is <= `MAX_RX_BUF_SIZE`, + // which is the size of `buf` heapless vec + packet.buf.resize_default(N).unwrap(); + + let mut wb = WriteBuf::new(&mut packet.buf); + wb.reserve(PacketHdr::HDR_RESERVE)?; - let mut exchanges = self.exchanges.borrow_mut(); + let Some(meta) = payload_writer(&mut wb)? else { + packet.buf.clear(); + return Ok(()); + }; - let ctx = ExchangeCtx::get(&mut exchanges, exchange_id).unwrap(); + meta.set_into(&mut packet.header.proto); - let state = &mut ctx.state; + let retransmission = if let Some(session) = &mut session { + packet.header.plain = Default::default(); - match state { - ExchangeState::Construction { rx, notification } => { - let rx = unsafe { rx.as_mut() }.unwrap(); - rx.load(src_rx)?; + let (peer, retransmission) = + session.pre_send(exchange_index, &mut packet.header, epoch)?; - unsafe { notification.as_ref() }.unwrap().signal(()); - *state = ExchangeState::Active; + packet.peer = peer; + + retransmission + } else { + if packet.header.plain.is_encrypted() + || packet.header.plain.get_src_nodeid().is_none() + || packet.header.proto.is_reliable() + { + // We can encode packets without a session only when they are unencrypted and do not need a retransmission + Err(ErrorCode::NoSession)?; } - _ => unreachable!(), + + let src_nodeid = packet.header.plain.get_src_nodeid(); + + packet.header.plain = Default::default(); + + packet.header.plain.sess_id = 0; + packet.header.plain.ctr = 1; + packet.header.plain.set_src_nodeid(None); + packet.header.plain.set_dst_unicast_nodeid(src_nodeid); + + packet.header.proto.unset_initiator(); + packet.header.proto.adjust_reliability(false, &packet.peer); + + false + }; + + info!( + "\n<<<<< {}\n => {} (system)", + Packet::<0>::display(&packet.peer, &packet.header), + if retransmission { + "Re-sending" + } else { + "Sending" + } + ); + + debug!( + "{}", + Packet::<0>::display_payload(&packet.header.proto, wb.as_slice()) + ); + + if let Some(session) = session { + session.encode(&packet.header, &mut wb)?; + } else { + packet.header.encode(&mut wb, 0, None)?; } + let range = (wb.get_start(), wb.get_tail()); + + packet.payload_start = range.0; + packet.buf.truncate(range.1); + Ok(()) } - pub async fn wait_tx(&self) -> Result<(), Error> { - select( - self.send_notification.wait(), - Timer::after(Duration::from_millis(100)), - ) - .await; + fn encode_evict_some_session( + &self, + packet: &mut Packet, + ) -> Result { + let mut session_mgr = self.session_mgr.borrow_mut(); + let id = session_mgr.get_session_for_eviction().map(|sess| sess.id); + if let Some(id) = id { + self.encode_evict_session(packet, &mut session_mgr, id)?; + + Ok(true) + } else { + Ok(false) + } + } + + fn encode_evict_session( + &self, + packet: &mut Packet, + session_mgr: &mut SessionMgr, + id: u32, + ) -> Result<(), Error> { + packet.header.proto.exch_id = session_mgr.get_next_exch_id(); + packet.header.proto.set_initiator(); + + // It is a responsibility of the caller to ensure that this method is called with a valid session ID + let mut session = session_mgr.remove(id).unwrap(); + + self.encode_packet(packet, Some(&mut session), None, session_mgr.epoch, |wb| { + sc_write(wb, SCStatusCodes::CloseSession, &[]) + })?; Ok(()) } - pub fn pull_tx(&self, dest_tx: &mut Packet) -> Result { - self.purge()?; + fn is_close_session(payload: &mut [u8]) -> Result { + let mut pb = ParseBuf::new(payload); + let report = StatusReport::read(&mut pb)?; - let mut ephemeral = self.ephemeral.borrow_mut(); - let mut exchanges = self.exchanges.borrow_mut(); + let close_session = report.proto_id == PROTO_ID_SECURE_CHANNEL as _ + && report.proto_code == SCStatusCodes::CloseSession as u16; - self.pull_tx_exchanges(ephemeral.iter_mut().chain(exchanges.iter_mut()), dest_tx) + Ok(close_session) } - fn pull_tx_exchanges<'i, I>( - &self, - mut exchanges: I, - dest_tx: &mut Packet, - ) -> Result + async fn netw_recv(mut recv: R, buf: &mut [u8]) -> Result<(usize, Address), Error> where - I: Iterator, + R: NetworkReceive, { - let ctx = exchanges.find(|ctx| { - matches!( - &ctx.state, - ExchangeState::Acknowledge { .. } - | ExchangeState::ExchangeSend { .. } - // | ExchangeState::ExchangeRecv { - // tx_acknowledged: false, - // .. - // } - | ExchangeState::Complete { .. } // | ExchangeState::CompleteAcknowledge { .. } - ) || ctx.mrp.is_ack_ready(*self.borrow()) - }); - - if let Some(ctx) = ctx { - self.notify_changed(); - - let state = &mut ctx.state; - - let send = match state { - ExchangeState::Acknowledge { notification } => { - ReliableMessage::prepare_ack(ctx.id.id, dest_tx); - - unsafe { notification.as_ref() }.unwrap().signal(()); - *state = ExchangeState::Active; - - true - } - ExchangeState::ExchangeSend { - tx, - rx, - notification, - } => { - let tx = unsafe { tx.as_ref() }.unwrap(); - dest_tx.load(tx)?; - - *state = ExchangeState::ExchangeRecv { - _tx: tx, - tx_acknowledged: false, - rx: *rx, - notification: *notification, - }; - - true - } - // ExchangeState::ExchangeRecv { .. } => { - // // TODO: Re-send the tx package if due - // false - // } - ExchangeState::Complete { tx, notification } => { - let tx = unsafe { tx.as_ref() }.unwrap(); - dest_tx.load(tx)?; - - if dest_tx.is_reliable() { - *state = ExchangeState::CompleteAcknowledge { - _tx: tx as *const _, - notification: *notification, - }; - } else { - unsafe { notification.as_ref() }.unwrap().signal(()); - ctx.state = ExchangeState::Closed; - } + match recv.recv_from(buf).await { + Ok((len, addr)) => { + debug!("\n>>>>> {} {}B:\n{:02x?}", addr, len, &buf[..len]); - true - } - // ExchangeState::CompleteAcknowledge { .. } => { - // // TODO: Re-send the tx package if due - // false - // } - _ => { - ReliableMessage::prepare_ack(ctx.id.id, dest_tx); - true - } - }; + Ok((len, addr)) + } + Err(e) => { + error!("FAILED network recv: {e:?}"); + + Err(e) + } + } + } + + async fn netw_send( + send: &IfMutex, + peer: Address, + data: &[u8], + system: bool, + ) -> Result<(), Error> + where + S: NetworkSend, + { + match send.lock().await.send_to(data, peer).await { + Ok(_) => { + debug!( + "\n<<<<< {} {}B{}: {:02x?}", + peer, + data.len(), + if system { " (system)" } else { "" }, + data + ); - if send { - dest_tx.log("Sending packet"); - self.notify_changed(); + Ok(()) + } + Err(e) => { + error!( + "\n<<<<< {} {}B{} !FAILED!: {e:?}: {:02x?}", + peer, + data.len(), + if system { " (system)" } else { "" }, + data + ); - return Ok(true); + Err(e) } } + } +} - Ok(false) +// The internal representation of a packet in the transport layer. +// There are only two such packets - RX and TX. +// +// This type is only known and used by `TransportMgr` and the `exchange` module +pub(crate) struct Packet { + pub(crate) peer: Address, + pub(crate) header: PacketHdr, + pub(crate) buf: PacketBuffer, + pub(crate) payload_start: usize, +} + +impl Packet { + #[inline(always)] + pub(crate) const fn new() -> Self { + Self { + peer: Address::new(), + header: PacketHdr::new(), + buf: PacketBuffer::new(), + payload_start: 0, + } } - fn purge(&self) -> Result<(), Error> { - loop { - let mut exchanges = self.exchanges.borrow_mut(); + pub fn display<'a>(peer: &'a Address, header: &'a PacketHdr) -> impl Display + 'a { + struct PacketInfo<'a>(&'a Address, &'a PacketHdr); - if let Some(index) = exchanges.iter_mut().enumerate().find_map(|(index, ctx)| { - matches!(ctx.state, ExchangeState::Closed).then_some(index) - }) { - exchanges.swap_remove(index); - } else { - break; + impl<'a> Display for PacketInfo<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + Packet::<0>::fmt(f, self.0, self.1) } } - Ok(()) + PacketInfo(peer, header) } - pub(crate) async fn evict_session(&self, tx: &mut Packet<'_>) -> Result<(), Error> { - let sess_index = self.session_mgr.borrow().get_session_for_eviction(); - if let Some(sess_index) = sess_index { - let ctx = { - create_status_report( - tx, - GeneralCode::Success, - PROTO_ID_SECURE_CHANNEL as _, - SCStatusCodes::CloseSession as _, - None, - )?; + pub fn display_payload<'a>(proto: &'a ProtoHdr, buf: &'a [u8]) -> impl Display + 'a { + struct PacketInfo<'a>(&'a ProtoHdr, &'a [u8]); - let mut session_mgr = self.session_mgr.borrow_mut(); - let session_id = session_mgr.mut_by_index(sess_index).unwrap().id(); - warn!("Evicting session: {:?}", session_id); + impl<'a> Display for PacketInfo<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + Packet::<0>::fmt_payload(f, self.0, self.1) + } + } - let ctx = ExchangeCtx::prep_ephemeral(session_id, &mut session_mgr, None, tx)?; + PacketInfo(proto, buf) + } - session_mgr.remove(sess_index); + fn fmt(f: &mut fmt::Formatter<'_>, peer: &Address, header: &PacketHdr) -> fmt::Result { + let meta = MessageMeta::from(&header.proto); - ctx - }; + write!(f, "{peer} {header}\n{meta}") + } + + fn fmt_payload(f: &mut fmt::Formatter<'_>, proto: &ProtoHdr, buf: &[u8]) -> fmt::Result { + let meta = MessageMeta::from(proto); + + write!(f, "{meta}")?; - self.send_ephemeral(ctx, tx).await + if meta.is_tlv() { + write!( + f, + "; TLV:\n----------------\n{}\n----------------\n", + TLVList::new(buf) + )?; } else { - Err(ErrorCode::NoSpaceSessions.into()) + write!( + f, + "; Payload:\n----------------\n{:02x?}\n----------------\n", + buf + )?; } + + Ok(()) } +} - async fn send_busy(&self, rx: &Packet<'_>, tx: &mut Packet<'_>) -> Result<(), Error> { - warn!("Sending Busy as all exchanges are occupied"); +impl Display for Packet { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + Self::fmt(f, &self.peer, &self.header) + } +} - create_status_report( - tx, - GeneralCode::Busy, - rx.get_proto_id() as _, - if rx.get_proto_id() == PROTO_ID_SECURE_CHANNEL { - SCStatusCodes::Busy as _ - } else { - IMStatusCode::Busy as _ - }, - None, // TODO: ms - )?; +// The buffer used inside the pair of RX and TX `Packet` instances +// When the `alloc` and `large-buffers` features are enabled, the buffer payload is allocated on the heap +// +// This type is only known and used by `TransportMgr` and the `exchange` module +#[cfg(all(feature = "large-buffers", feature = "alloc"))] +pub(crate) struct PacketBuffer(Option>>); + +// The buffer used inside the pair of RX and TX `Packet` instances +// When the either of the `alloc` and `large-buffers` features is not enabled, the buffer payload is allocated inline +// +// This type is only known and used by `TransportMgr` and the `exchange` module +#[cfg(not(all(feature = "large-buffers", feature = "alloc")))] +pub(crate) struct PacketBuffer(heapless::Vec); + +impl PacketBuffer { + #[cfg(all(feature = "large-buffers", feature = "alloc"))] + pub const fn new() -> Self { + Self(None) + } - let ctx = ExchangeCtx::prep_ephemeral( - SessionId::load(rx), - &mut self.session_mgr.borrow_mut(), - Some(rx), - tx, - )?; + #[cfg(not(all(feature = "large-buffers", feature = "alloc")))] + pub const fn new() -> Self { + Self(heapless::Vec::new()) + } - self.send_ephemeral(ctx, tx).await + #[cfg(all(feature = "large-buffers", feature = "alloc"))] + pub fn buf_mut(&mut self) -> &mut heapless::Vec { + &mut *self + .0 + .as_mut() + .expect("Buffer is not allocated. Did you forget to call `initialize_buffers`?") } - async fn send_ephemeral(&self, mut ctx: ExchangeCtx, tx: &mut Packet<'_>) -> Result<(), Error> { - let _guard = self.ephemeral_mutex.lock().await; + #[cfg(not(all(feature = "large-buffers", feature = "alloc")))] + pub fn buf_mut(&mut self) -> &mut heapless::Vec { + &mut self.0 + } - let notification = Notification::new(); + #[cfg(all(feature = "large-buffers", feature = "alloc"))] + pub fn buf_ref(&self) -> &heapless::Vec { + self.0 + .as_ref() + .expect("Buffer is not allocated. Did you forget to call `initialize_buffers`?") + } - let tx: &'static mut Packet<'static> = unsafe { core::mem::transmute(tx) }; + #[cfg(not(all(feature = "large-buffers", feature = "alloc")))] + pub fn buf_ref(&self) -> &heapless::Vec { + &self.0 + } +} - ctx.state = ExchangeState::Complete { - tx, - notification: ¬ification, - }; +impl Deref for PacketBuffer { + type Target = heapless::Vec; - *self.ephemeral.borrow_mut() = Some(ctx); + fn deref(&self) -> &Self::Target { + self.buf_ref() + } +} - self.send_notification.signal(()); +impl DerefMut for PacketBuffer { + fn deref_mut(&mut self) -> &mut Self::Target { + self.buf_mut() + } +} - notification.wait().await; +// Represents the fact that either `TransportMgr` or some `Exchange` instace has an exclusive access to the +// RX or TX packet of the transport layer. +// +// At any point in time, either the `TransportMgr` singleton, or exactly one `Exchange` instance, or nobody +// holds a lock on the RX or TX packet. This is enforced by protecting the packets with an `IfMutex` asynchronous mutex. +// +// This type is only known and used by `TransportMgr` and the `exchange` module +pub(crate) struct PacketAccess<'a, const N: usize>(IfMutexGuard<'a, NoopRawMutex, Packet>, bool); + +impl<'a, const N: usize> PacketAccess<'a, N> { + pub fn clear_on_drop(&mut self, clear: bool) { + self.1 = clear; + } +} - *self.ephemeral.borrow_mut() = None; +impl<'a, const N: usize> Deref for PacketAccess<'a, N> { + type Target = Packet; - Ok(()) + fn deref(&self) -> &Self::Target { + &self.0 } +} - fn assign_exchange( - &self, - exchanges: &mut heapless::Vec, - rx: &mut Packet<'_>, - ) -> Result<(usize, bool), Error> { - // Get the session +impl<'a, const N: usize> DerefMut for PacketAccess<'a, N> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} - let mut session_mgr = self.session_mgr.borrow_mut(); +impl<'a, const N: usize> Drop for PacketAccess<'a, N> { + fn drop(&mut self) { + if self.1 { + self.buf.clear(); + } + } +} - let sess_index = session_mgr.post_recv(rx)?; - let session = session_mgr.mut_by_index(sess_index).unwrap(); - - // Decrypt the message - session.recv(self.epoch, rx)?; - - // Get the exchange - let (exchange_index, new) = Self::register( - exchanges, - ExchangeId::load(rx), - Role::complementary(rx.proto.is_initiator()), - // We create a new exchange, only if the peer is the initiator - rx.proto.is_initiator(), - )?; - - // Message Reliability Protocol - exchanges[exchange_index].mrp.recv(rx, self.epoch)?; - - Ok((exchange_index, new)) - } - - fn register( - exchanges: &mut heapless::Vec, - id: ExchangeId, - role: Role, - create_new: bool, - ) -> Result<(usize, bool), Error> { - let exchange_index = exchanges - .iter_mut() - .enumerate() - .find_map(|(index, exchange)| (exchange.id == id).then_some(index)); - - if let Some(exchange_index) = exchange_index { - let exchange = &mut exchanges[exchange_index]; - if exchange.role == role { - Ok((exchange_index, false)) - } else { - Err(ErrorCode::NoExchange.into()) - } - } else if create_new { - info!("Creating new exchange: {:?}", id); - - let exchange = ExchangeCtx { - id, - role, - mrp: ReliableMessage::new(), - state: ExchangeState::Active, - }; +impl<'a, const N: usize> Display for PacketAccess<'a, N> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} - exchanges - .push(exchange) - .map_err(|_| ErrorCode::NoSpaceExchanges)?; +// Allows other code in `rs-matter` to (ab)use the packet buffers of the transport layer +// in case it needs temporary access to a `&mut [u8]`-shaped memory +// +// Used by the builtin mDNS responder, as well as by the QR code generator +pub(crate) struct PacketBufferExternalAccess<'a, const N: usize>( + pub(crate) &'a IfMutex>, +); - Ok((exchanges.len() - 1, true)) - } else { - Err(ErrorCode::NoExchange.into()) - } +impl<'a, const N: usize> BufferAccess<[u8]> for PacketBufferExternalAccess<'a, N> { + type Buffer<'b> = ExternalPacketBuffer<'b, N> where Self: 'b; + + async fn get(&self) -> Option> { + let mut packet = self.0.lock_if(|packet| packet.buf.is_empty()).await; + + // TODO: Resizing might be a bit expensive with large buffers + // Resizing to `N` is always safe because the size of `buf` heapless vec is `N` + packet.buf.resize_default(N).unwrap(); + + Some(ExternalPacketBuffer(packet)) + } +} + +// Wraps the RX or TX packet of the transport manager in something that looks like a `&mut [u8]` buffer. +pub struct ExternalPacketBuffer<'a, const N: usize>(IfMutexGuard<'a, NoopRawMutex, Packet>); + +impl<'a, const N: usize> Deref for ExternalPacketBuffer<'a, N> { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + &self.0.buf + } +} + +impl<'a, const N: usize> DerefMut for ExternalPacketBuffer<'a, N> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0.buf + } +} + +impl<'a, const N: usize> Drop for ExternalPacketBuffer<'a, N> { + fn drop(&mut self) { + self.0.buf.clear(); } } diff --git a/rs-matter/src/transport/dedup.rs b/rs-matter/src/transport/dedup.rs index f2c382e8..53abf848 100644 --- a/rs-matter/src/transport/dedup.rs +++ b/rs-matter/src/transport/dedup.rs @@ -39,24 +39,26 @@ impl RxCtrState { self.ctr_bitmap |= 1 << bit_number; } - /// Receive a message and update Rx State accordingly - /// Returns a bool indicating whether the message is a duplicate - pub fn recv(&mut self, msg_ctr: u32, is_encrypted: bool) -> bool { + /// Receive a message and update RX state accordingly + /// + /// The method will return `false` if the message is detected to be duplicate, and therefore, + /// the RX state had not been updated. + pub fn post_recv(&mut self, msg_ctr: u32, is_encrypted: bool) -> bool { let idiff = (msg_ctr as i32) - (self.max_ctr as i32); let udiff = idiff.unsigned_abs(); if msg_ctr == self.max_ctr { // Duplicate - true + false } else if (-(MSG_RX_STATE_BITMAP_LEN as i32)..0).contains(&idiff) { // In Rx Bitmap let index = udiff - 1; if self.contains(index) { // Duplicate - true + false } else { self.insert(index); - false + true } } // Now the leftover cases are the new counter is outside of the bitmap as well as max_ctr @@ -70,15 +72,15 @@ impl RxCtrState { } else { self.ctr_bitmap = 0xffff; } - false + true } else if !is_encrypted { // This is the case where the peer possibly rebooted and chose a different // random counter self.max_ctr = msg_ctr; self.ctr_bitmap = 0xffff; - false - } else { true + } else { + false } } } @@ -94,27 +96,27 @@ mod tests { const NOT_ENCRYPTED: bool = false; fn assert_ndup(b: bool) { - assert!(!b); + assert!(b); } fn assert_dup(b: bool) { - assert!(b); + assert!(!b); } #[test] fn new_msg_ctr() { let mut s = RxCtrState::new(101); - assert_ndup(s.recv(103, ENCRYPTED)); - assert_ndup(s.recv(104, ENCRYPTED)); - assert_ndup(s.recv(106, ENCRYPTED)); + assert_ndup(s.post_recv(103, ENCRYPTED)); + assert_ndup(s.post_recv(104, ENCRYPTED)); + assert_ndup(s.post_recv(106, ENCRYPTED)); assert_eq!(s.max_ctr, 106); assert_eq!(s.ctr_bitmap, 0b1111_1111_1111_0110); - assert_ndup(s.recv(118, NOT_ENCRYPTED)); + assert_ndup(s.post_recv(118, NOT_ENCRYPTED)); assert_eq!(s.ctr_bitmap, 0b0110_1000_0000_0000); - assert_ndup(s.recv(119, NOT_ENCRYPTED)); - assert_ndup(s.recv(121, NOT_ENCRYPTED)); + assert_ndup(s.post_recv(119, NOT_ENCRYPTED)); + assert_ndup(s.post_recv(121, NOT_ENCRYPTED)); assert_eq!(s.ctr_bitmap, 0b0100_0000_0000_0110); } @@ -122,9 +124,9 @@ mod tests { fn dup_max_ctr() { let mut s = RxCtrState::new(101); - assert_ndup(s.recv(103, ENCRYPTED)); - assert_dup(s.recv(103, ENCRYPTED)); - assert_dup(s.recv(103, NOT_ENCRYPTED)); + assert_ndup(s.post_recv(103, ENCRYPTED)); + assert_dup(s.post_recv(103, ENCRYPTED)); + assert_dup(s.post_recv(103, NOT_ENCRYPTED)); assert_eq!(s.max_ctr, 103); assert_eq!(s.ctr_bitmap, 0b1111_1111_1111_1110); @@ -136,24 +138,24 @@ mod tests { let mut s = RxCtrState::new(101); for _ in 1..8 { ctr += 2; - assert_ndup(s.recv(ctr, ENCRYPTED)); + assert_ndup(s.post_recv(ctr, ENCRYPTED)); } - assert_ndup(s.recv(116, ENCRYPTED)); - assert_ndup(s.recv(117, ENCRYPTED)); + assert_ndup(s.post_recv(116, ENCRYPTED)); + assert_ndup(s.post_recv(117, ENCRYPTED)); assert_eq!(s.max_ctr, 117); assert_eq!(s.ctr_bitmap, 0b1010_1010_1010_1011); // duplicate on the left corner - assert_dup(s.recv(101, ENCRYPTED)); - assert_dup(s.recv(101, NOT_ENCRYPTED)); + assert_dup(s.post_recv(101, ENCRYPTED)); + assert_dup(s.post_recv(101, NOT_ENCRYPTED)); // duplicate on the right corner - assert_dup(s.recv(116, ENCRYPTED)); - assert_dup(s.recv(116, NOT_ENCRYPTED)); + assert_dup(s.post_recv(116, ENCRYPTED)); + assert_dup(s.post_recv(116, NOT_ENCRYPTED)); // valid insert - assert_ndup(s.recv(102, ENCRYPTED)); - assert_dup(s.recv(102, ENCRYPTED)); + assert_ndup(s.post_recv(102, ENCRYPTED)); + assert_dup(s.post_recv(102, ENCRYPTED)); assert_eq!(s.ctr_bitmap, 0b1110_1010_1010_1011); } @@ -163,17 +165,17 @@ mod tests { let mut s = RxCtrState::new(101); for _ in 1..9 { ctr += 2; - assert_ndup(s.recv(ctr, ENCRYPTED)); + assert_ndup(s.post_recv(ctr, ENCRYPTED)); } assert_eq!(s.max_ctr, 118); assert_eq!(s.ctr_bitmap, 0b0010_1010_1010_1010); // valid insert on the left corner - assert_ndup(s.recv(102, ENCRYPTED)); + assert_ndup(s.post_recv(102, ENCRYPTED)); assert_eq!(s.ctr_bitmap, 0b1010_1010_1010_1010); // valid insert on the right corner - assert_ndup(s.recv(117, ENCRYPTED)); + assert_ndup(s.post_recv(117, ENCRYPTED)); assert_eq!(s.ctr_bitmap, 0b1010_1010_1010_1011); } @@ -181,17 +183,17 @@ mod tests { fn encrypted_wraparound() { let mut s = RxCtrState::new(65534); - assert_ndup(s.recv(65535, ENCRYPTED)); - assert_ndup(s.recv(65536, ENCRYPTED)); - assert_dup(s.recv(0, ENCRYPTED)); + assert_ndup(s.post_recv(65535, ENCRYPTED)); + assert_ndup(s.post_recv(65536, ENCRYPTED)); + assert_dup(s.post_recv(0, ENCRYPTED)); } #[test] fn unencrypted_wraparound() { let mut s = RxCtrState::new(65534); - assert_ndup(s.recv(65536, NOT_ENCRYPTED)); - assert_ndup(s.recv(0, NOT_ENCRYPTED)); + assert_ndup(s.post_recv(65536, NOT_ENCRYPTED)); + assert_ndup(s.post_recv(0, NOT_ENCRYPTED)); } #[test] @@ -202,7 +204,7 @@ mod tests { info!("Sub regular is {:?}", 2000_u16.overflowing_sub(1998)); let mut s = RxCtrState::new(20010); - assert_ndup(s.recv(20011, NOT_ENCRYPTED)); - assert_ndup(s.recv(0, NOT_ENCRYPTED)); + assert_ndup(s.post_recv(20011, NOT_ENCRYPTED)); + assert_ndup(s.post_recv(0, NOT_ENCRYPTED)); } } diff --git a/rs-matter/src/transport/exchange.rs b/rs-matter/src/transport/exchange.rs index d6291b04..1b0b478c 100644 --- a/rs-matter/src/transport/exchange.rs +++ b/rs-matter/src/transport/exchange.rs @@ -1,432 +1,1046 @@ -use crate::{ - acl::Accessor, - error::{Error, ErrorCode}, - utils::{epoch::Epoch, select::Notification}, - Matter, -}; - -use super::{ - mrp::ReliableMessage, - network::Address, - packet::Packet, - session::{CloneData, Session, SessionMgr}, -}; - -pub const MAX_EXCHANGES: usize = 8; +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use core::cmp::max; +use core::fmt::{self, Display}; +use core::pin::pin; + +use embassy_futures::select::{select, Either}; +use embassy_time::Timer; + +use log::{debug, error, info, warn}; + +use crate::acl::Accessor; +use crate::error::{Error, ErrorCode}; +use crate::interaction_model::{self, core::PROTO_ID_INTERACTION_MODEL}; +use crate::secure_channel::{self, common::PROTO_ID_SECURE_CHANNEL}; +use crate::utils::{epoch::Epoch, writebuf::WriteBuf}; +use crate::Matter; + +use super::core::{Packet, PacketAccess, MAX_RX_BUF_SIZE, MAX_TX_BUF_SIZE}; +use super::mrp::{ReliableMessage, RetransEntry}; +use super::network; +use super::packet::PacketHdr; +use super::plain_hdr::PlainHdr; +use super::proto_hdr::ProtoHdr; +use super::session::Session; + +/// Minimum buffer which should be allocated by user code that wants to pull RX messages via `Exchange::recv_into` +// TODO: Revisit with large packets +pub const MAX_EXCHANGE_RX_BUF_SIZE: usize = network::MAX_RX_PACKET_SIZE; + +/// Maximum buffer which should be allocated and used by user code that wants to send messages via `Exchange::send` +// TODO: Revisit with large packets +pub const MAX_EXCHANGE_TX_BUF_SIZE: usize = + network::MAX_TX_PACKET_SIZE - PacketHdr::HDR_RESERVE - PacketHdr::TAIL_RESERVE; + +/// An exchange identifier, uniquely identifying a session and an exchange within that session for a given Matter stack. +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub struct ExchangeId(u32); + +impl ExchangeId { + pub(crate) fn new(session_id: u32, exchange_index: usize) -> Self { + if session_id > 0x0fff_ffff { + panic!("Session ID out of range"); + } + + if exchange_index >= 16 { + panic!("Exchange index out of range"); + } + + Self((exchange_index as u32) << 28 | session_id) + } + + pub(crate) fn session_id(&self) -> u32 { + self.0 & 0x0fff_ffff + } + + pub(crate) fn exchange_index(&self) -> usize { + (self.0 >> 28) as _ + } + + async fn recv<'a>(&self, matter: &'a Matter<'a>) -> Result, Error> { + self.check_no_pending_retrans(matter)?; + + let transport_mgr = &matter.transport_mgr; + + let mut packet = transport_mgr + .get_if(&transport_mgr.rx, |packet| { + if packet.buf.is_empty() { + false + } else { + let for_us = self.with_ctx(matter, |sess, exch_index| { + if sess.is_for_rx(&packet.peer, &packet.header.plain) { + let exchange = sess.exchanges[exch_index].as_ref().unwrap(); + + return Ok(exchange.is_for_rx(&packet.header.proto)); + } + + Ok(false) + }); + + for_us.unwrap_or(true) + } + }) + .await; + + packet.clear_on_drop(true); + + self.check_no_pending_retrans(matter)?; + + Ok(RxMessage(packet)) + } + + /// Gets access to the TX buffer of the Matter stack for constructing a new TX message. + /// If the TX buffer is not available, the method will wait indefinitely until it becomes available. + /// + /// NOTE: + /// This is a low-level method that leaves the re-transmission logic on the shoulders of the user. + /// Therefore, prefer using `Exchange::sender`, `Exchange::send` or `Exchange::send_with` instead. + /// + /// Note also that if the uderlying session or exchange tracked by the Matter stack is dropped + /// (say, because of lack of resources or a hard networking error), the method will return an error. + async fn init_send<'a>(&self, matter: &'a Matter<'a>) -> Result, Error> { + self.with_ctx(matter, |_, _| Ok(()))?; + + let transport_mgr = &matter.transport_mgr; + + let mut packet = transport_mgr + .get_if(&transport_mgr.tx, |packet| { + packet.buf.is_empty() || self.with_ctx(matter, |_, _| Ok(())).is_err() + }) + .await; + + // TODO: Resizing might be a bit expensive with large buffers + packet.buf.resize_default(MAX_TX_BUF_SIZE).unwrap(); + + packet.clear_on_drop(true); + + let tx = TxMessage { + exchange_id: *self, + matter, + packet, + }; + + self.with_ctx(matter, |_, _| Ok(()))?; + + Ok(tx) + } + + /// Waits until the other side acknowledges the last message sent on this exchange, + /// or until time for a re-transmission had come. + /// + /// If the last sent message was not using the MRP protocol, the method will return immediately with `TxOutcome::Done`. + /// + /// NOTE: + /// This is a low-level method that leaves the re-transmission logic on the shoulders of the user. + /// Therefore, prefer using `Exchange::sender`, `Exchange::send` or `Exchange::send_with` instead. + /// + /// Note also that if the uderlying session or exchange tracked by the Matter stack is dropped + /// (say, because of lack of resources or a hard networking error), the method will return an error. + async fn wait_tx<'a>(&self, matter: &'a Matter<'a>) -> Result { + if let Some(delay) = self.retrans_delay_ms(matter)? { + let mut notification = pin!(self.internal_wait_ack(matter)); + let mut timer = pin!(Timer::after(embassy_time::Duration::from_millis(delay))); + + select(&mut notification, &mut timer).await; + + if self.retrans_delay_ms(matter)?.is_some() { + Ok(TxOutcome::Retransmit) + } else { + Ok(TxOutcome::Done) + } + } else { + Ok(TxOutcome::Done) + } + } + + fn accessor<'a>(&self, matter: &'a Matter<'a>) -> Result, Error> { + self.with_session(matter, |sess| { + Ok(Accessor::for_session(sess, &matter.acl_mgr)) + }) + } + + fn with_session<'a, F, T>(&self, matter: &'a Matter<'a>, f: F) -> Result + where + F: FnOnce(&mut Session) -> Result, + { + self.with_ctx(matter, |sess, _| f(sess)) + } + + fn with_ctx<'a, F, T>(&self, matter: &'a Matter<'a>, f: F) -> Result + where + F: FnOnce(&mut Session, usize) -> Result, + { + let mut session_mgr = matter.transport_mgr.session_mgr.borrow_mut(); + + if let Some(session) = session_mgr.get(self.session_id()) { + f(session, self.exchange_index()) + } else { + warn!("Exchange {}: No session", self); + Err(ErrorCode::NoSession.into()) + } + } + + async fn internal_wait_ack<'a>(&self, matter: &'a Matter<'a>) -> Result<(), Error> { + let transport_mgr = &matter.transport_mgr; + + transport_mgr + .get_if(&transport_mgr.rx, |_| { + self.retrans_delay_ms(matter) + .map(|retrans| retrans.is_none()) + .unwrap_or(true) + }) + .await; + + self.with_ctx(matter, |_, _| Ok(())) + } + + fn retrans_delay_ms<'a>(&self, matter: &'a Matter<'a>) -> Result, Error> { + self.with_ctx(matter, |sess, exch_index| { + let exchange = sess.exchanges[exch_index].as_mut().unwrap(); + + Ok(exchange.retrans_delay_ms()) + }) + } + + fn check_no_pending_retrans<'a>(&self, matter: &'a Matter<'a>) -> Result<(), Error> { + self.with_ctx(matter, |sess, exch_index| { + let exchange = sess.exchanges[exch_index].as_mut().unwrap(); + + if exchange.mrp.is_retrans_pending() { + error!("Exchange {}: Retransmission pending", self); + Err(ErrorCode::InvalidState)?; + } + + Ok(()) + }) + } + + fn pending_retrans<'a>(&self, matter: &'a Matter<'a>) -> Result { + Ok(self.retrans_delay_ms(matter)?.is_some()) + } + + fn pending_ack<'a>(&self, matter: &'a Matter<'a>) -> Result { + self.with_ctx(matter, |sess, exch_index| { + let exchange = sess.exchanges[exch_index].as_ref().unwrap(); + + Ok(exchange.mrp.is_ack_pending()) + }) + } +} + +impl Display for ExchangeId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}::{}", self.session_id(), self.exchange_index()) + } +} #[derive(Debug, PartialEq, Eq, Copy, Clone, Default)] -pub(crate) enum Role { +pub(crate) enum InitiatorState { + #[default] + Owned, + Dropped, +} + +#[derive(Debug, PartialEq, Eq, Copy, Clone, Default)] +pub(crate) enum ResponderState { #[default] - Initiator = 0, - Responder = 1, + AcceptPending, + Owned, + Dropped, +} + +#[derive(Debug, PartialEq, Eq, Copy, Clone)] +pub(crate) enum Role { + Initiator(InitiatorState), + Responder(ResponderState), } impl Role { - pub fn complementary(is_initiator: bool) -> Self { - if is_initiator { - Self::Responder - } else { - Self::Initiator + pub fn is_dropped_state(&self) -> bool { + match self { + Self::Initiator(state) => *state == InitiatorState::Dropped, + Self::Responder(state) => *state == ResponderState::Dropped, + } + } + + pub fn set_dropped_state(&mut self) { + match self { + Self::Initiator(state) => *state = InitiatorState::Dropped, + Self::Responder(state) => *state = ResponderState::Dropped, } } } #[derive(Debug)] -pub(crate) struct ExchangeCtx { - pub(crate) id: ExchangeId, +pub(crate) struct ExchangeState { + pub(crate) exch_id: u16, pub(crate) role: Role, pub(crate) mrp: ReliableMessage, - pub(crate) state: ExchangeState, } -impl ExchangeCtx { - pub(crate) fn get<'r>( - exchanges: &'r mut heapless::Vec, - id: &ExchangeId, - ) -> Option<&'r mut ExchangeCtx> { - exchanges.iter_mut().find(|exchange| exchange.id == *id) +impl ExchangeState { + pub fn is_for_rx(&self, rx_proto: &ProtoHdr) -> bool { + self.exch_id == rx_proto.exch_id + && rx_proto.is_initiator() == matches!(self.role, Role::Responder(_)) } - pub fn new_ephemeral(session_id: SessionId, reply_to: Option<&Packet<'_>>) -> Self { - Self { - id: ExchangeId { - id: if let Some(rx) = reply_to { - rx.proto.exch_id - } else { - 0 - }, - session_id: session_id.clone(), - }, - role: if reply_to.is_some() { - Role::Responder - } else { - Role::Initiator - }, - mrp: ReliableMessage::new(), - state: ExchangeState::Active, + pub fn post_recv( + &mut self, + rx_plain: &PlainHdr, + rx_proto: &ProtoHdr, + epoch: Epoch, + ) -> Result<(), Error> { + self.mrp.post_recv(rx_plain, rx_proto, epoch)?; + + Ok(()) + } + + pub fn pre_send( + &mut self, + tx_plain: &PlainHdr, + tx_proto: &mut ProtoHdr, + epoch: Epoch, + ) -> Result<(), Error> { + if matches!(self.role, Role::Initiator(_)) { + tx_proto.set_initiator(); + } else { + tx_proto.unset_initiator(); } + + tx_proto.exch_id = self.exch_id; + + self.mrp.pre_send(tx_plain, tx_proto, epoch) + } + + pub fn retrans_delay_ms(&mut self) -> Option { + self.mrp.retrans.as_ref().map(RetransEntry::delay_ms) } +} - pub(crate) fn prep_ephemeral( - session_id: SessionId, - session_mgr: &mut SessionMgr, - reply_to: Option<&Packet<'_>>, - tx: &mut Packet<'_>, - ) -> Result { - let mut ctx = Self::new_ephemeral(session_id.clone(), reply_to); +/// Meta-data when sending/receving messages via an Exchange. +/// Basically, the protocol ID, the protocol opcode and whether the message should be set in a reliable manner. +#[derive(Debug, Eq, PartialEq, Copy, Clone)] +pub struct MessageMeta { + pub proto_id: u16, + pub proto_opcode: u8, + pub reliable: bool, +} - let sess_index = session_mgr.get( - session_id.id, - session_id.peer_addr, - session_id.peer_nodeid, - session_id.is_encrypted, - ); +impl MessageMeta { + // Create a new message meta-data instance + pub const fn new(proto_id: u16, proto_opcode: u8, reliable: bool) -> Self { + Self { + proto_id, + proto_opcode, + reliable, + } + } - let epoch = session_mgr.epoch; - let rand = session_mgr.rand; + /// Try to cast the protocol opcode to a specific type + pub fn opcode(&self) -> Result { + num::FromPrimitive::from_u8(self.proto_opcode).ok_or(ErrorCode::InvalidOpcode.into()) + } - if let Some(rx) = reply_to { - ctx.mrp.recv(rx, epoch)?; + /// Check if the protocol opcode is equal to a specific value + pub fn check_opcode(&self, opcode: T) -> Result<(), Error> { + if self.opcode::()? == opcode { + Ok(()) } else { - tx.proto.set_initiator(); + Err(ErrorCode::Invalid.into()) + } + } + + /// Create an instance from a ProtoHdr instance + pub fn from(proto: &ProtoHdr) -> Self { + Self { + proto_id: proto.proto_id, + proto_opcode: proto.proto_opcode, + reliable: proto.is_reliable(), } + } - tx.unset_reliable(); + /// Set the protocol ID and opcode into a ProtoHdr instance + pub fn set_into(&self, proto: &mut ProtoHdr) { + proto.proto_id = self.proto_id; + proto.proto_opcode = self.proto_opcode; + proto.set_vendor(None); - if let Some(sess_index) = sess_index { - let session = session_mgr.mut_by_index(sess_index).unwrap(); - ctx.pre_send_sess(session, tx, epoch)?; + if self.reliable { + proto.set_reliable(); } else { - let mut session = - Session::new(session_id.peer_addr, session_id.peer_nodeid, epoch, rand); - ctx.pre_send_sess(&mut session, tx, epoch)?; + proto.unset_reliable(); } + } - Ok(ctx) + pub fn reliable(self, reliable: bool) -> Self { + Self { reliable, ..self } } - pub(crate) fn pre_send( - &mut self, - session_mgr: &mut SessionMgr, - tx: &mut Packet, - ) -> Result<(), Error> { - let epoch = session_mgr.epoch; - - let sess_index = session_mgr - .get( - self.id.session_id.id, - self.id.session_id.peer_addr, - self.id.session_id.peer_nodeid, - self.id.session_id.is_encrypted, - ) - .ok_or(ErrorCode::NoSession)?; + /// Utility method to check if the specific proto opcode in the instance is expecting a TLV payload. + pub(crate) fn is_tlv(&self) -> bool { + match self.proto_id { + PROTO_ID_SECURE_CHANNEL => self + .opcode::() + .ok() + .map(|op| op.is_tlv()) + .unwrap_or(false), + PROTO_ID_INTERACTION_MODEL => self + .opcode::() + .ok() + .map(|op| op.is_tlv()) + .unwrap_or(false), + _ => false, + } + } - let session = session_mgr.mut_by_index(sess_index).unwrap(); + /// Utility method to check if the protocol is Secure Channel, and the opcode is a standalone ACK (`MrpStandaloneAck`). + pub(crate) fn is_standalone_ack(&self) -> bool { + !self.reliable + && self.proto_id == PROTO_ID_SECURE_CHANNEL + && self.proto_opcode == secure_channel::common::OpCode::MRPStandAloneAck as u8 + } - self.pre_send_sess(session, tx, epoch) + /// Utility method to check if the protocol is Secure Channel, and the opcode is Status. + pub(crate) fn is_sc_status(&self) -> bool { + !self.reliable + && self.proto_id == PROTO_ID_SECURE_CHANNEL + && self.proto_opcode == secure_channel::common::OpCode::StatusReport as u8 } - pub(crate) fn pre_send_sess( - &mut self, - session: &mut Session, - tx: &mut Packet, - epoch: Epoch, - ) -> Result<(), Error> { - tx.proto.exch_id = self.id.id; - if self.role == Role::Initiator { - tx.proto.set_initiator(); - } + /// Utility method to check if the protocol is Secure Channel, and the opcode is a new session request. + pub(crate) fn is_new_session(&self) -> bool { + self.reliable + && self.proto_id == PROTO_ID_SECURE_CHANNEL + && (self.proto_opcode == secure_channel::common::OpCode::PBKDFParamRequest as u8 + || self.proto_opcode == secure_channel::common::OpCode::CASESigma1 as u8) + } +} - session.pre_send(tx)?; - self.mrp.pre_send(tx)?; - session.send(epoch, tx) +impl Display for MessageMeta { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.proto_id { + PROTO_ID_SECURE_CHANNEL => { + if let Ok(opcode) = self.opcode::() { + write!(f, "SC::{:?}", opcode) + } else { + write!(f, "SC::{:02x}", self.proto_opcode) + } + } + PROTO_ID_INTERACTION_MODEL => { + if let Ok(opcode) = self.opcode::() { + write!(f, "IM::{:?}", opcode) + } else { + write!(f, "IM::{:02x}", self.proto_opcode) + } + } + _ => write!(f, "{:02x}::{:02x}", self.proto_id, self.proto_opcode), + } } } -#[derive(Debug, Clone)] -pub(crate) enum ExchangeState { - Construction { - rx: *mut Packet<'static>, - notification: *const Notification, - }, - Active, - Acknowledge { - notification: *const Notification, - }, - ExchangeSend { - tx: *const Packet<'static>, - rx: *mut Packet<'static>, - notification: *const Notification, - }, - ExchangeRecv { - _tx: *const Packet<'static>, - tx_acknowledged: bool, - rx: *mut Packet<'static>, - notification: *const Notification, - }, - Complete { - tx: *const Packet<'static>, - notification: *const Notification, - }, - CompleteAcknowledge { - _tx: *const Packet<'static>, - notification: *const Notification, - }, - Closed, +/// An RX message pending on an `Exchange` instance. +pub struct RxMessage<'a>(PacketAccess<'a, MAX_RX_BUF_SIZE>); + +impl<'a> RxMessage<'a> { + /// Get the meta-data of the pending message + pub fn meta(&self) -> MessageMeta { + MessageMeta::from(&self.0.header.proto) + } + + /// Get the payload of the pending message + pub fn payload(&self) -> &[u8] { + &self.0.buf[self.0.payload_start..] + } } -pub struct ExchangeCtr<'a> { - pub(crate) exchange: Exchange<'a>, - pub(crate) construction_notification: &'a Notification, +/// Accessor to the TX message buffer of the underlying Matter transport stack. +/// +/// This is used to construct a new TX message to be sent on an `Exchange` instance. +/// +/// NOTE: It is strongly advised to use the `TxMessage` accessor in combination with the `Sender` utility, +/// which takes care of all message retransmission logic. Alternatively, one can use the +/// `Exchange::send` or `Exchange::send_with` which also take care of re-transmissions. +pub struct TxMessage<'a> { + exchange_id: ExchangeId, + matter: &'a Matter<'a>, + packet: PacketAccess<'a, MAX_TX_BUF_SIZE>, } -impl<'a> ExchangeCtr<'a> { - pub const fn id(&self) -> &ExchangeId { - self.exchange.id() +impl<'a> TxMessage<'a> { + /// Get a reference to the payload buffer of the TX message being built + pub fn payload(&mut self) -> &mut [u8] { + &mut self.packet.buf[PacketHdr::HDR_RESERVE..MAX_TX_BUF_SIZE - PacketHdr::TAIL_RESERVE] } - #[allow(clippy::all)] - // Should be #[allow(clippy::needless_pass_by_ref_mut)], but this is only in 1.73 which is not released yet - // rx is actually modified, but via an unsafe `*mut Packet<'static>` and apparently Clippy can't see this - pub async fn get(mut self, rx: &mut Packet<'_>) -> Result, Error> { - let construction_notification = self.construction_notification; + /// Complete and send a TX message by providing: + /// - The payload size that was filled-in by user code in the payload buffer returned by `TxMessage::payload` + /// - The TX message meta-data + pub fn complete( + mut self, + payload_start: usize, + payload_end: usize, + meta: M, + ) -> Result<(), Error> + where + M: Into, + { + if payload_start > payload_end + || payload_end > MAX_TX_BUF_SIZE - PacketHdr::HDR_RESERVE - PacketHdr::TAIL_RESERVE + { + Err(ErrorCode::Invalid)?; + } - self.exchange.with_ctx_mut(move |exchange, ctx| { - if !matches!(ctx.state, ExchangeState::Active) { - Err(ErrorCode::NoExchange)?; - } + let meta: MessageMeta = meta.into(); - let rx: &'static mut Packet<'static> = unsafe { core::mem::transmute(rx) }; - let notification: &'static Notification = - unsafe { core::mem::transmute(&exchange.notification) }; + self.packet.header.reset(); - ctx.state = ExchangeState::Construction { rx, notification }; + meta.set_into(&mut self.packet.header.proto); - construction_notification.signal(()); + let mut session_mgr = self.matter.transport_mgr.session_mgr.borrow_mut(); - Ok(()) - })?; + let session = session_mgr + .get(self.exchange_id.session_id()) + .ok_or(ErrorCode::NoSession)?; - self.exchange.notification.wait().await; + let (peer, retransmission) = session.pre_send( + Some(self.exchange_id.exchange_index()), + &mut self.packet.header, + self.matter.epoch, + )?; + + self.packet.peer = peer; + + info!( + "\n<<<<< {}\n => {}", + Packet::<0>::display(&self.packet.peer, &self.packet.header), + if retransmission { + "Re-sending" + } else { + "Sending" + }, + ); + + debug!( + "{}", + Packet::<0>::display_payload( + &self.packet.header.proto, + &self.packet.buf + [PacketHdr::HDR_RESERVE + payload_start..PacketHdr::HDR_RESERVE + payload_end] + ) + ); - Ok(self.exchange) + let packet = &mut *self.packet; + + let mut writebuf = WriteBuf::new_with( + &mut packet.buf, + PacketHdr::HDR_RESERVE + payload_start, + PacketHdr::HDR_RESERVE + payload_end, + ); + session.encode(&packet.header, &mut writebuf)?; + + let encoded_payload_start = writebuf.get_start(); + let encoded_payload_end = writebuf.get_tail(); + + self.packet.payload_start = encoded_payload_start; + self.packet.buf.truncate(encoded_payload_end); + self.packet.clear_on_drop(false); + + Ok(()) } } -#[derive(Debug, Clone, Eq, PartialEq)] -pub struct ExchangeId { - pub id: u16, - pub session_id: SessionId, +/// Outcome from calling `Exchange::wait_tx` +#[derive(Copy, Clone, Eq, PartialEq, Debug)] +pub enum TxOutcome { + /// The other side has acknowledged the last message or the last message was not using the MRP protocol + /// Stop re-sending. + Done, + /// Need to re-send the last message. + Retransmit, } -impl ExchangeId { - pub fn load(rx: &Packet) -> Self { - Self { - id: rx.proto.exch_id, - session_id: SessionId::load(rx), - } +impl TxOutcome { + /// Check if the outcome is `Done` + pub const fn is_done(&self) -> bool { + matches!(self, Self::Done) } } -#[derive(Debug, Clone, Eq, PartialEq)] -pub struct SessionId { - pub id: u16, - pub peer_addr: Address, - pub peer_nodeid: Option, - pub is_encrypted: bool, + +pub struct SenderTx<'a, 'b> { + sender: &'b mut Sender<'a>, + message: TxMessage<'a>, } -impl SessionId { - pub fn load(rx: &Packet) -> Self { - Self { - id: rx.plain.sess_id, - peer_addr: rx.peer, - peer_nodeid: rx.plain.get_src_u64(), - is_encrypted: rx.plain.is_encrypted(), +impl<'a, 'b> SenderTx<'a, 'b> { + pub fn split(&mut self) -> (&Exchange<'_>, &mut [u8]) { + (&self.sender.exchange, self.message.payload()) + } + + pub fn payload(&mut self) -> &mut [u8] { + self.message.payload() + } + + pub fn complete( + self, + payload_start: usize, + payload_end: usize, + meta: MessageMeta, + ) -> Result<(), Error> { + self.message.complete(payload_start, payload_end, meta)?; + + self.sender.initial = false; + + Ok(()) + } +} + +/// Utility struct for sending a message with potential retransmissions. +pub struct Sender<'a> { + exchange: &'a Exchange<'a>, + initial: bool, + complete: bool, +} + +impl<'a> Sender<'a> { + fn new(exchange: &'a Exchange<'a>) -> Result { + exchange.id.check_no_pending_retrans(exchange.matter)?; + + Ok(Self { + exchange, + initial: true, + complete: false, + }) + } + + /// Get the TX buffer of the underlying Matter stack for (re)constructing a new TX message, + /// waiting for the TX buffer to become available, if it is not. + /// + /// If the method returns `None`, it means that the message was already acknowledged by the other side, + /// or that the message does not need acknowledgement and re-transmissions. + /// + /// When called for the first time, the method will always return a `Some` value, as the message has not been sent even once yet. + /// Once the method returns `None`, it will always return `None` on subsequent calls, as the message has been acknowledged by the other side. + /// + /// Example: + /// ```ignore + /// let exchange = ...; + /// + /// let sender = exchange.sender()?; + /// + /// while let Some(mut tx) = sender.tx().await? { + /// let (exchange, payload) = tx.split()?; + /// + /// // Write the message payload in the `payload` `&mut [u8]` slice + /// // On every iteration of the loop, write the _same_ payload (as message re-transmission is idempotent w.r.t. the message) + /// ... + /// + /// // Complete the payload by providing `MessageMeta`, payload start and payload end + /// // On every iteration of the loop, proide the _same_ meta-data (as message re-transmission is idempotent w.r.t. the message) + /// let meta = ...; + /// let payload_start = ...; + /// let payload_end = ...; + /// + /// tx.complete(payload_start, payload_end, meta)?; + /// } + /// ``` + pub async fn tx(&mut self) -> Result>, Error> { + if self.complete { + return Ok(None); + } + + if !self.initial + && self + .exchange + .id + .wait_tx(self.exchange.matter) + .await? + .is_done() + { + // No need to re-transmit + self.complete = true; + return Ok(None); + } + + let id = self.exchange.id; + let matter = self.exchange.matter; + + let tx = id.init_send(matter).await?; + + if self.initial || id.pending_retrans(matter)? { + Ok(Some(SenderTx { + sender: self, + message: tx, + })) + } else { + self.complete = true; + Ok(None) } } } + +/// An exchange within a Matter stack, representing a session and an exchange within that session. +/// +/// This is the main API for sending and receiving messages within the Matter stack. +/// Used by upper-level layers like the Secure Channel and Interaction Model. pub struct Exchange<'a> { - pub(crate) id: ExchangeId, - pub(crate) matter: &'a Matter<'a>, - pub(crate) notification: Notification, + id: ExchangeId, + matter: &'a Matter<'a>, + rx: Option>, } impl<'a> Exchange<'a> { - pub const fn id(&self) -> &ExchangeId { - &self.id + pub(crate) const fn new(id: ExchangeId, matter: &'a Matter<'a>) -> Self { + Self { + id, + matter, + rx: None, + } } - pub fn accessor(&self) -> Result, Error> { - self.with_session(|sess| Ok(Accessor::for_session(sess, &self.matter.acl_mgr))) + /// Get the Id of the exchange + pub fn id(&self) -> ExchangeId { + self.id } - pub fn with_session_mut(&self, f: F) -> Result - where - F: FnOnce(&mut Session) -> Result, - { - self.with_ctx(|_self, ctx| { - let mut session_mgr = _self.matter.session_mgr.borrow_mut(); - - let sess_index = session_mgr - .get( - ctx.id.session_id.id, - ctx.id.session_id.peer_addr, - ctx.id.session_id.peer_nodeid, - ctx.id.session_id.is_encrypted, - ) - .ok_or(ErrorCode::NoSession)?; - - f(session_mgr.mut_by_index(sess_index).unwrap()) - }) + /// Get the Matter stack instance associated with this exchange + pub fn matter(&self) -> &'a Matter<'a> { + self.matter } - pub fn with_session(&self, f: F) -> Result - where - F: FnOnce(&Session) -> Result, - { - self.with_session_mut(|sess| f(sess)) + /// Create a new initiator exchange on the provided Matter stack for the provided Node ID + /// + /// This method will fail if there is no existing session in the provided Matter satack for the provided Node ID. + /// + // TODO: This signature will change in future + #[inline(always)] + pub async fn initiate( + matter: &'a Matter<'a>, + node_id: u64, + secure: bool, + ) -> Result { + matter.transport_mgr.initiate(matter, node_id, secure).await } - pub async fn acknowledge(&mut self) -> Result<(), Error> { - let wait = self.with_ctx_mut(|_self, ctx| { - if !matches!(ctx.state, ExchangeState::Active) { - Err(ErrorCode::NoExchange)?; - } - - if ctx.mrp.is_empty() { - Ok(false) - } else { - ctx.state = ExchangeState::Acknowledge { - notification: &_self.notification as *const _, - }; - _self.matter.send_notification.signal(()); + /// Accepts a new responder exchange pending on the provided Matter stack. + /// + /// If there is no new pending responder exchange, the method will wait indefinitely until one appears. + #[inline(always)] + pub async fn accept(matter: &'a Matter<'a>) -> Result { + Self::accept_after(matter, 0).await + } - Ok(true) + /// Accepts a new responder exchange pending on the provided Matter stack, but only if the + /// pending exchange was pending for longer than `received_timeout_ms`. + /// + /// If there is no new pending responder exchange, the method will wait indefinitely until one appears. + pub async fn accept_after( + matter: &'a Matter<'a>, + received_timeout_ms: u64, + ) -> Result { + if received_timeout_ms > 0 { + let epoch = matter.epoch; + + loop { + let mut accept = pin!(matter.transport_mgr.accept_if(matter, |_, exch, _| { + exch.mrp.has_rx_timed_out(received_timeout_ms, epoch) + })); + + let mut timer = pin!(Timer::after(embassy_time::Duration::from_millis(max( + received_timeout_ms / 2, + 1, + )))); + + if let Either::First(exchange) = select(&mut accept, &mut timer).await { + break exchange; + } } - })?; - - if wait { - self.notification.wait().await; + } else { + matter.transport_mgr.accept_if(matter, |_, _, _| true).await } - - Ok(()) } - pub async fn exchange( - &mut self, - tx: &mut Packet<'_>, - rx: &mut Packet<'_>, - ) -> Result<(), Error> { - let tx: &mut Packet<'static> = unsafe { core::mem::transmute(tx) }; - let rx: &mut Packet<'static> = unsafe { core::mem::transmute(rx) }; + /// Get access to the pending RX message on this exchange, and consume it when the returned `RxMessage` instance is dropped. + /// + /// If there is no pending RX message, the method will wait indefinitely until one appears. + /// + /// Note that if the uderlying session or exchange tracked by the Matter stack is dropped + /// (say, because of lack of resources or a hard networking error), the method will return an error. + #[inline(always)] + pub async fn recv(&mut self) -> Result, Error> { + self.recv_fetch().await?; + + self.rx.take().ok_or(ErrorCode::InvalidState.into()) + } - self.with_ctx_mut(|_self, ctx| { - if !matches!(ctx.state, ExchangeState::Active) { - Err(ErrorCode::NoExchange)?; - } + /// Get access to the pending RX message on this exchange, and consume it + /// by copying the payload into the provided `WriteBuf` instance. + /// + /// A syntax sugar for calling ```self.recv().await?``` and then copying the payload. + /// + /// Returns the exchange message meta-data. + /// + /// If there is no pending RX message, the method will wait indefinitely until one appears. + /// + /// If there is already a pending RX message, which was already fetched using `Exchange::recv_fetch` and that + /// message is not cleared yet using `Exchange::rx_done` or via some of the `Exchange::send*` methods, + /// the method will return that message. + /// + /// Note that if the uderlying session or exchange tracked by the Matter stack is dropped + /// (say, because of lack of resources or a hard networking error), the method will return an error. + #[inline(always)] + pub async fn recv_into(&mut self, wb: &mut WriteBuf<'_>) -> Result { + let rx = self.recv().await?; + + wb.reset(); + wb.append(rx.payload())?; + + Ok(rx.meta()) + } - let mut session_mgr = _self.matter.session_mgr.borrow_mut(); - ctx.pre_send(&mut session_mgr, tx)?; + /// Return a _reference_ to the pending RX message on this exchange. + /// + /// If there is no pending RX message, the method will wait indefinitely until one appears. + /// + /// Unlike `recv` which returns the actual message object which - when dropped - allows the transport to + /// fetch the _next_ RX message for this or other exchanges, `recv_fetch` keeps the received message around, + /// which is convenient when the message needs to be examined / processed by multiple layers of application code. + /// + /// Note however that this does not come for free - keeping the RX message around means that the transport cannot receive + /// _other_ RX messages which blocks the whole transport layer, as the transport layer uses a single RX message buffer. + /// + /// Therefore, calling `recv_fetch` should be done with care and the message should be marked as processed (and thus dropped) - + /// via `rx_done` as soon as possible, ideally without `await`-ing between `recv_fetch` and `rx_done` + /// + /// Note that if the uderlying session or exchange tracked by the Matter stack is dropped + /// (say, because of lack of resources or a hard networking error), the method will return an error. + #[inline(always)] + pub async fn recv_fetch(&mut self) -> Result<&RxMessage<'a>, Error> { + if self.rx.is_none() { + let rx = self.id.recv(self.matter).await?; + + self.rx = Some(rx); + } - ctx.state = ExchangeState::ExchangeSend { - tx: tx as *const _, - rx: rx as *mut _, - notification: &_self.notification as *const _, - }; - _self.matter.send_notification.signal(()); + self.rx() + } - Ok(()) - })?; + /// Returns the RX message which was already fetched using a previous call to `recv_fetch`. + /// If there is no fetched RX message, the method will fail with `ErrorCode::InvalidState`. + /// + /// This method only exists as a slight optimization for the cases where the user is sure, that there is + /// an RX message already fetched with `recv_fetch`, as - unlike `recv_fetch` - this method does not `await` and hence + /// variables used after calling `rx` do not have to be stored in the generated future. + /// + /// But in general and putting optimizations aside, it is always safe to replace calls to `rx` with calls to `recv_fetch`. + #[inline(always)] + pub fn rx(&self) -> Result<&RxMessage<'a>, Error> { + self.rx.as_ref().ok_or(ErrorCode::InvalidState.into()) + } - self.notification.wait().await; + /// Clears the RX message which was already fetched using a previous call to `recv_fetch`. + /// If there is no fetched RX message, the method will do nothing. + #[inline(always)] + pub fn rx_done(&mut self) -> Result<(), Error> { + self.rx = None; Ok(()) } - pub async fn complete(mut self, tx: &mut Packet<'_>) -> Result<(), Error> { - self.send_complete(tx).await + /// Gets access to the TX buffer of the Matter stack for constructing a new TX message. + /// If the TX buffer is not available, the method will wait indefinitely until it becomes available. + /// + /// NOTE: + /// This is a low-level method that leaves the re-transmission logic on the shoulders of the user. + /// Therefore, prefer using `Exchange::sender`, `Exchange::send` or `Exchange::send_with` instead. + /// + /// Note also that if the uderlying session or exchange tracked by the Matter stack is dropped + /// (say, because of lack of resources or a hard networking error), the method will return an error. + #[inline(always)] + pub async fn init_send(&mut self) -> Result, Error> { + self.rx = None; + + self.id.init_send(self.matter).await } - pub async fn send_complete(&mut self, tx: &mut Packet<'_>) -> Result<(), Error> { - let tx: &mut Packet<'static> = unsafe { core::mem::transmute(tx) }; - - self.with_ctx_mut(|_self, ctx| { - if !matches!(ctx.state, ExchangeState::Active) { - Err(ErrorCode::NoExchange)?; - } - - let mut session_mgr = _self.matter.session_mgr.borrow_mut(); - ctx.pre_send(&mut session_mgr, tx)?; + /// Waits until the other side acknowledges the last message sent on this exchange, + /// or until time for a re-transmission had come. + /// + /// If the last sent message was not using the MRP protocol, the method will return immediately with `TxOutcome::Done`. + /// + /// NOTE: + /// This is a low-level method that leaves the re-transmission logic on the shoulders of the user. + /// Therefore, prefer using `Exchange::sender`, `Exchange::send` or `Exchange::send_with` instead. + /// + /// Note also that if the uderlying session or exchange tracked by the Matter stack is dropped + /// (say, because of lack of resources or a hard networking error), the method will return an error. + #[inline(always)] + pub async fn wait_tx(&mut self) -> Result { + self.rx = None; + + self.id.wait_tx(self.matter).await + } - ctx.state = ExchangeState::Complete { - tx: tx as *const _, - notification: &_self.notification as *const _, - }; - _self.matter.send_notification.signal(()); + /// Returns `true` if there is a pending message re-transmission. + /// A re-transmission will be pending if the last sent message was using the MRP protocol, and + /// an acknowledgement for the other side is still pending. + /// + /// NOTE: + /// This is a low-level method that leaves the re-transmission logic on the shoulders of the user. + /// Therefore, prefer using `Exchange::sender`, `Exchange::send` or `Exchange::send_with` instead. + /// + /// Note also that if the uderlying session or exchange tracked by the Matter stack is dropped + /// (say, because of lack of resources or a hard networking error), the method will return an error. + pub fn pending_retrans(&self) -> Result { + self.id.pending_retrans(self.matter) + } - Ok(()) - })?; + /// Returns `true` if there is a pending message acknowledgement. + /// An acknowledgement be pending if the last received message was using the MRP protocol, and we have to acknowledge it. + /// + /// NOTE: + /// This is a low-level method that leaves the re-transmission logic on the shoulders of the user. + /// Therefore, prefer using `Exchange::sender`, `Exchange::send` or `Exchange::send_with` instead. + /// + /// Note also that if the uderlying session or exchange tracked by the Matter stack is dropped + /// (say, because of lack of resources or a hard networking error), the method will return an error. + pub fn pending_ack(&self) -> Result { + self.id.pending_ack(self.matter) + } - self.notification.wait().await; + /// Acknowledge the last message received on this exchange (by sending a `MrpStandaloneAck`). + /// + /// If the last message was already acknowledged + /// (either by a previous call to this method, by piggy-backing on a reliable message, or by the Matter stack itself), + /// this method does nothing. + /// + /// Note that if the uderlying session or exchange tracked by the Matter stack is dropped + /// (say, because of lack of resources or a hard networking error), the method will return an error. + #[inline(always)] + pub async fn acknowledge(&mut self) -> Result<(), Error> { + if self.pending_ack()? { + self.send_with(|exchange, _| { + Ok(exchange + .pending_ack()? + .then_some(secure_channel::common::OpCode::MRPStandAloneAck.into())) + }) + .await?; + } Ok(()) } - pub(crate) fn get_next_sess_id(&mut self) -> u16 { - self.matter.session_mgr.borrow_mut().get_next_sess_id() + /// Utility for sending a message on this exchange that automatically handles all re-transmission logic + /// in case the constructed message needs to be send reliably. + /// + /// Note that if the uderlying session or exchange tracked by the Matter stack is dropped + /// (say, because of lack of resources or a hard networking error), the method will return an error. + pub fn sender(&mut self) -> Result, Error> { + self.rx = None; + + Sender::new(self) } - pub(crate) async fn clone_session( - &mut self, - tx: &mut Packet<'_>, - clone_data: &CloneData, - ) -> Result { - loop { - let result = self - .matter - .session_mgr - .borrow_mut() - .clone_session(clone_data); - - match result { - Err(err) if err.code() == ErrorCode::NoSpaceSessions => { - self.matter.evict_session(tx).await? - } - other => break other, + /// Utility for sending a message on this exchange that automatically handles all re-transmission logic + /// in case the constructed message needs to be send reliably. + /// + /// The message is constructed by the provided closure, which is given a `WriteBuf` instance to write the message payload into. + /// + /// Note that the closure is expected to construct the exact same message when called multiple times. + /// + /// Note also that if the uderlying session or exchange tracked by the Matter stack is dropped + /// (say, because of lack of resources or a hard networking error), the method will return an error. + pub async fn send_with(&mut self, mut f: F) -> Result<(), Error> + where + F: FnMut(&Exchange, &mut WriteBuf) -> Result, Error>, + { + let mut sender = self.sender()?; + + while let Some(mut tx) = sender.tx().await? { + let (exchange, payload) = tx.split(); + + let mut wb = WriteBuf::new(payload); + + if let Some(meta) = f(exchange, &mut wb)? { + let payload_start = wb.get_start(); + let payload_end = wb.get_tail(); + tx.complete(payload_start, payload_end, meta)?; + } else { + // Closure aborted sending + break; } } + + Ok(()) } - fn with_ctx(&self, f: F) -> Result + /// Send the provided exchange meta-data and payload as part of this exchange. + /// + /// If the provided exchange meta-data indicates a reliable message, the message will be automatically re-transmitted until + /// the other side acknowledges it. + /// + /// Note that if the uderlying session or exchange tracked by the Matter stack is dropped + /// (say, because of lack of resources or a hard networking error), the method will return an error. + pub async fn send(&mut self, meta: M, payload: &[u8]) -> Result<(), Error> where - F: FnOnce(&Self, &ExchangeCtx) -> Result, + M: Into, { - let mut exchanges = self.matter.exchanges.borrow_mut(); + let meta = meta.into(); + + self.send_with(|_, wb| { + wb.append(payload)?; - let exchange = ExchangeCtx::get(&mut exchanges, &self.id).ok_or(ErrorCode::NoExchange)?; // TODO + Ok(Some(meta)) + }) + .await + } - f(self, exchange) + pub(crate) fn accessor(&self) -> Result, Error> { + self.id.accessor(self.matter) } - fn with_ctx_mut(&mut self, f: F) -> Result + pub(crate) fn with_session(&self, f: F) -> Result where - F: FnOnce(&mut Self, &mut ExchangeCtx) -> Result, + F: FnOnce(&mut Session) -> Result, { - let mut exchanges = self.matter.exchanges.borrow_mut(); - - let exchange = ExchangeCtx::get(&mut exchanges, &self.id).ok_or(ErrorCode::NoExchange)?; // TODO + self.id.with_session(self.matter, f) + } - f(self, exchange) + pub(crate) fn with_ctx(&self, f: F) -> Result + where + F: FnOnce(&mut Session, usize) -> Result, + { + self.id.with_ctx(self.matter, f) } } impl<'a> Drop for Exchange<'a> { fn drop(&mut self) { - let _ = self.with_ctx_mut(|_self, ctx| { - ctx.state = ExchangeState::Closed; - _self.matter.send_notification.signal(()); + let closed = self.with_ctx(|sess, exch_index| Ok(sess.remove_exch(exch_index))); - Ok(()) - }); + if !matches!(closed, Ok(true)) { + self.matter.transport_mgr.dropped.notify(); + } + } +} + +impl<'a> Display for Exchange<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.id) } } diff --git a/rs-matter/src/transport/mrp.rs b/rs-matter/src/transport/mrp.rs index d9815919..693f60be 100644 --- a/rs-matter/src/transport/mrp.rs +++ b/rs-matter/src/transport/mrp.rs @@ -15,115 +15,159 @@ * limitations under the License. */ +use log::error; + +use crate::error::*; use crate::utils::epoch::Epoch; -use core::time::Duration; -use crate::{error::*, secure_channel, transport::packet::Packet}; -use log::error; +use super::{plain_hdr::PlainHdr, proto_hdr::ProtoHdr}; -// 200 ms -const MRP_STANDALONE_ACK_TIMEOUT: u64 = 200; +//const MRP_STANDALONE_ACK_TIMEOUT_MS: u64 = 200; // TODO: Use to pro-actively send ACKs +const MRP_BASE_RETRY_INTERVAL_MS: u64 = 200; // TODO: Un-hardcode for Sleepy vs Active devices +const MRP_MAX_TRANSMISSIONS: usize = 10; +const MRP_BACKOFF_THRESHOLD: usize = 3; +const MRP_BACKOFF_BASE: (u64, u64) = (16, 10); // 1.6 + //const MRP_BACKOFF_JITTER: (u64, u64) = (25, 100); // 0.25 + //const MRP_BACKOFF_MARGIN: (u64, u64) = (11, 10); // 1.1 #[derive(Debug)] pub struct RetransEntry { // The msg counter that we are waiting to be acknowledged msg_ctr: u32, - // This will additionally have retransmission count and periods once we implement it + sent_at_ms: u64, + counter: usize, } impl RetransEntry { - pub fn new(msg_ctr: u32) -> Self { - Self { msg_ctr } + pub fn new(msg_ctr: u32, epoch: Epoch) -> Self { + Self { + msg_ctr, + sent_at_ms: epoch().as_millis() as u64, + counter: 0, + } } pub fn get_msg_ctr(&self) -> u32 { self.msg_ctr } + + pub fn is_due(&self, epoch: Epoch) -> bool { + self.sent_at_ms + .checked_add(self.delay_ms()) + .map(|d| d <= epoch().as_millis() as u64) + .unwrap_or(true) + } + + pub fn delay_ms(&self) -> u64 { + let mut delay = MRP_BASE_RETRY_INTERVAL_MS; + + if self.counter >= MRP_BACKOFF_THRESHOLD { + for _ in 0..self.counter - MRP_BACKOFF_THRESHOLD { + delay = delay * MRP_BACKOFF_BASE.0 / MRP_BACKOFF_BASE.1; + } + } + + delay + } + + pub fn pre_send(&mut self, ctr: u32) -> Result<(), Error> { + if self.msg_ctr == ctr { + if self.counter < MRP_MAX_TRANSMISSIONS { + self.counter += 1; + Ok(()) + } else { + Err(ErrorCode::Invalid.into()) // TODO + } + } else { + // This indicates there was some existing entry for same sess-id/exch-id, which shouldn't happen + panic!("Previous retrans entry for this exchange already exists"); + } + } } #[derive(Debug, Clone)] pub struct AckEntry { // The msg counter that we should acknowledge - msg_ctr: u32, - // The max time after which this entry must be ACK - ack_timeout: Duration, + pub(crate) msg_ctr: u32, + // Whether the message was acknowledged at least once + pub(crate) acknowledged: bool, } impl AckEntry { - pub fn new(msg_ctr: u32, epoch: Epoch) -> Result { - if let Some(ack_timeout) = - epoch().checked_add(Duration::from_millis(MRP_STANDALONE_ACK_TIMEOUT)) - { - Ok(Self { - msg_ctr, - ack_timeout, - }) - } else { - Err(ErrorCode::Invalid.into()) - } + pub fn new(msg_ctr: u32) -> Result { + Ok(Self { + msg_ctr, + acknowledged: false, + }) } pub fn get_msg_ctr(&self) -> u32 { self.msg_ctr } - - pub fn has_timed_out(&self, epoch: Epoch) -> bool { - self.ack_timeout > epoch() - } } #[derive(Default, Debug)] pub struct ReliableMessage { - retrans: Option, - ack: Option, + pub(crate) retrans: Option, + pub(crate) ack: Option, + pub(crate) received_at_ms: Option, } impl ReliableMessage { pub fn new() -> Self { - Self { - ..Default::default() - } + Default::default() } - pub fn is_empty(&self) -> bool { - self.retrans.is_none() && self.ack.is_none() + pub fn is_retrans_pending(&self) -> bool { + self.retrans.is_some() } - // Check any pending acknowledgements / retransmissions and take action - pub fn is_ack_ready(&self, epoch: Epoch) -> bool { - // Acknowledgements - if let Some(ack_entry) = &self.ack { - ack_entry.has_timed_out(epoch) - } else { - false - } + pub fn is_ack_pending(&self) -> bool { + self.ack + .as_ref() + .map(|ack| !ack.acknowledged) + .unwrap_or(false) } - pub fn prepare_ack(_exch_id: u16, proto_tx: &mut Packet) { - secure_channel::common::create_mrp_standalone_ack(proto_tx); + pub fn has_rx_timed_out(&self, timeout_ms: u64, epoch: Epoch) -> bool { + self.received_at_ms + .and_then(|received_at_ms| { + received_at_ms + .checked_add(timeout_ms) + .map(|d| d <= epoch().as_millis() as u64) + }) + .unwrap_or(false) } - pub fn pre_send(&mut self, proto_tx: &mut Packet) -> Result<(), Error> { + pub fn pre_send( + &mut self, + tx_plain: &PlainHdr, + tx_proto: &mut ProtoHdr, + epoch: Epoch, + ) -> Result<(), Error> { // Check if any acknowledgements are pending for this exchange, - - // if so, piggy back in the encoded header here - if let Some(ack_entry) = &self.ack { - // Ack Entry exists, set ACK bit and remove from table - proto_tx.proto.set_ack(ack_entry.get_msg_ctr()); - self.ack = None; + if let Some(ack) = &mut self.ack { + // if so, piggy back in the encoded header here + tx_proto.set_ack(Some(ack.get_msg_ctr())); + ack.acknowledged = true; } - if !proto_tx.is_reliable() { - return Ok(()); - } + if tx_proto.is_reliable() { + if let Some(retrans) = &mut self.retrans { + if retrans.pre_send(tx_plain.ctr).is_err() { + // Too many retransmissions, give up + error!("Too many retransmissions. Giving up"); - if self.retrans.is_some() { - // This indicates there was some existing entry for same sess-id/exch-id, which shouldnt happen - error!("Previous retrans entry for this exchange already exists"); - Err(ErrorCode::Invalid)?; + self.retrans = None; + self.ack = None; + } + } else { + self.retrans = Some(RetransEntry::new(tx_plain.ctr, epoch)); + } } - self.retrans = Some(RetransEntry::new(proto_tx.plain.ctr)); + self.received_at_ms = None; + Ok(()) } @@ -132,29 +176,39 @@ impl ReliableMessage { * - there can be only one pending retransmission per exchange (so this is per-exchange) * - duplicate detection should happen per session (obviously), so that part is per-session */ - pub fn recv(&mut self, proto_rx: &Packet, epoch: Epoch) -> Result<(), Error> { - if proto_rx.proto.is_ack() { + pub fn post_recv( + &mut self, + rx_plain: &PlainHdr, + rx_proto: &ProtoHdr, + epoch: Epoch, + ) -> Result<(), Error> { + if let Some(ack_msg_ctr) = rx_proto.get_ack() { // Handle received Acks - let ack_msg_ctr = proto_rx.proto.get_ack_msg_ctr().ok_or(ErrorCode::Invalid)?; if let Some(entry) = &self.retrans { if entry.get_msg_ctr() != ack_msg_ctr { - // TODO: XXX Fix this - error!("Mismatch in retrans-table's msg counter and received msg counter: received {}, expected {}. This is expected for the timebeing", ack_msg_ctr, entry.get_msg_ctr()); + error!("Mismatch in retrans-table's msg counter and received msg counter: received {:x}, expected {:x}.", ack_msg_ctr, entry.msg_ctr); } + self.retrans = None; + self.ack = None; } } - if proto_rx.proto.is_reliable() { - if self.ack.is_some() { + if rx_proto.is_reliable() { + if let Some(ack) = &self.ack { // This indicates there was some existing entry for same sess-id/exch-id, which shouldnt happen // TODO: As per the spec if this happens, we need to send out the previous ACK and note this new ACK - error!("Previous ACK entry for this exchange already exists"); - Err(ErrorCode::Invalid)?; + error!( + "Previous ACK entry {:x} for this exchange already exists", + ack.get_msg_ctr() + ); } - self.ack = Some(AckEntry::new(proto_rx.plain.ctr, epoch)?); + self.ack = Some(AckEntry::new(rx_plain.ctr)?); } + + self.received_at_ms = Some(epoch().as_millis() as u64); + Ok(()) } } diff --git a/rs-matter/src/transport/network.rs b/rs-matter/src/transport/network.rs index 254833d3..757f8fb8 100644 --- a/rs-matter/src/transport/network.rs +++ b/rs-matter/src/transport/network.rs @@ -17,16 +17,26 @@ use core::fmt::{Debug, Display}; -#[cfg(not(feature = "std"))] pub use core::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; -#[cfg(feature = "std")] -pub use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; use crate::error::Error; +// Maximum UDP RX packet size per Matter spec +pub const MAX_RX_PACKET_SIZE: usize = 1583; + +// Maximum UDP TX packet size per Matter spec +pub const MAX_TX_PACKET_SIZE: usize = 1280 - 40/*IPV6 header size*/ - 8/*UDP header size*/; + +// Maximum TCP RX packet size per Matter spec +pub const MAX_RX_LARGE_PACKET_SIZE: usize = 1024 * 1024; + +// Maximum TCP TX packet size per Matter spec +pub const MAX_TX_LARGE_PACKET_SIZE: usize = MAX_RX_LARGE_PACKET_SIZE; + #[derive(Eq, PartialEq, Copy, Clone)] pub enum Address { Udp(SocketAddr), + Tcp(SocketAddr), } impl Address { @@ -35,14 +45,20 @@ impl Address { } pub fn is_reliable(&self) -> bool { - match self { - Self::Udp(_) => false, - } + matches!(self, Self::Tcp(_)) } pub fn unwrap_udp(self) -> SocketAddr { match self { Self::Udp(addr) => addr, + other => panic!("Expected UDP address, got {:?}", other), + } + } + + pub fn unwrap_tcp(self) -> SocketAddr { + match self { + Self::Tcp(addr) => addr, + other => panic!("Expected TCP address, got {:?}", other), } } } @@ -57,6 +73,7 @@ impl Display for Address { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { match self { Address::Udp(addr) => write!(f, "UDP {}", addr), + Address::Tcp(addr) => write!(f, "TCP {}", addr), } } } @@ -65,6 +82,7 @@ impl Debug for Address { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { match self { Address::Udp(addr) => writeln!(f, "{}", addr), + Address::Tcp(addr) => writeln!(f, "{}", addr), } } } diff --git a/rs-matter/src/transport/packet.rs b/rs-matter/src/transport/packet.rs index 5e0cf98a..8c4dd882 100644 --- a/rs-matter/src/transport/packet.rs +++ b/rs-matter/src/transport/packet.rs @@ -15,312 +15,96 @@ * limitations under the License. */ -use log::{error, info, trace}; -use owo_colors::OwoColorize; +use core::fmt; + +use log::trace; use crate::{ - error::{Error, ErrorCode}, - interaction_model::core::PROTO_ID_INTERACTION_MODEL, - secure_channel::common::PROTO_ID_SECURE_CHANNEL, - tlv, + crypto::AEAD_MIC_LEN_BYTES, + error::Error, utils::{parsebuf::ParseBuf, writebuf::WriteBuf}, }; use super::{ - network::Address, plain_hdr::{self, PlainHdr}, proto_hdr::{self, ProtoHdr}, }; -pub const MAX_RX_BUF_SIZE: usize = 1583; -pub const MAX_RX_STATUS_BUF_SIZE: usize = 100; -pub const MAX_TX_BUF_SIZE: usize = 1280 - 40/*IPV6 header size*/ - 8/*UDP header size*/; - -#[derive(Debug, PartialEq, Eq, Copy, Clone)] -enum RxState { - Uninit, - PlainDecode, - ProtoDecode, -} - -enum Direction<'a> { - Tx(WriteBuf<'a>), - Rx(ParseBuf<'a>, RxState), -} - -impl<'a> Direction<'a> { - pub fn load(&mut self, direction: &Direction) -> Result<(), Error> { - if matches!(self, Self::Tx(_)) != matches!(direction, Direction::Tx(_)) { - Err(ErrorCode::Invalid)?; - } - - match self { - Self::Tx(wb) => match direction { - Direction::Tx(src_wb) => wb.load(src_wb)?, - Direction::Rx(_, _) => Err(ErrorCode::Invalid)?, - }, - Self::Rx(pb, state) => match direction { - Direction::Tx(_) => Err(ErrorCode::Invalid)?, - Direction::Rx(src_pb, src_state) => { - pb.load(src_pb)?; - *state = *src_state; - } - }, - } - - Ok(()) - } -} - -pub struct Packet<'a> { +#[derive(Debug, Default, Clone)] +pub struct PacketHdr { pub plain: PlainHdr, pub proto: ProtoHdr, - pub peer: Address, - data: Direction<'a>, } -impl<'a> Packet<'a> { - const HDR_RESERVE: usize = plain_hdr::max_plain_hdr_len() + proto_hdr::max_proto_hdr_len(); +impl PacketHdr { + pub const HDR_RESERVE: usize = plain_hdr::max_plain_hdr_len() + proto_hdr::max_proto_hdr_len(); + pub const TAIL_RESERVE: usize = AEAD_MIC_LEN_BYTES; - pub fn new_rx(buf: &'a mut [u8]) -> Self { + #[inline(always)] + pub const fn new() -> Self { Self { - plain: Default::default(), - proto: Default::default(), - peer: Address::default(), - data: Direction::Rx(ParseBuf::new(buf), RxState::Uninit), - } - } - - pub fn new_tx(buf: &'a mut [u8]) -> Self { - let mut wb = WriteBuf::new(buf); - wb.reserve(Packet::HDR_RESERVE).unwrap(); - - // Reliability on by default - let mut proto: ProtoHdr = Default::default(); - proto.set_reliable(); - - Self { - plain: Default::default(), - proto, - peer: Address::default(), - data: Direction::Tx(wb), + plain: PlainHdr::new(), + proto: ProtoHdr::new(), } } pub fn reset(&mut self) { - if let Direction::Tx(wb) = &mut self.data { - wb.reset(); - wb.reserve(Packet::HDR_RESERVE).unwrap(); - - self.plain = Default::default(); - self.proto = Default::default(); - self.peer = Address::default(); - - self.proto.set_reliable(); - } + self.plain = Default::default(); + self.proto = Default::default(); + self.proto.set_reliable(); } - pub fn load(&mut self, packet: &Packet) -> Result<(), Error> { + pub fn load(&mut self, packet: &PacketHdr) { self.plain = packet.plain.clone(); self.proto = packet.proto.clone(); - self.peer = packet.peer; - self.data.load(&packet.data) - } - - pub fn as_slice(&self) -> &[u8] { - match &self.data { - Direction::Rx(pb, _) => pb.as_slice(), - Direction::Tx(wb) => wb.as_slice(), - } - } - - pub fn as_mut_slice(&mut self) -> &mut [u8] { - match &mut self.data { - Direction::Rx(pb, _) => pb.as_mut_slice(), - Direction::Tx(wb) => wb.as_mut_slice(), - } - } - - pub fn get_parsebuf(&mut self) -> Result<&mut ParseBuf<'a>, Error> { - if let Direction::Rx(pbuf, _) = &mut self.data { - Ok(pbuf) - } else { - Err(ErrorCode::Invalid.into()) - } - } - - pub fn get_writebuf(&mut self) -> Result<&mut WriteBuf<'a>, Error> { - if let Direction::Tx(wbuf) = &mut self.data { - Ok(wbuf) - } else { - Err(ErrorCode::Invalid.into()) - } } - pub fn get_proto_id(&self) -> u16 { - self.proto.proto_id + pub fn decode_plain_hdr(&mut self, pb: &mut ParseBuf) -> Result<(), Error> { + self.plain.decode(pb) } - pub fn set_proto_id(&mut self, proto_id: u16) { - self.proto.proto_id = proto_id; - } - - pub fn get_proto_opcode(&self) -> Result { - num::FromPrimitive::from_u8(self.proto.proto_opcode).ok_or(ErrorCode::Invalid.into()) - } - - pub fn get_proto_raw_opcode(&self) -> u8 { - self.proto.proto_opcode - } - - pub fn check_proto_opcode(&self, opcode: u8) -> Result<(), Error> { - if self.proto.proto_opcode == opcode { - Ok(()) - } else { - Err(ErrorCode::Invalid.into()) - } - } - - pub fn set_proto_opcode(&mut self, proto_opcode: u8) { - self.proto.proto_opcode = proto_opcode; - } - - pub fn set_reliable(&mut self) { - self.proto.set_reliable() - } - - pub fn unset_reliable(&mut self) { - self.proto.unset_reliable() - } - - pub fn is_reliable(&mut self) -> bool { - self.proto.is_reliable() - } - - pub fn proto_decode(&mut self, peer_nodeid: u64, dec_key: Option<&[u8]>) -> Result<(), Error> { - match &mut self.data { - Direction::Rx(pb, state) => { - if *state == RxState::PlainDecode { - *state = RxState::ProtoDecode; - self.proto - .decrypt_and_decode(&self.plain, pb, peer_nodeid, dec_key) - } else { - error!("Invalid state for proto_decode"); - Err(ErrorCode::InvalidState.into()) - } - } - _ => Err(ErrorCode::InvalidState.into()), - } + pub fn decode_remaining( + &mut self, + pb: &mut ParseBuf, + peer_nodeid: u64, + dec_key: Option<&[u8]>, + ) -> Result<(), Error> { + self.proto + .decrypt_and_decode(&self.plain, pb, peer_nodeid, dec_key) } - pub fn proto_encode( - &mut self, - peer: Address, - peer_nodeid: Option, + pub fn encode( + &self, + wb: &mut WriteBuf, local_nodeid: u64, - plain_text: bool, enc_key: Option<&[u8]>, ) -> Result<(), Error> { - self.peer = peer; - // Generate encrypted header let mut tmp_buf = [0_u8; proto_hdr::max_proto_hdr_len()]; let mut write_buf = WriteBuf::new(&mut tmp_buf); self.proto.encode(&mut write_buf)?; - self.get_writebuf()?.prepend(write_buf.as_slice())?; - - // Generate plain-text header - if plain_text { - if let Some(d) = peer_nodeid { - self.plain.set_dest_u64(d); - } - } + wb.prepend(write_buf.as_slice())?; let mut tmp_buf = [0_u8; plain_hdr::max_plain_hdr_len()]; let mut write_buf = WriteBuf::new(&mut tmp_buf); self.plain.encode(&mut write_buf)?; let plain_hdr_bytes = write_buf.as_slice(); - trace!("unencrypted packet: {:x?}", self.as_mut_slice()); + trace!("unencrypted packet: {:x?}", wb.as_slice()); let ctr = self.plain.ctr; if let Some(e) = enc_key { - proto_hdr::encrypt_in_place( - ctr, - local_nodeid, - plain_hdr_bytes, - self.get_writebuf()?, - e, - )?; + proto_hdr::encrypt_in_place(ctr, local_nodeid, plain_hdr_bytes, wb, e)?; } - self.get_writebuf()?.prepend(plain_hdr_bytes)?; - trace!("Full encrypted packet: {:x?}", self.as_mut_slice()); + wb.prepend(plain_hdr_bytes)?; + trace!("Full encrypted packet: {:x?}", wb.as_slice()); Ok(()) } +} - pub fn is_plain_hdr_decoded(&self) -> Result { - match &self.data { - Direction::Rx(_, state) => match state { - RxState::Uninit => Ok(false), - _ => Ok(true), - }, - _ => Err(ErrorCode::InvalidState.into()), - } - } - - pub fn plain_hdr_decode(&mut self) -> Result<(), Error> { - match &mut self.data { - Direction::Rx(pb, state) => { - if *state == RxState::Uninit { - *state = RxState::PlainDecode; - self.plain.decode(pb) - } else { - error!("Invalid state for plain_decode"); - Err(ErrorCode::InvalidState.into()) - } - } - _ => Err(ErrorCode::InvalidState.into()), - } - } - - pub fn log(&self, operation: &str) { - match self.get_proto_id() { - PROTO_ID_SECURE_CHANNEL => { - if let Ok(opcode) = self.get_proto_opcode::() - { - info!("{} SC:{:?}: ", operation.cyan(), opcode); - } else { - info!( - "{} SC:{}??: ", - operation.cyan(), - self.get_proto_raw_opcode() - ); - } - - tlv::print_tlv_list(self.as_slice()); - } - PROTO_ID_INTERACTION_MODEL => { - if let Ok(opcode) = - self.get_proto_opcode::() - { - info!("{} IM:{:?}: ", operation.cyan(), opcode); - } else { - info!( - "{} IM:{}??: ", - operation.cyan(), - self.get_proto_raw_opcode() - ); - } - - tlv::print_tlv_list(self.as_slice()); - } - other => info!( - "{} {}??:{}??: ", - operation.cyan(), - other, - self.get_proto_raw_opcode() - ), - } +impl fmt::Display for PacketHdr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "[{}][{}]", self.plain, self.proto) } } diff --git a/rs-matter/src/transport/plain_hdr.rs b/rs-matter/src/transport/plain_hdr.rs index f939829a..e2b54812 100644 --- a/rs-matter/src/transport/plain_hdr.rs +++ b/rs-matter/src/transport/plain_hdr.rs @@ -15,18 +15,13 @@ * limitations under the License. */ +use core::fmt; + use crate::error::*; use crate::utils::parsebuf::ParseBuf; use crate::utils::writebuf::WriteBuf; use bitflags::bitflags; -use log::info; - -#[derive(Debug, PartialEq, Eq, Default, Copy, Clone)] -pub enum SessionType { - #[default] - None, - Encrypted, -} +use log::trace; bitflags! { #[repr(transparent)] @@ -38,68 +33,192 @@ bitflags! { } } +impl fmt::Display for MsgFlags { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut sep = false; + for flag in [ + Self::SRC_ADDR_PRESENT, + Self::DSIZ_UNICAST_NODEID, + Self::DSIZ_GROUPCAST_NODEID, + ] { + if self.contains(flag) { + if sep { + write!(f, "|")?; + } + + let str = match flag { + Self::DSIZ_UNICAST_NODEID => "U", + Self::DSIZ_GROUPCAST_NODEID => "G", + Self::SRC_ADDR_PRESENT => "S", + _ => "?", + }; + + write!(f, "{}", str)?; + sep = true; + } + } + + Ok(()) + } +} + // This is the unencrypted message #[derive(Debug, Default, Clone)] pub struct PlainHdr { - pub flags: MsgFlags, - pub sess_type: SessionType, + flags: MsgFlags, pub sess_id: u16, pub ctr: u32, - peer_nodeid: Option, + src_nodeid: u64, + dst_nodeid: u64, } impl PlainHdr { - pub fn set_dest_u64(&mut self, id: u64) { - self.flags |= MsgFlags::DSIZ_UNICAST_NODEID; - self.peer_nodeid = Some(id); + #[inline(always)] + pub const fn new() -> Self { + Self { + flags: MsgFlags::empty(), + sess_id: 0, + ctr: 0, + src_nodeid: 0, + dst_nodeid: 0, + } } - pub fn get_src_u64(&self) -> Option { + pub fn get_src_nodeid(&self) -> Option { if self.flags.contains(MsgFlags::SRC_ADDR_PRESENT) { - self.peer_nodeid + Some(self.src_nodeid) } else { None } } -} -impl PlainHdr { + pub fn set_src_nodeid(&mut self, id: Option) { + if let Some(id) = id { + self.flags |= MsgFlags::SRC_ADDR_PRESENT; + self.src_nodeid = id; + } else { + self.flags.remove(MsgFlags::SRC_ADDR_PRESENT); + self.src_nodeid = 0; + } + } + + pub fn get_dst_unicast_nodeid(&self) -> Option { + if self.flags.contains(MsgFlags::DSIZ_UNICAST_NODEID) { + Some(self.dst_nodeid) + } else { + None + } + } + + pub fn set_dst_unicast_nodeid(&mut self, id: Option) { + if let Some(id) = id { + self.flags |= MsgFlags::DSIZ_UNICAST_NODEID; + self.flags.remove(MsgFlags::DSIZ_GROUPCAST_NODEID); + self.dst_nodeid = id; + } else { + self.flags + .remove(MsgFlags::DSIZ_UNICAST_NODEID | MsgFlags::DSIZ_GROUPCAST_NODEID); + self.dst_nodeid = 0; + } + } + + pub fn get_dst_groupcast_nodeid(&self) -> Option { + if self.flags.contains(MsgFlags::DSIZ_GROUPCAST_NODEID) { + Some(self.dst_nodeid as u16) + } else { + None + } + } + + pub fn set_dst_groupcast_nodeid(&mut self, id: Option) { + if let Some(id) = id { + self.flags |= MsgFlags::DSIZ_GROUPCAST_NODEID; + self.flags.remove(MsgFlags::DSIZ_UNICAST_NODEID); + self.dst_nodeid = id as u64; + } else { + self.flags + .remove(MsgFlags::DSIZ_UNICAST_NODEID | MsgFlags::DSIZ_GROUPCAST_NODEID); + self.dst_nodeid = 0; + } + } + // it will have an additional 'message length' field first pub fn decode(&mut self, msg: &mut ParseBuf) -> Result<(), Error> { self.flags = MsgFlags::from_bits(msg.le_u8()?).ok_or(ErrorCode::Invalid)?; self.sess_id = msg.le_u16()?; let _sec_flags = msg.le_u8()?; - self.sess_type = if self.sess_id != 0 { - SessionType::Encrypted - } else { - SessionType::None - }; self.ctr = msg.le_u32()?; if self.flags.contains(MsgFlags::SRC_ADDR_PRESENT) { - self.peer_nodeid = Some(msg.le_u64()?); + self.src_nodeid = msg.le_u64()?; + } + + if !self + .flags + .contains(MsgFlags::DSIZ_UNICAST_NODEID | MsgFlags::DSIZ_GROUPCAST_NODEID) + { + if self.flags.contains(MsgFlags::DSIZ_UNICAST_NODEID) { + self.dst_nodeid = msg.le_u64()?; + } else if self.flags.contains(MsgFlags::DSIZ_GROUPCAST_NODEID) { + self.dst_nodeid = msg.le_u16()? as u64; + } } - info!( - "[decode] flags: {:?}, session type: {:#?}, sess_id: {}, ctr: {}", - self.flags, self.sess_type, self.sess_id, self.ctr - ); + trace!("[decode] {}", self); Ok(()) } - pub fn encode(&mut self, resp_buf: &mut WriteBuf) -> Result<(), Error> { + pub fn encode(&self, resp_buf: &mut WriteBuf) -> Result<(), Error> { + trace!("[encode] {}", self); resp_buf.le_u8(self.flags.bits())?; resp_buf.le_u16(self.sess_id)?; resp_buf.le_u8(0)?; resp_buf.le_u32(self.ctr)?; - if let Some(d) = self.peer_nodeid { - resp_buf.le_u64(d)?; + + if self.flags.contains(MsgFlags::SRC_ADDR_PRESENT) { + resp_buf.le_u64(self.src_nodeid)?; + } + + if !self + .flags + .contains(MsgFlags::DSIZ_UNICAST_NODEID | MsgFlags::DSIZ_GROUPCAST_NODEID) + { + if self.flags.contains(MsgFlags::DSIZ_UNICAST_NODEID) { + resp_buf.le_u64(self.dst_nodeid)?; + } else if self.flags.contains(MsgFlags::DSIZ_GROUPCAST_NODEID) { + resp_buf.le_u16(self.dst_nodeid as u16)?; + } } + Ok(()) } pub fn is_encrypted(&self) -> bool { - self.sess_type == SessionType::Encrypted + self.sess_id != 0 + } +} + +impl fmt::Display for PlainHdr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if !self.flags.is_empty() { + write!(f, "{},", self.flags)?; + } + + write!(f, "SID:{:x},CTR:{:x}", self.sess_id, self.ctr)?; + + if let Some(src_nodeid) = self.get_src_nodeid() { + write!(f, ",SRC:{:x}", src_nodeid)?; + } + + if let Some(dst_nodeid) = self.get_dst_unicast_nodeid() { + write!(f, ",DST:{:x}", dst_nodeid)?; + } + + if let Some(dst_group_nodeid) = self.get_dst_groupcast_nodeid() { + write!(f, ",GRP:{:x}", dst_group_nodeid)?; + } + + Ok(()) } } diff --git a/rs-matter/src/transport/proto_hdr.rs b/rs-matter/src/transport/proto_hdr.rs index 003b93d7..19c2ccad 100644 --- a/rs-matter/src/transport/proto_hdr.rs +++ b/rs-matter/src/transport/proto_hdr.rs @@ -23,7 +23,9 @@ use crate::utils::parsebuf::ParseBuf; use crate::utils::writebuf::WriteBuf; use crate::{crypto, error::*}; -use log::{info, trace}; +use log::{trace, warn}; + +use super::network::Address; bitflags! { #[repr(transparent)] @@ -37,24 +39,88 @@ bitflags! { } } +impl fmt::Display for ExchFlags { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut sep = false; + for flag in [ + Self::INITIATOR, + Self::ACK, + Self::RELIABLE, + Self::SECEX, + Self::VENDOR, + ] { + if self.contains(flag) { + if sep { + write!(f, "|")?; + } + + let str = match flag { + Self::INITIATOR => "I", + Self::ACK => "A", + Self::RELIABLE => "R", + Self::SECEX => "SX", + Self::VENDOR => "V", + _ => "?", + }; + + write!(f, "{}", str)?; + sep = true; + } + } + + Ok(()) + } +} + #[derive(Debug, Default, Clone)] pub struct ProtoHdr { pub exch_id: u16, - pub exch_flags: ExchFlags, + exch_flags: ExchFlags, pub proto_id: u16, pub proto_opcode: u8, - pub proto_vendor_id: Option, - pub ack_msg_ctr: Option, + proto_vendor_id: u16, + ack_msg_ctr: u32, } impl ProtoHdr { - pub fn is_vendor(&self) -> bool { - self.exch_flags.contains(ExchFlags::VENDOR) + #[inline(always)] + pub const fn new() -> Self { + Self { + exch_id: 0, + exch_flags: ExchFlags::empty(), + proto_id: 0, + proto_opcode: 0, + proto_vendor_id: 0, + ack_msg_ctr: 0, + } } - pub fn set_vendor(&mut self, proto_vendor_id: u16) { - self.exch_flags |= ExchFlags::RELIABLE; - self.proto_vendor_id = Some(proto_vendor_id); + pub fn opcode(&self) -> Result { + num::FromPrimitive::from_u8(self.proto_opcode).ok_or(ErrorCode::Invalid.into()) + } + + pub fn check_opcode(&self, opcode: T) -> Result<(), Error> { + if self.opcode::()? == opcode { + Ok(()) + } else { + Err(ErrorCode::Invalid.into()) + } + } + + pub fn get_vendor(&self) -> Option { + self.exch_flags + .contains(ExchFlags::VENDOR) + .then_some(self.proto_vendor_id) + } + + pub fn set_vendor(&mut self, vendor_id: Option) { + if let Some(vendor_id) = vendor_id { + self.exch_flags |= ExchFlags::VENDOR; + self.proto_vendor_id = vendor_id; + } else { + self.exch_flags.remove(ExchFlags::VENDOR); + self.proto_vendor_id = 0; + } } pub fn is_security_ext(&self) -> bool { @@ -73,27 +139,66 @@ impl ProtoHdr { self.exch_flags |= ExchFlags::RELIABLE; } - pub fn is_ack(&self) -> bool { - self.exch_flags.contains(ExchFlags::ACK) + pub fn get_ack(&self) -> Option { + self.exch_flags + .contains(ExchFlags::ACK) + .then_some(self.ack_msg_ctr) } - pub fn get_ack_msg_ctr(&self) -> Option { - self.ack_msg_ctr - } - - pub fn set_ack(&mut self, ack_msg_ctr: u32) { - self.exch_flags |= ExchFlags::ACK; - self.ack_msg_ctr = Some(ack_msg_ctr); + pub fn set_ack(&mut self, ack_msg_ctr: Option) { + if let Some(ack_msg_ctr) = ack_msg_ctr { + self.exch_flags |= ExchFlags::ACK; + self.ack_msg_ctr = ack_msg_ctr; + } else { + self.exch_flags.remove(ExchFlags::ACK); + self.ack_msg_ctr = 0; + } } pub fn is_initiator(&self) -> bool { self.exch_flags.contains(ExchFlags::INITIATOR) } + pub fn unset_initiator(&mut self) { + self.exch_flags.remove(ExchFlags::INITIATOR); + } + pub fn set_initiator(&mut self) { self.exch_flags |= ExchFlags::INITIATOR; } + pub fn toggle_initiator(&mut self) { + if self.is_initiator() { + self.unset_initiator(); + } else { + self.set_initiator(); + } + } + + /// Adjusts the reliability settings (flags R and A) in the proto header + /// by inspecting the reliability of the network protocol itself. + /// + /// In case the protocol is reliable - yet the message has the R or A flags set - + /// these flags are lowered. Warnings will be logged in this case if the `rx` parameter + /// is set to `true` (i.e. this is an incoming message), because this situation + /// represents a Matter protocol violation, as per the Matter spec. + pub fn adjust_reliability(&mut self, rx: bool, addr: &Address) { + if addr.is_reliable() { + if rx { + if self.is_reliable() { + warn!("Detected a reliable message over a reliable transport; reliability request will not be honored with an ACK"); + } + + if self.get_ack().is_some() { + warn!("Detected an ACK counter over a reliable transport; ACK counter will be discarded"); + } + } + + self.unset_reliable(); + self.set_ack(None); + } + } + pub fn decrypt_and_decode( &mut self, plain_hdr: &plain_hdr::PlainHdr, @@ -111,28 +216,28 @@ impl ProtoHdr { self.exch_id = parsebuf.le_u16()?; self.proto_id = parsebuf.le_u16()?; - info!("[decode] {} ", self); - if self.is_vendor() { - self.proto_vendor_id = Some(parsebuf.le_u16()?); + if self.exch_flags.contains(ExchFlags::VENDOR) { + self.proto_vendor_id = parsebuf.le_u16()?; } - if self.is_ack() { - self.ack_msg_ctr = Some(parsebuf.le_u32()?); + if self.exch_flags.contains(ExchFlags::ACK) { + self.ack_msg_ctr = parsebuf.le_u32()?; } + trace!("[decode] {}", self); trace!("[rx payload]: {:x?}", parsebuf.as_mut_slice()); Ok(()) } - pub fn encode(&mut self, resp_buf: &mut WriteBuf) -> Result<(), Error> { - info!("[encode] {}", self); + pub fn encode(&self, resp_buf: &mut WriteBuf) -> Result<(), Error> { + trace!("[encode] {}", self); resp_buf.le_u8(self.exch_flags.bits())?; resp_buf.le_u8(self.proto_opcode)?; resp_buf.le_u16(self.exch_id)?; resp_buf.le_u16(self.proto_id)?; - if self.is_vendor() { - resp_buf.le_u16(self.proto_vendor_id.ok_or(ErrorCode::Invalid)?)?; + if let Some(vendor_id) = self.get_vendor() { + resp_buf.le_u16(vendor_id)?; } - if self.is_ack() { - resp_buf.le_u32(self.ack_msg_ctr.ok_or(ErrorCode::Invalid)?)?; + if let Some(ack_msg_ctr) = self.get_ack() { + resp_buf.le_u32(ack_msg_ctr)?; } Ok(()) } @@ -140,27 +245,25 @@ impl ProtoHdr { impl fmt::Display for ProtoHdr { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut flag_str = heapless::String::<16>::new(); - if self.is_vendor() { - flag_str.push_str("V|").unwrap(); - } - if self.is_security_ext() { - flag_str.push_str("SX|").unwrap(); - } - if self.is_reliable() { - flag_str.push_str("R|").unwrap(); - } - if self.is_ack() { - flag_str.push_str("A|").unwrap(); - } - if self.is_initiator() { - flag_str.push_str("I|").unwrap(); + if !self.exch_flags.is_empty() { + write!(f, "{},", self.exch_flags)?; } + write!( f, - "ExId: {}, Proto: {}, Opcode: {}, Flags: {}", - self.exch_id, self.proto_id, self.proto_opcode, flag_str - ) + "EID:{:x},PROTO:{:x},OP:{:x}", + self.exch_id, self.proto_id, self.proto_opcode + )?; + + if let Some(ack_msg_ctr) = self.get_ack() { + write!(f, ",ACTR:{:x}", ack_msg_ctr)?; + } + + if let Some(vendor_id) = self.get_vendor() { + write!(f, ",VID:{:x}", vendor_id)?; + } + + Ok(()) } } diff --git a/rs-matter/src/transport/session.rs b/rs-matter/src/transport/session.rs index aa91b2c4..ad71a542 100644 --- a/rs-matter/src/transport/session.rs +++ b/rs-matter/src/transport/session.rs @@ -15,18 +15,29 @@ * limitations under the License. */ -use crate::data_model::sdm::noc::NocData; -use crate::utils::epoch::Epoch; -use crate::utils::rand::Rand; +use core::cell::RefCell; use core::fmt; use core::time::Duration; -use crate::{error::*, transport::plain_hdr}; -use log::info; +use log::{error, info, trace, warn}; + +use crate::data_model::sdm::noc::NocData; +use crate::error::*; +use crate::transport::exchange::ExchangeId; +use crate::transport::mrp::ReliableMessage; +use crate::utils::epoch::Epoch; +use crate::utils::parsebuf::ParseBuf; +use crate::utils::rand::Rand; +use crate::utils::writebuf::WriteBuf; +use crate::Matter; use super::dedup::RxCtrState; -use super::exchange::SessionId; -use super::{network::Address, packet::Packet}; +use super::exchange::{ExchangeState, Role}; +use super::mrp::RetransEntry; +use super::network::Address; +use super::packet::PacketHdr; +use super::plain_hdr::PlainHdr; +use super::proto_hdr::ProtoHdr; pub const MAX_CAT_IDS_PER_NOC: usize = 3; pub type NocCatIds = [u32; MAX_CAT_IDS_PER_NOC]; @@ -58,6 +69,8 @@ pub enum SessionMode { } pub struct Session { + // Internal ID which is guaranteeed to be unique accross all sessions and not change when sessions are added/removed + pub(crate) id: u32, peer_addr: Address, local_nodeid: u64, peer_nodeid: Option, @@ -72,50 +85,23 @@ pub struct Session { rx_ctr_state: RxCtrState, mode: SessionMode, data: Option, + pub(crate) exchanges: heapless::Vec, MAX_EXCHANGES>, last_use: Duration, + reserved: bool, } -#[derive(Debug)] -pub struct CloneData { - pub dec_key: [u8; MATTER_AES128_KEY_SIZE], - pub enc_key: [u8; MATTER_AES128_KEY_SIZE], - pub att_challenge: [u8; MATTER_AES128_KEY_SIZE], - local_sess_id: u16, - peer_sess_id: u16, - local_nodeid: u64, - peer_nodeid: u64, - peer_addr: Address, - mode: SessionMode, -} - -impl CloneData { +impl Session { pub fn new( - local_nodeid: u64, - peer_nodeid: u64, - peer_sess_id: u16, - local_sess_id: u16, + id: u32, + reserved: bool, peer_addr: Address, - mode: SessionMode, - ) -> CloneData { - CloneData { - dec_key: [0; MATTER_AES128_KEY_SIZE], - enc_key: [0; MATTER_AES128_KEY_SIZE], - att_challenge: [0; MATTER_AES128_KEY_SIZE], - local_nodeid, - peer_nodeid, - peer_addr, - peer_sess_id, - local_sess_id, - mode, - } - } -} - -const MATTER_MSG_CTR_RANGE: u32 = 0x0fffffff; - -impl Session { - pub fn new(peer_addr: Address, peer_nodeid: Option, epoch: Epoch, rand: Rand) -> Self { + peer_nodeid: Option, + epoch: Epoch, + rand: Rand, + ) -> Self { Self { + id, + reserved, peer_addr, local_nodeid: 0, peer_nodeid, @@ -128,38 +114,11 @@ impl Session { rx_ctr_state: RxCtrState::new(0), mode: SessionMode::PlainText, data: None, + exchanges: heapless::Vec::new(), last_use: epoch(), } } - // A new encrypted session always clones from a previous 'new' session - pub fn clone(clone_from: &CloneData, epoch: Epoch, rand: Rand) -> Session { - Session { - peer_addr: clone_from.peer_addr, - local_nodeid: clone_from.local_nodeid, - peer_nodeid: Some(clone_from.peer_nodeid), - dec_key: clone_from.dec_key, - enc_key: clone_from.enc_key, - att_challenge: clone_from.att_challenge, - local_sess_id: clone_from.local_sess_id, - peer_sess_id: clone_from.peer_sess_id, - msg_ctr: Self::rand_msg_ctr(rand), - rx_ctr_state: RxCtrState::new(0), - mode: clone_from.mode.clone(), - data: None, - last_use: epoch(), - } - } - - pub fn id(&self) -> SessionId { - SessionId { - id: self.local_sess_id, - peer_addr: self.peer_addr, - peer_nodeid: self.peer_nodeid, - is_encrypted: self.is_encrypted(), - } - } - pub fn set_noc_data(&mut self, data: NocData) { self.data = Some(data); } @@ -222,7 +181,7 @@ impl Session { &self.mode } - pub fn get_msg_ctr(&mut self) -> u32 { + fn get_msg_ctr(&mut self) -> u32 { let ctr = self.msg_ctr; self.msg_ctr += 1; ctr @@ -246,30 +205,110 @@ impl Session { &self.att_challenge } - pub fn recv(&mut self, epoch: Epoch, rx: &mut Packet) -> Result<(), Error> { - self.last_use = epoch(); - rx.proto_decode(self.peer_nodeid.unwrap_or_default(), self.get_dec_key()) + pub(crate) fn is_for_node(&self, node_id: u64, secure: bool) -> bool { + self.peer_nodeid == Some(node_id) && self.is_encrypted() == secure && !self.reserved + } + + pub(crate) fn is_for_rx(&self, rx_peer: &Address, rx_plain: &PlainHdr) -> bool { + let nodeid_matches = self.peer_nodeid.is_none() + || rx_plain.get_src_nodeid().is_none() + || self.peer_nodeid == rx_plain.get_src_nodeid(); + + nodeid_matches + && self.local_sess_id == rx_plain.sess_id + && self.peer_addr == *rx_peer + && self.is_encrypted() == rx_plain.is_encrypted() + && !self.reserved } - pub fn pre_send(&mut self, tx: &mut Packet) -> Result<(), Error> { - tx.plain.sess_id = self.get_peer_sess_id(); - tx.plain.ctr = self.get_msg_ctr(); - if self.is_encrypted() { - tx.plain.sess_type = plain_hdr::SessionType::Encrypted; + pub(crate) fn post_recv( + &mut self, + rx_header: &mut PacketHdr, + pb: &mut ParseBuf, + epoch: Epoch, + ) -> Result { + self.decode_remaining(rx_header, pb)?; + + rx_header.proto.adjust_reliability(true, &self.peer_addr); + + if !self + .rx_ctr_state + .post_recv(rx_header.plain.ctr, self.is_encrypted()) + { + Err(ErrorCode::Duplicate)?; + } + + let exch_index = self.get_exch_for_rx(&rx_header.proto); + if let Some(exch_index) = exch_index { + let exch = self.exchanges[exch_index].as_mut().unwrap(); + + exch.post_recv(&rx_header.plain, &rx_header.proto, epoch)?; + + Ok(false) + } else { + if !rx_header.proto.is_initiator() { + Err(ErrorCode::NoExchange)?; + } + + if let Some(exch_index) = + self.add_exch(rx_header.proto.exch_id, Role::Responder(Default::default())) + { + let exch = self.exchanges[exch_index].as_mut().unwrap(); + + exch.post_recv(&rx_header.plain, &rx_header.proto, epoch)?; + + Ok(true) + } else { + Err(ErrorCode::NoSpaceExchanges)? + } } - Ok(()) } - pub(crate) fn send(&mut self, epoch: Epoch, tx: &mut Packet) -> Result<(), Error> { - self.last_use = epoch(); + pub(crate) fn pre_send( + &mut self, + exch_index: Option, + tx_header: &mut PacketHdr, + epoch: Epoch, + ) -> Result<(Address, bool), Error> { + let ctr = if let Some(exchange_index) = exch_index { + let exchange = self.exchanges[exchange_index].as_mut().unwrap(); + exchange.mrp.retrans.as_ref().map(RetransEntry::get_msg_ctr) + } else { + None + }; - tx.proto_encode( - self.peer_addr, - self.peer_nodeid, - self.local_nodeid, - self.mode == SessionMode::PlainText, - self.get_enc_key(), - ) + let retransmission = ctr.is_some(); + + tx_header.plain.sess_id = self.get_peer_sess_id(); + tx_header.plain.ctr = ctr.unwrap_or_else(|| self.get_msg_ctr()); + tx_header.plain.set_src_nodeid(None); + tx_header.plain.set_dst_unicast_nodeid( + (self.mode == SessionMode::PlainText) + .then_some(self.peer_nodeid) + .flatten(), + ); + + tx_header.proto.adjust_reliability(false, &self.peer_addr); + + if let Some(exchange_index) = exch_index { + let exchange = self.exchanges[exchange_index].as_mut().unwrap(); + + exchange.pre_send(&tx_header.plain, &mut tx_header.proto, epoch)?; + } + + Ok((self.peer_addr, retransmission)) + } + + fn decode_remaining(&self, rx: &mut PacketHdr, pb: &mut ParseBuf) -> Result<(), Error> { + rx.decode_remaining(pb, self.peer_nodeid.unwrap_or_default(), self.get_dec_key()) + } + + pub(crate) fn encode(&self, tx: &PacketHdr, wb: &mut WriteBuf) -> Result<(), Error> { + tx.encode(wb, self.local_nodeid, self.get_enc_key()) + } + + fn update_last_used(&mut self, epoch: Epoch) { + self.last_use = epoch(); } fn rand_msg_ctr(rand: Rand) -> u32 { @@ -277,6 +316,71 @@ impl Session { rand(&mut buf); u32::from_be_bytes(buf) & MATTER_MSG_CTR_RANGE } + + pub(crate) fn get_exch_for_rx(&self, rx_proto: &ProtoHdr) -> Option { + self.exchanges + .iter() + .enumerate() + .filter(|(_, exch)| { + exch.as_ref() + .map(|exch| exch.is_for_rx(rx_proto)) + .unwrap_or(false) + }) + .map(|(index, _)| index) + .next() + } + + pub(crate) fn add_exch(&mut self, exch_id: u16, role: Role) -> Option { + let exch_state = Some(ExchangeState { + exch_id, + role, + mrp: ReliableMessage::new(), + }); + + if self.exchanges.len() < MAX_EXCHANGES { + let _ = self.exchanges.push(exch_state); + + info!("Creating a new exchange: {:x}/{:?}", exch_id, role); + Some(self.exchanges.len() - 1) + } else { + let index = self.exchanges.iter().position(Option::is_none); + + if let Some(index) = index { + self.exchanges[index] = exch_state; + + info!("Creating a new exchange: {:x}/{:?}", exch_id, role); + Some(index) + } else { + error!( + "Too many exchanges for session {:x}; exchange creation failed", + self.id + ); + None + } + } + } + + pub(crate) fn remove_exch(&mut self, index: usize) -> bool { + let exchange = self.exchanges[index].as_mut().unwrap(); + let exchange_id = ExchangeId::new(self.id, index); + + if exchange.mrp.is_retrans_pending() { + exchange.role.set_dropped_state(); + error!("Exchange {exchange_id}, session {self}: A packet is still (re)transmitted! Marking as dropped, but session will be closed"); + + false + } else if exchange.mrp.is_ack_pending() { + exchange.role.set_dropped_state(); + warn!("Exchange {exchange_id}, session {self}: Pending ACK. Marking as dropped"); + + false + } else { + self.exchanges[index] = None; + trace!("Exchange {exchange_id}: Dropped cleanly"); + + true + } + } } impl fmt::Display for Session { @@ -295,11 +399,105 @@ impl fmt::Display for Session { } } -pub const MAX_SESSIONS: usize = 16; +pub struct ReservedSession<'a> { + id: u32, + session_mgr: &'a RefCell, + complete: bool, +} + +impl<'a> ReservedSession<'a> { + pub fn reserve_now(matter: &'a Matter<'a>) -> Result { + let mut mgr = matter.transport_mgr.session_mgr.borrow_mut(); + + let id = mgr + .add(true, Address::new(), None) + .ok_or(ErrorCode::NoSpaceSessions)? + .id; + + Ok(Self { + id, + session_mgr: &matter.transport_mgr.session_mgr, + complete: false, + }) + } + + pub async fn reserve(matter: &'a Matter<'a>) -> Result, Error> { + let session = Self::reserve_now(matter); + + if let Ok(session) = session { + Ok(session) + } else { + matter.transport_mgr.evict_some_session().await?; + + Self::reserve_now(matter) + } + } + + #[allow(clippy::too_many_arguments)] + pub fn update( + &mut self, + local_nodeid: u64, + peer_nodeid: u64, + peer_sessid: u16, + local_sessid: u16, + peer_addr: Address, + mode: SessionMode, + dec_key: Option<&[u8]>, + enc_key: Option<&[u8]>, + att_challenge: Option<&[u8]>, + ) -> Result<(), Error> { + let mut mgr = self.session_mgr.borrow_mut(); + let session = mgr.get(self.id).ok_or(ErrorCode::NoSession)?; + + session.local_nodeid = local_nodeid; + session.peer_nodeid = Some(peer_nodeid); + session.peer_sess_id = peer_sessid; + session.local_sess_id = local_sessid; + session.peer_addr = peer_addr; + session.mode = mode; + + if let Some(dec_key) = dec_key { + session.dec_key.copy_from_slice(dec_key); + } + + if let Some(enc_key) = enc_key { + session.enc_key.copy_from_slice(enc_key); + } + + if let Some(att_challenge) = att_challenge { + session.att_challenge.copy_from_slice(att_challenge); + } + + Ok(()) + } + + pub fn complete(&mut self) { + self.complete = true; + } +} + +impl<'a> Drop for ReservedSession<'a> { + fn drop(&mut self) { + if self.complete { + let mut session_mgr = self.session_mgr.borrow_mut(); + let session = session_mgr.get(self.id).unwrap(); + session.reserved = false; + } else { + self.session_mgr.borrow_mut().remove(self.id); + } + } +} + +const MAX_SESSIONS: usize = 16; +const MAX_EXCHANGES: usize = 5; + +const MATTER_MSG_CTR_RANGE: u32 = 0x0fffffff; pub struct SessionMgr { + next_sess_unique_id: u32, next_sess_id: u16, - sessions: heapless::Vec, MAX_SESSIONS>, + next_exch_id: u16, + sessions: heapless::Vec, pub(crate) epoch: Epoch, pub(crate) rand: Rand, } @@ -309,7 +507,9 @@ impl SessionMgr { pub const fn new(epoch: Epoch, rand: Rand) -> Self { Self { sessions: heapless::Vec::new(), + next_sess_unique_id: 0, next_sess_id: 1, + next_exch_id: 1, epoch, rand, } @@ -318,10 +518,7 @@ impl SessionMgr { pub fn reset(&mut self) { self.sessions.clear(); self.next_sess_id = 1; - } - - pub fn mut_by_index(&mut self, index: usize) -> Option<&mut Session> { - self.sessions.get_mut(index).and_then(Option::as_mut) + self.next_exch_id = 1; } pub fn get_next_sess_id(&mut self) -> u16 { @@ -336,153 +533,170 @@ impl SessionMgr { } // Ensure the currently selected id doesn't match any existing session - if self.sessions.iter().all(|sess| { - sess.as_ref() - .map(|sess| sess.get_local_sess_id() != next_sess_id) - .unwrap_or(true) - }) { + if self + .sessions + .iter() + .all(|sess| sess.get_local_sess_id() != next_sess_id) + { break; } } next_sess_id } - pub fn get_session_for_eviction(&self) -> Option { - if self.sessions.len() == MAX_SESSIONS && self.get_empty_slot().is_none() { - Some(self.get_lru()) - } else { - None - } - } + pub fn get_next_exch_id(&mut self) -> u16 { + let mut next_exch_id: u16; + loop { + next_exch_id = self.next_exch_id; + + // Increment next exch id + self.next_exch_id = self.next_exch_id.overflowing_add(1).0; + if self.next_exch_id == 0 { + self.next_exch_id = 1; + } - fn get_empty_slot(&self) -> Option { - self.sessions.iter().position(|x| x.is_none()) + // Ensure the currently selected id doesn't match any existing exchange + if self + .sessions + .iter() + .flat_map(|sess| sess.exchanges.iter()) + .filter_map(|exch| exch.as_ref()) + .all(|exch| { + !matches!(exch.role, Role::Responder(_)) || exch.exch_id != next_exch_id + }) + { + break; + } + } + next_exch_id } - fn get_lru(&self) -> usize { - let mut lru_index = 0; + pub fn get_session_for_eviction(&mut self) -> Option<&mut Session> { + let mut lru_index = None; let mut lru_ts = (self.epoch)(); for (i, s) in self.sessions.iter().enumerate() { - if let Some(s) = s { - if s.last_use < lru_ts { - lru_ts = s.last_use; - lru_index = i; - } + if s.last_use < lru_ts && !s.reserved && s.exchanges.iter().all(Option::is_none) { + lru_ts = s.last_use; + lru_index = Some(i); } } - lru_index + + lru_index.map(|index| &mut self.sessions[index]) } - pub fn add(&mut self, peer_addr: Address, peer_nodeid: Option) -> Result { - let session = Session::new(peer_addr, peer_nodeid, self.epoch, self.rand); - self.add_session(session) + pub fn add( + &mut self, + reserved: bool, + peer_addr: Address, + peer_nodeid: Option, + ) -> Option<&mut Session> { + let session_id = self.next_sess_unique_id; + + self.next_sess_unique_id += 1; + if self.next_sess_unique_id > 0x0fff_ffff { + // Reserve the upper 4 bits for the exchange index + self.next_sess_unique_id = 0; + } + + let session = Session::new( + session_id, + reserved, + peer_addr, + peer_nodeid, + self.epoch, + self.rand, + ); + + self.sessions.push(session).ok()?; + + Some(self.sessions.last_mut().unwrap()) } /// This assumes that the higher layer has taken care of doing anything required /// as per the spec before the session is erased - pub fn remove(&mut self, idx: usize) { - self.sessions[idx] = None; - } - - /// We could have returned a SessionHandle here. But the borrow checker doesn't support - /// non-lexical lifetimes. This makes it harder for the caller of this function to take - /// action in the error return path - fn add_session(&mut self, session: Session) -> Result { - if let Some(index) = self.get_empty_slot() { - self.sessions[index] = Some(session); - Ok(index) - } else if self.sessions.len() < MAX_SESSIONS { - self.sessions - .push(Some(session)) - .map_err(|_| ErrorCode::NoSpaceSessions) - .unwrap(); - - Ok(self.sessions.len() - 1) + pub fn remove(&mut self, id: u32) -> Option { + if let Some(index) = self.sessions.iter().position(|sess| sess.id == id) { + Some(self.sessions.swap_remove(index)) } else { - Err(ErrorCode::NoSpaceSessions.into()) + None } } - pub fn clone_session(&mut self, clone_data: &CloneData) -> Result { - let session = Session::clone(clone_data, self.epoch, self.rand); - self.add_session(session) + pub fn get(&mut self, id: u32) -> Option<&mut Session> { + let mut session = self.sessions.iter_mut().find(|sess| sess.id == id); + + if let Some(session) = session.as_mut() { + session.update_last_used(self.epoch); + } + + session } - pub fn get( - &self, - sess_id: u16, - peer_addr: Address, - peer_nodeid: Option, - is_encrypted: bool, - ) -> Option { - self.sessions.iter().position(|x| { - if let Some(x) = x { - let mut nodeid_matches = true; - if x.peer_nodeid.is_some() && peer_nodeid.is_some() && x.peer_nodeid != peer_nodeid - { - nodeid_matches = false; - } - x.local_sess_id == sess_id - && x.peer_addr == peer_addr - && x.is_encrypted() == is_encrypted - && nodeid_matches - } else { - false - } - }) + pub(crate) fn get_for_node(&mut self, node_id: u64, secure: bool) -> Option<&mut Session> { + let mut session = self + .sessions + .iter_mut() + .find(|sess| sess.is_for_node(node_id, secure)); + + if let Some(session) = session.as_mut() { + session.update_last_used(self.epoch); + } + + session } - pub fn get_or_add( + pub(crate) fn get_for_rx( &mut self, - sess_id: u16, - peer_addr: Address, - peer_nodeid: Option, - is_encrypted: bool, - ) -> Result { - if let Some(index) = self.get(sess_id, peer_addr, peer_nodeid, is_encrypted) { - Ok(index) - } else if sess_id == 0 && !is_encrypted { - // We must create a new session for this case - info!("Creating new session"); - self.add(peer_addr, peer_nodeid) - } else { - Err(ErrorCode::NotFound.into()) + rx_peer: &Address, + rx_plain: &PlainHdr, + ) -> Option<&mut Session> { + let mut session = self + .sessions + .iter_mut() + .find(|sess| sess.is_for_rx(rx_peer, rx_plain)); + + if let Some(session) = session.as_mut() { + session.update_last_used(self.epoch); } - } - // We will try to get a session for this Packet. If no session exists, we will try to add one - // If the session list is full we will return a None - pub fn post_recv(&mut self, rx: &Packet) -> Result { - let sess_index = self.get_or_add( - rx.plain.sess_id, - rx.peer, - rx.plain.get_src_u64(), - rx.plain.is_encrypted(), - )?; - - let session = self.sessions[sess_index].as_mut().unwrap(); - let is_encrypted = session.is_encrypted(); - let duplicate = session.rx_ctr_state.recv(rx.plain.ctr, is_encrypted); - if duplicate { - info!("Dropping duplicate packet"); - Err(ErrorCode::Duplicate.into()) + session + } + + pub(crate) fn get_exch(&mut self, f: F) -> Option<(&mut Session, usize)> + where + F: Fn(&Session, &ExchangeState) -> bool, + { + let exch = self + .sessions + .iter() + .flat_map(|sess| { + sess.exchanges + .iter() + .enumerate() + .filter_map(move |(exch_index, exch)| { + exch.as_ref().map(|exch| (sess, exch, exch_index)) + }) + }) + .filter(|(sess, exch, _)| f(sess, exch)) + .map(|(sess, _, exch_index)| (sess.id, exch_index)) + .next(); + + if let Some((id, exch_index)) = exch { + let epoch = self.epoch; + let session = self.get(id).unwrap(); + session.update_last_used(epoch); + + Some((session, exch_index)) } else { - Ok(sess_index) + None } } - - pub fn send(&mut self, sess_idx: usize, tx: &mut Packet) -> Result<(), Error> { - self.sessions[sess_idx] - .as_mut() - .ok_or(ErrorCode::NoSession)? - .send(self.epoch, tx) - } } impl fmt::Display for SessionMgr { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!(f, "{{[")?; - for s in self.sessions.iter().flatten() { + for s in &self.sessions { writeln!(f, "{{ {}, }},", s)?; } write!(f, "], next_sess_id: {}", self.next_sess_id)?; @@ -503,13 +717,11 @@ mod tests { #[test] fn test_next_sess_id_doesnt_reuse() { let mut sm = SessionMgr::new(dummy_epoch, dummy_rand); - let sess_idx = sm.add(Address::default(), None).unwrap(); - let sess = sm.mut_by_index(sess_idx).unwrap(); + let sess = sm.add(false, Address::default(), None).unwrap(); sess.set_local_sess_id(1); assert_eq!(sm.get_next_sess_id(), 2); assert_eq!(sm.get_next_sess_id(), 3); - let sess_idx = sm.add(Address::default(), None).unwrap(); - let sess = sm.mut_by_index(sess_idx).unwrap(); + let sess = sm.add(false, Address::default(), None).unwrap(); sess.set_local_sess_id(4); assert_eq!(sm.get_next_sess_id(), 5); } @@ -517,8 +729,7 @@ mod tests { #[test] fn test_next_sess_id_overflows() { let mut sm = SessionMgr::new(dummy_epoch, dummy_rand); - let sess_idx = sm.add(Address::default(), None).unwrap(); - let sess = sm.mut_by_index(sess_idx).unwrap(); + let sess = sm.add(false, Address::default(), None).unwrap(); sess.set_local_sess_id(1); assert_eq!(sm.get_next_sess_id(), 2); sm.next_sess_id = 65534; diff --git a/rs-matter/src/utils/buf.rs b/rs-matter/src/utils/buf.rs index 8f9e50ca..898f3bd5 100644 --- a/rs-matter/src/utils/buf.rs +++ b/rs-matter/src/utils/buf.rs @@ -15,67 +15,167 @@ * limitations under the License. */ +use core::cell::UnsafeCell; use core::ops::{Deref, DerefMut}; +use core::pin::pin; -use embassy_sync::blocking_mutex::raw::NoopRawMutex; -use embassy_sync::mutex::{Mutex, MutexGuard}; +use embassy_futures::select::{select, Either}; +use embassy_sync::blocking_mutex::raw::RawMutex; +use embassy_time::{Duration, Timer}; -/// A trait for concurrently accessing a &mut [u8] buffer from multiple async tasks. -pub trait BufferAccess { - type Buffer<'a>: DerefMut +use super::signal::Signal; + +/// A trait for getting access to a `&mut T` buffer, potentially awaiting until a buffer becomes available. +pub trait BufferAccess +where + T: ?Sized, +{ + type Buffer<'a>: DerefMut where Self: 'a; - /// Get a reference to the buffer. - /// Await until the buffer is available, as it might be in use by somebody else. - async fn get(&self) -> Self::Buffer<'_>; + /// Get a reference to a buffer. + /// Might await until a buffer is available, as it might be in use by somebody else. + /// + /// Depending on its internal implementation details, access to a buffer might also be denied + /// immediately, or after a certain amount of time (subject to the concrete implementation of the method). + /// In that case, the method will return `None`. + async fn get(&self) -> Option>; } -impl BufferAccess for &T +impl BufferAccess for &B where - T: BufferAccess, + B: BufferAccess, + T: ?Sized, { - type Buffer<'a> = T::Buffer<'a> where Self: 'a; + type Buffer<'a> = B::Buffer<'a> where Self: 'a; - async fn get(&self) -> Self::Buffer<'_> { + async fn get(&self) -> Option> { (*self).get().await } } -/// A concrete implementation of `BufferAccess` utilizing a single internal buffer. -pub struct BufferAccessImpl(Mutex>); +/// A concrete implementation of `BufferAccess` utilizing an internal pool of buffers. +/// Accessing a buffer would fail when all buffers are still used elsewhere after a wait timeout expires. +pub struct PooledBuffers { + available: Signal, + pool: UnsafeCell>, + wait_timeout_ms: u32, +} -impl BufferAccessImpl { +impl PooledBuffers +where + M: RawMutex, +{ #[inline(always)] - pub const fn new() -> Self { - Self(Mutex::new(heapless::Vec::new())) + pub const fn new(wait_timeout_ms: u32) -> Self { + Self { + available: Signal::new([true; N]), + pool: UnsafeCell::new(heapless::Vec::new()), + wait_timeout_ms, + } } } -impl BufferAccess for BufferAccessImpl { - type Buffer<'a> = BufferImpl<'a, N> where Self: 'a; - - async fn get(&self) -> Self::Buffer<'_> { - let mut guard = self.0.lock().await; +impl BufferAccess for PooledBuffers +where + M: RawMutex, + T: Default + Clone, +{ + type Buffer<'b> = PooledBuffer<'b, N, M, T> where Self: 'b; + + async fn get(&self) -> Option> { + if self.wait_timeout_ms > 0 { + let mut wait = pin!(self.available.wait(|available| { + if let Some(index) = available.iter().position(|a| *a) { + available[index] = false; + Some(index) + } else { + None + } + })); + + let mut timeout = pin!(Timer::after(Duration::from_millis( + self.wait_timeout_ms as u64 + ))); + + let result = select(&mut wait, &mut timeout).await; + + match result { + Either::First(index) => { + let buffer = &mut unsafe { self.pool.get().as_mut() }.unwrap()[index]; + + Some(PooledBuffer { + index, + buffer, + access: self, + }) + } + Either::Second(()) => None, + } + } else { + let index = self.available.modify(|available| { + if let Some(index) = available.iter().position(|a| *a) { + available[index] = false; + (false, Some(index)) + } else { + (false, None) + } + }); + + index.map(|index| { + let buffers = unsafe { self.pool.get().as_mut() }.unwrap(); + buffers.resize_default(N).unwrap(); + + let buffer = &mut buffers[index]; + + PooledBuffer { + index, + buffer, + access: self, + } + }) + } + } +} - guard.resize_default(N).unwrap(); +pub struct PooledBuffer<'a, const N: usize, M, T> +where + M: RawMutex, +{ + index: usize, + buffer: &'a mut T, + access: &'a PooledBuffers, +} - BufferImpl(guard) +impl<'a, const N: usize, M, T> Drop for PooledBuffer<'a, N, M, T> +where + M: RawMutex, +{ + fn drop(&mut self) { + self.access.available.modify(|available| { + available[self.index] = true; + (true, ()) + }); } } -pub struct BufferImpl<'a, const N: usize>(MutexGuard<'a, NoopRawMutex, heapless::Vec>); - -impl<'a, const N: usize> Deref for BufferImpl<'a, N> { - type Target = [u8]; +impl<'a, const N: usize, M, T> Deref for PooledBuffer<'a, N, M, T> +where + M: RawMutex, +{ + type Target = T; fn deref(&self) -> &Self::Target { - &self.0 + self.buffer.deref() } } -impl<'a, const N: usize> DerefMut for BufferImpl<'a, N> { +impl<'a, const N: usize, M, T> DerefMut for PooledBuffer<'a, N, M, T> +where + M: RawMutex, +{ fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 + self.buffer.deref_mut() } } diff --git a/rs-matter/src/utils/ifmutex.rs b/rs-matter/src/utils/ifmutex.rs new file mode 100644 index 00000000..776f450c --- /dev/null +++ b/rs-matter/src/utils/ifmutex.rs @@ -0,0 +1,227 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +//! A variation of the `embassy-sync` async mutex that only locks the mutex if a certain condition on the content of the data holds true. +//! Check `embassy_sync::Mutex` for the original unconditional implementation. +use core::cell::UnsafeCell; +use core::ops::{Deref, DerefMut}; + +use embassy_sync::blocking_mutex::raw::RawMutex; + +use super::signal::Signal; + +/// Error returned by [`Mutex::try_lock`] +#[derive(PartialEq, Eq, Clone, Copy, Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct TryLockError; + +/// Async mutex with conditional locking based on the data inside the mutex. +/// Check `embassy_sync::Mutex` for the original unconditional implementation. +pub struct IfMutex +where + M: RawMutex, + T: ?Sized, +{ + state: Signal, + inner: UnsafeCell, +} + +unsafe impl Send for IfMutex {} +unsafe impl Sync for IfMutex {} + +/// Async mutex. +impl IfMutex +where + M: RawMutex, +{ + /// Create a new mutex with the given value. + #[inline(always)] + pub const fn new(value: T) -> Self { + Self { + state: Signal::::new(false), + inner: UnsafeCell::new(value), + } + } +} + +impl IfMutex +where + M: RawMutex, + T: ?Sized, +{ + /// Lock the mutex. + /// + /// This will wait for the mutex to be unlocked if it's already locked. + pub async fn lock(&self) -> IfMutexGuard<'_, M, T> { + self.lock_if(|_| true).await + } + + /// Lock the mutex. + /// + /// This will wait for the mutex to be unlocked if it's already locked _and_ for the provided condition on the data to become true. + pub async fn lock_if(&self, f: F) -> IfMutexGuard<'_, M, T> + where + F: Fn(&T) -> bool, + { + self.state + .wait(|locked| { + // Safety: it is safe to access the unsafe cell data, because: + // - nobody holds the long term (async) lock on the mutex right now (`locked == false`) + // - we have gained the blocking short-term mutex lock + if !*locked && f(unsafe { &*self.inner.get() }) { + *locked = true; + + Some(()) + } else { + None + } + }) + .await; + + IfMutexGuard { mutex: self } + } + + /// Waits for the mutex to become unlocked and then executes the provided closure. + /// Will become ready only when the callback closure returns a `Some` result. + pub async fn with(&self, mut f: F) -> R + where + F: FnMut(&mut T) -> Option, + { + let result = self + .state + .wait(|locked| { + if !*locked { + // Safety: it is safe to access the unsafe cell data, because: + // - nobody holds the long term (async) lock on the mutex right now (`locked == false`) + // - we have gained the blocking short-term mutex lock + if let Some(result) = f(unsafe { &mut *self.inner.get() }) { + *locked = true; + return Some(result); + } + } + + None + }) + .await; + + // Construct and immediately drop the guard to unlock the mutex + let _ = IfMutexGuard { mutex: self }; + + result + } + + /// Attempt to immediately lock the mutex. + pub fn try_lock(&self) -> Result, TryLockError> { + self.try_lock_if(|_| true) + } + + /// Attempt to immediately lock the mutex. + /// + /// If the mutex is already locked or the condition on the data is not true, this will return an error instead of waiting. + pub fn try_lock_if(&self, mut f: F) -> Result, TryLockError> + where + F: FnMut(&T) -> bool, + { + self.state.modify(|locked| { + if *locked { + (false, Err(TryLockError)) + } else if f(unsafe { &*self.inner.get() }) { + // Safety: it is safe to access the unsafe cell data, because: + // - nobody holds the long term (async) lock on the mutex right now (`locked == false`) + // - we have gained the blocking short-term mutex lock + *locked = true; + (false, Ok(())) + } else { + (false, Err(TryLockError)) + } + })?; + + Ok(IfMutexGuard { mutex: self }) + } + + /// Consumes this mutex, returning the underlying data. + pub fn into_inner(self) -> T + where + T: Sized, + { + self.inner.into_inner() + } + + /// Returns a mutable reference to the underlying data. + /// + /// Since this call borrows the Mutex mutably, no actual locking needs to + /// take place -- the mutable borrow statically guarantees no locks exist. + pub fn get_mut(&mut self) -> &mut T { + self.inner.get_mut() + } +} + +/// Async mutex guard. +/// +/// Owning an instance of this type indicates having +/// successfully locked the mutex, and grants access to the contents. +/// +/// Dropping it unlocks the mutex. +pub struct IfMutexGuard<'a, M, T> +where + M: RawMutex, + T: ?Sized, +{ + mutex: &'a IfMutex, +} + +impl<'a, M, T> Drop for IfMutexGuard<'a, M, T> +where + M: RawMutex, + T: ?Sized, +{ + fn drop(&mut self) { + self.mutex.state.modify(|locked| { + assert!(*locked); + + *locked = false; + + (true, ()) + }) + } +} + +impl<'a, M, T> Deref for IfMutexGuard<'a, M, T> +where + M: RawMutex, + T: ?Sized, +{ + type Target = T; + + fn deref(&self) -> &Self::Target { + // Safety: the MutexGuard represents exclusive access to the contents + // of the mutex, so it's OK to get it. + unsafe { &*(self.mutex.inner.get() as *const T) } + } +} + +impl<'a, M, T> DerefMut for IfMutexGuard<'a, M, T> +where + M: RawMutex, + T: ?Sized, +{ + fn deref_mut(&mut self) -> &mut Self::Target { + // Safety: the MutexGuard represents exclusive access to the contents + // of the mutex, so it's OK to get it. + unsafe { &mut *(self.mutex.inner.get()) } + } +} diff --git a/rs-matter/src/utils/mod.rs b/rs-matter/src/utils/mod.rs index d634894e..0b09e5fc 100644 --- a/rs-matter/src/utils/mod.rs +++ b/rs-matter/src/utils/mod.rs @@ -17,7 +17,10 @@ pub mod buf; pub mod epoch; +pub mod ifmutex; +pub mod notification; pub mod parsebuf; pub mod rand; pub mod select; +pub mod signal; pub mod writebuf; diff --git a/rs-matter/src/utils/notification.rs b/rs-matter/src/utils/notification.rs new file mode 100644 index 00000000..b13320ca --- /dev/null +++ b/rs-matter/src/utils/notification.rs @@ -0,0 +1,46 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use embassy_sync::blocking_mutex::raw::RawMutex; + +use super::signal::Signal; + +/// A notification primitive that allows for notifying a single waiter. +pub struct Notification(Signal>); + +impl Notification +where + M: RawMutex, +{ + /// Create a new `Notification`. + pub const fn new() -> Self { + Self(Signal::new(None)) + } + + /// Notify the waiter. + pub fn notify(&self) { + self.0.modify(|state| { + *state = Some(()); + (true, ()) + }); + } + + /// Wait for the notification. + pub async fn wait(&self) { + self.0.wait(|state| state.take()).await; + } +} diff --git a/rs-matter/src/utils/parsebuf.rs b/rs-matter/src/utils/parsebuf.rs index 233693cd..d96c416f 100644 --- a/rs-matter/src/utils/parsebuf.rs +++ b/rs-matter/src/utils/parsebuf.rs @@ -56,6 +56,10 @@ impl<'a> ParseBuf<'a> { self.left = left; } + pub fn slice_range(&self) -> (usize, usize) { + (self.read_off, self.read_off + self.left) + } + // Return the data that is valid as a slice pub fn as_slice(&self) -> &[u8] { &self.buf[self.read_off..(self.read_off + self.left)] diff --git a/rs-matter/src/utils/select.rs b/rs-matter/src/utils/select.rs index a63c10be..6d147665 100644 --- a/rs-matter/src/utils/select.rs +++ b/rs-matter/src/utils/select.rs @@ -1,38 +1,149 @@ -use embassy_futures::select::{Either, Either3, Either4}; -use embassy_sync::blocking_mutex::raw::NoopRawMutex; +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ -pub type Notification = embassy_sync::signal::Signal; +use core::future::Future; -pub trait EitherUnwrap { - fn unwrap(self) -> T; +use embassy_futures::{ + join::{Join, Join3, Join4, Join5}, + select::{Either, Either3, Either4, Select, Select3, Select4}, +}; + +/// A trait for coalescing the outputs of `embassy_futures::Select*` and `embassy_futures::Join*` futures. +/// +/// - The outputs of the `embassy_futures::Select*` future can be coalesced only +/// if all legs of the `Select*` future return the same type +/// +/// - The outputs of the `embassy_futures::Join*` future can be coalesced only if +/// all legs of the `Join*` future return `Result<(), T>` where T is the same error type. +/// Note that in the case when multiple legs of the `Join*` future resulted in an error, +/// only the error of the leftmost leg is returned, while the others are discarded. +pub trait Coalesce { + fn coalesce(self) -> impl Future; +} + +impl Coalesce for Select +where + F1: Future, + F2: Future, +{ + async fn coalesce(self) -> T { + match self.await { + Either::First(t) => t, + Either::Second(t) => t, + } + } +} + +impl Coalesce for Select3 +where + F1: Future, + F2: Future, + F3: Future, +{ + async fn coalesce(self) -> T { + match self.await { + Either3::First(t) => t, + Either3::Second(t) => t, + Either3::Third(t) => t, + } + } +} + +impl Coalesce for Select4 +where + F1: Future, + F2: Future, + F3: Future, + F4: Future, +{ + async fn coalesce(self) -> T { + match self.await { + Either4::First(t) => t, + Either4::Second(t) => t, + Either4::Third(t) => t, + Either4::Fourth(t) => t, + } + } +} + +impl Coalesce> for Join +where + F1: Future>, + F2: Future>, +{ + async fn coalesce(self) -> Result<(), T> { + match self.await { + (Err(e), _) => Err(e), + (_, Err(e)) => Err(e), + _ => Ok(()), + } + } } -impl EitherUnwrap for Either { - fn unwrap(self) -> T { - match self { - Self::First(t) => t, - Self::Second(t) => t, +impl Coalesce> for Join3 +where + F1: Future>, + F2: Future>, + F3: Future>, +{ + async fn coalesce(self) -> Result<(), T> { + match self.await { + (Err(e), _, _) => Err(e), + (_, Err(e), _) => Err(e), + (_, _, Err(e)) => Err(e), + _ => Ok(()), } } } -impl EitherUnwrap for Either3 { - fn unwrap(self) -> T { - match self { - Self::First(t) => t, - Self::Second(t) => t, - Self::Third(t) => t, +impl Coalesce> for Join4 +where + F1: Future>, + F2: Future>, + F3: Future>, + F4: Future>, +{ + async fn coalesce(self) -> Result<(), T> { + match self.await { + (Err(e), _, _, _) => Err(e), + (_, Err(e), _, _) => Err(e), + (_, _, Err(e), _) => Err(e), + (_, _, _, Err(e)) => Err(e), + _ => Ok(()), } } } -impl EitherUnwrap for Either4 { - fn unwrap(self) -> T { - match self { - Self::First(t) => t, - Self::Second(t) => t, - Self::Third(t) => t, - Self::Fourth(t) => t, +impl Coalesce> for Join5 +where + F1: Future>, + F2: Future>, + F3: Future>, + F4: Future>, + F5: Future>, +{ + async fn coalesce(self) -> Result<(), T> { + match self.await { + (Err(e), _, _, _, _) => Err(e), + (_, Err(e), _, _, _) => Err(e), + (_, _, Err(e), _, _) => Err(e), + (_, _, _, Err(e), _) => Err(e), + (_, _, _, _, Err(e)) => Err(e), + _ => Ok(()), } } } diff --git a/rs-matter/src/utils/signal.rs b/rs-matter/src/utils/signal.rs new file mode 100644 index 00000000..a747f654 --- /dev/null +++ b/rs-matter/src/utils/signal.rs @@ -0,0 +1,97 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use core::cell::RefCell; +use core::future::poll_fn; +use core::task::{Context, Poll}; + +use embassy_sync::blocking_mutex::{raw::RawMutex, Mutex}; +use embassy_sync::waitqueue::WakerRegistration; + +struct State { + state: S, + waker: WakerRegistration, +} + +/// `Signal` is an async synchonization primitive that can be viewed as a generalization of the `embassy_sync::Signal` primitive +/// that takes callback closures. +/// +/// It allows for waiting on a condition of its state `S` to become true, where whether the condition is met is decided by a callback closure. +/// +/// It also allows for modifying the state `S` and waking up the waiters - but only as long as a callback closure provides information that +/// the state is modified in such a way, that the waiters should be notified. +/// +/// The generic nature of `Signal` allows for a wide range of use cases, including the implementation of: +/// - the `Notification` primitive +/// - the `IfMutex` primitive +pub struct Signal(Mutex>>); + +impl Signal +where + M: RawMutex, +{ + /// Crate a `Signal` with the given initial state `S`. + pub const fn new(state: S) -> Self { + Self(Mutex::new(RefCell::new(State { + state, + waker: WakerRegistration::new(), + }))) + } + + // Modify the state `S` and wake up the waiters if necessary. + pub fn modify(&self, f: F) -> R + where + F: FnOnce(&mut S) -> (bool, R), + { + self.0.lock(|s| { + let mut s = s.borrow_mut(); + + let (wake, result) = f(&mut s.state); + + if wake { + s.waker.wake(); + } + + result + }) + } + + // Wait for the condition of the state `S` to become true. + pub async fn wait(&self, mut f: F) -> R + where + F: FnMut(&mut S) -> Option, + { + poll_fn(move |ctx| self.poll_wait(ctx, &mut f)).await + } + + // Poll the condition of the state `S` to become true. + pub fn poll_wait(&self, ctx: &mut Context, f: F) -> Poll + where + F: FnOnce(&mut S) -> Option, + { + self.0.lock(|s| { + let mut s = s.borrow_mut(); + + if let Some(result) = f(&mut s.state) { + Poll::Ready(result) + } else { + s.waker.register(ctx.waker()); + Poll::Pending + } + }) + } +} diff --git a/rs-matter/src/utils/writebuf.rs b/rs-matter/src/utils/writebuf.rs index d091dfb5..f3363cc4 100644 --- a/rs-matter/src/utils/writebuf.rs +++ b/rs-matter/src/utils/writebuf.rs @@ -28,13 +28,17 @@ pub struct WriteBuf<'a> { impl<'a> WriteBuf<'a> { pub fn new(buf: &'a mut [u8]) -> Self { + Self::new_with(buf, 0, 0) + } + + pub fn new_with(buf: &'a mut [u8], start: usize, end: usize) -> Self { let buf_size = buf.len(); Self { buf, buf_size, - start: 0, - end: 0, + start, + end, } } diff --git a/rs-matter/tests/common/handlers.rs b/rs-matter/tests/common/handlers.rs index 198eb739..63d743c9 100644 --- a/rs-matter/tests/common/handlers.rs +++ b/rs-matter/tests/common/handlers.rs @@ -206,7 +206,8 @@ impl<'a> ImEngine<'a> { delay: u16, ) { let mut out = heapless::Vec::<_, 2>::new(); - let write_req = WriteReq::new(false, input); + let mut write_req = WriteReq::new(false, input); + write_req.timed_request = Some(true); self.gen_timed_reqs_output( handler, diff --git a/rs-matter/tests/common/im_engine.rs b/rs-matter/tests/common/im_engine.rs index aa1e8980..3eb08e5e 100644 --- a/rs-matter/tests/common/im_engine.rs +++ b/rs-matter/tests/common/im_engine.rs @@ -17,18 +17,22 @@ use crate::common::echo_cluster; use core::borrow::Borrow; -use core::future::pending; -use core::time::Duration; -use embassy_futures::select::select3; + +use embassy_futures::{block_on, join::join, select::select3}; + use embassy_sync::{ - blocking_mutex::raw::{NoopRawMutex, RawMutex}, + blocking_mutex::raw::NoopRawMutex, zerocopy_channel::{Channel, Receiver, Sender}, }; + +use embassy_time::{Duration, Timer}; + use rs_matter::{ acl::{AclEntry, AuthMode}, data_model::{ cluster_basic_information::{self, BasicInfoConfig}, cluster_on_off::{self, OnOffCluster}, + core::{DataModel, IMBuffer}, device_types::{DEV_TYPE_ON_OFF_LIGHT, DEV_TYPE_ROOT_NODE}, objects::{ AttrData, AttrDataEncoder, AttrDetails, Endpoint, Handler, HandlerCompat, Metadata, @@ -40,6 +44,7 @@ use rs_matter::{ dev_att::{DataType, DevAttDataFetcher}, general_commissioning, noc, nw_commissioning, }, + subscriptions::Subscriptions, system_model::{ access_control, descriptor::{self, DescriptorCluster}, @@ -49,16 +54,18 @@ use rs_matter::{ handler_chain_type, interaction_model::core::{OpCode, PROTO_ID_INTERACTION_MODEL}, mdns::MdnsService, - secure_channel::{self, common::PROTO_ID_SECURE_CHANNEL, spake2p::VerifierData}, + respond::Responder, tlv::{TLVWriter, TagType, ToTLV}, transport::{ - core::PacketBuffers, - network::{Address, Ipv4Addr, NetworkReceive, NetworkSend, SocketAddr, SocketAddrV4}, - packet::{Packet, MAX_RX_BUF_SIZE, MAX_TX_BUF_SIZE}, - session::{CaseDetails, CloneData, NocCatIds, SessionMode}, + exchange::{Exchange, MessageMeta, MAX_EXCHANGE_TX_BUF_SIZE}, + network::{ + Address, Ipv4Addr, NetworkReceive, NetworkSend, SocketAddr, SocketAddrV4, + MAX_RX_PACKET_SIZE, MAX_TX_PACKET_SIZE, + }, + session::{CaseDetails, NocCatIds, ReservedSession, SessionMode}, }, - utils::select::{EitherUnwrap, Notification}, - CommissioningData, Matter, MATTER_PORT, + utils::{buf::PooledBuffers, select::Coalesce}, + Matter, MATTER_PORT, }; use super::echo_cluster::EchoCluster; @@ -137,7 +144,7 @@ impl<'a> ImInput<'a> { pub struct ImOutput { pub action: OpCode, - pub data: heapless::Vec, + pub data: heapless::Vec, } pub struct ImEngineHandler<'a> { @@ -207,6 +214,24 @@ impl<'a> ImEngine<'a> { /// Create the interaction model engine pub fn new(cat_ids: NocCatIds) -> Self { + Self { + matter: Self::new_matter(), + cat_ids, + } + } + + pub fn add_default_acl(&self) { + // Only allow the standard peer node id of the IM Engine + let mut default_acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); + default_acl.add_subject(IM_ENGINE_PEER_ID).unwrap(); + self.matter.acl_mgr.borrow_mut().add(default_acl).unwrap(); + } + + pub fn handler(&self) -> ImEngineHandler<'_> { + ImEngineHandler::new(&self.matter) + } + + fn new_matter() -> Matter<'static> { #[cfg(feature = "std")] use rs_matter::utils::epoch::sys_epoch as epoch; @@ -228,18 +253,31 @@ impl<'a> ImEngine<'a> { MATTER_PORT, ); - Self { matter, cat_ids } - } + matter.initialize_transport_buffers().unwrap(); - pub fn add_default_acl(&self) { - // Only allow the standard peer node id of the IM Engine - let mut default_acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); - default_acl.add_subject(IM_ENGINE_PEER_ID).unwrap(); - self.matter.acl_mgr.borrow_mut().add(default_acl).unwrap(); + matter } - pub fn handler(&self) -> ImEngineHandler<'_> { - ImEngineHandler::new(&self.matter) + fn init_matter(matter: &Matter, local_nodeid: u64, remote_nodeid: u64, cat_ids: &NocCatIds) { + matter.transport_mgr.reset(); + + let mut session = ReservedSession::reserve_now(matter).unwrap(); + + session + .update( + local_nodeid, + remote_nodeid, + 1, + 1, + ADDR, + SessionMode::Case(CaseDetails::new(1, cat_ids)), + None, + None, + None, + ) + .unwrap(); + + session.complete(); } pub fn process( @@ -248,177 +286,109 @@ impl<'a> ImEngine<'a> { input: &[&ImInput], out: &mut heapless::Vec, ) -> Result<(), Error> { - self.matter.reset_transport(); + out.clear(); - let clone_data = CloneData::new( + Self::init_matter( + &self.matter, IM_ENGINE_REMOTE_PEER_ID, IM_ENGINE_PEER_ID, - 1, - 1, - Address::default(), - SessionMode::Case(CaseDetails::new(1, &self.cat_ids)), + &self.cat_ids, ); - let sess_idx = self - .matter - .session_mgr - .borrow_mut() - .clone_session(&clone_data) - .unwrap(); + let matter_client = Self::new_matter(); + Self::init_matter( + &matter_client, + IM_ENGINE_PEER_ID, + IM_ENGINE_REMOTE_PEER_ID, + &self.cat_ids, + ); - let mut send_channel_buf = [heapless::Vec::new(); 1]; - let mut recv_channel_buf = [heapless::Vec::new(); 1]; + let mut buf1 = [heapless::Vec::new(); 1]; + let mut buf2 = [heapless::Vec::new(); 1]; - let mut send_channel = Channel::::new(&mut send_channel_buf); - let mut recv_channel = Channel::::new(&mut recv_channel_buf); + let mut pipe1 = NetworkPipe::::new(&mut buf1); + let mut pipe2 = NetworkPipe::::new(&mut buf2); - let handler = &handler; + let (send_remote, recv_local) = pipe1.split(); + let (send_local, recv_remote) = pipe2.split(); - let mut msg_ctr = self - .matter - .session_mgr - .borrow_mut() - .mut_by_index(sess_idx) - .unwrap() - .get_msg_ctr(); + let matter_client = &matter_client; - let resp_notif = Notification::new(); - let resp_notif = &resp_notif; + let buffers = PooledBuffers::<10, NoopRawMutex, IMBuffer>::new(0); - let mut buffers = PacketBuffers::new(); - let buffers = &mut buffers; + let subscriptions = Subscriptions::<1>::new(); - let (send, mut send_dest) = send_channel.split(); - let (mut recv_dest, recv) = recv_channel.split(); + let responder = Responder::new( + "Default", + DataModel::new(&buffers, &subscriptions, HandlerCompat(handler)), + &self.matter, + 0, + ); - embassy_futures::block_on(async move { + block_on( select3( - self.matter.run( - NetworkSender(send), - NetworkReceiver(recv), - buffers, - CommissioningData { - // TODO: Hard-coded for now - verifier: VerifierData::new_with_pw(123456, *self.matter.borrow()), - discriminator: 250, - }, - &HandlerCompat(handler), + matter_client + .transport_mgr + .run(NetworkSendImpl(send_local), NetworkReceiveImpl(recv_local)), + self.matter.transport_mgr.run( + NetworkSendImpl(send_remote), + NetworkReceiveImpl(recv_remote), ), - async move { - let mut acknowledge = false; - for ip in input { - Self::send(ip, &mut recv_dest, msg_ctr, acknowledge).await?; - resp_notif.wait().await; - - if let Some(delay) = ip.delay { - if delay > 0 { - #[cfg(feature = "std")] - std::thread::sleep(Duration::from_millis(delay as _)); - } - } + join(responder.respond_once("0"), async move { + let mut exchange = + Exchange::initiate(matter_client, IM_ENGINE_REMOTE_PEER_ID, true).await?; - msg_ctr += 2; - acknowledge = true; - } - - pending::<()>().await; - - Ok(()) - }, - async move { - out.clear(); - - while out.len() < input.len() { - let vec = send_dest.receive().await; + for ip in input { + exchange + .send_with(|_, wb| { + ip.data + .to_tlv(&mut TLVWriter::new(wb), TagType::Anonymous)?; + + Ok(Some(MessageMeta { + proto_id: PROTO_ID_INTERACTION_MODEL, + proto_opcode: ip.action as _, + reliable: true, + })) + }) + .await?; - let mut rx = Packet::new_rx(vec); + { + // In a separate block so that the RX message is dropped before we start waiting - rx.plain_hdr_decode()?; - rx.proto_decode(IM_ENGINE_REMOTE_PEER_ID, Some(&[0u8; 16]))?; + let rx = exchange.recv().await?; - if rx.get_proto_id() != PROTO_ID_SECURE_CHANNEL - || rx.get_proto_opcode::()? - != secure_channel::common::OpCode::MRPStandAloneAck - { out.push(ImOutput { - action: rx.get_proto_opcode()?, - data: heapless::Vec::from_slice(rx.as_slice()) + action: rx.meta().opcode()?, + data: heapless::Vec::from_slice(rx.payload()) .map_err(|_| ErrorCode::NoSpace)?, }) .map_err(|_| ErrorCode::NoSpace)?; - - resp_notif.signal(()); } - send_dest.receive_done(); + let delay = ip.delay.unwrap_or(0); + if delay > 0 { + Timer::after(Duration::from_millis(delay as _)).await; + } } + exchange.acknowledge().await?; + Ok(()) - }, + }) + .coalesce(), ) - .await - .unwrap() - })?; - - Ok(()) - } - - async fn send( - input: &ImInput<'_>, - sender: &mut Sender<'_, impl RawMutex, heapless::Vec>, - msg_ctr: u32, - acknowledge: bool, - ) -> Result<(), Error> { - let vec = sender.send().await; - - vec.clear(); - vec.extend(core::iter::repeat(0).take(MAX_RX_BUF_SIZE)); - - let mut tx = Packet::new_tx(vec); - - tx.set_proto_id(PROTO_ID_INTERACTION_MODEL); - tx.set_proto_opcode(input.action as u8); - - let mut tw = TLVWriter::new(tx.get_writebuf()?); - - input.data.to_tlv(&mut tw, TagType::Anonymous)?; - - tx.plain.ctr = msg_ctr + 1; - tx.plain.sess_id = 1; - tx.proto.set_initiator(); - - if acknowledge { - tx.proto.set_ack(msg_ctr - 1); - } - - tx.proto_encode( - Address::default(), - Some(IM_ENGINE_REMOTE_PEER_ID), - IM_ENGINE_PEER_ID, - false, - Some(&[0u8; 16]), - )?; - - let start = tx.get_writebuf()?.get_start(); - let end = tx.get_writebuf()?.get_tail(); - - if start > 0 { - for offset in 0..(end - start) { - vec[offset] = vec[start + offset]; - } - } - - vec.truncate(end - start); - - sender.send_done(); - - Ok(()) + .coalesce(), + ) } } -struct NetworkSender<'a>(Sender<'a, NoopRawMutex, heapless::Vec>); +const ADDR: Address = Address::Udp(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0))); -impl<'a> NetworkSend for NetworkSender<'a> { +type NetworkPipe<'a, const N: usize> = Channel<'a, NoopRawMutex, heapless::Vec>; +struct NetworkReceiveImpl<'a, const N: usize>(Receiver<'a, NoopRawMutex, heapless::Vec>); +struct NetworkSendImpl<'a, const N: usize>(Sender<'a, NoopRawMutex, heapless::Vec>); + +impl<'a, const N: usize> NetworkSend for NetworkSendImpl<'a, N> { async fn send_to(&mut self, data: &[u8], _addr: Address) -> Result<(), Error> { let vec = self.0.send().await; @@ -431,9 +401,7 @@ impl<'a> NetworkSend for NetworkSender<'a> { } } -struct NetworkReceiver<'a>(Receiver<'a, NoopRawMutex, heapless::Vec>); - -impl<'a> NetworkReceive for NetworkReceiver<'a> { +impl<'a, const N: usize> NetworkReceive for NetworkReceiveImpl<'a, N> { async fn wait_available(&mut self) -> Result<(), Error> { self.0.receive().await; @@ -448,9 +416,6 @@ impl<'a> NetworkReceive for NetworkReceiver<'a> { self.0.receive_done(); - Ok(( - len, - Address::Udp(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0))), - )) + Ok((len, ADDR)) } } diff --git a/rs-matter/tests/common/mod.rs b/rs-matter/tests/common/mod.rs index 94837fc1..052f8077 100644 --- a/rs-matter/tests/common/mod.rs +++ b/rs-matter/tests/common/mod.rs @@ -24,6 +24,8 @@ pub mod im_engine; pub fn init_env_logger() { #[cfg(all(feature = "std", not(target_os = "espidf")))] { - let _ = env_logger::try_init(); + let _ = env_logger::try_init_from_env( + env_logger::Env::default().filter_or(env_logger::DEFAULT_FILTER_ENV, "info"), + ); } } diff --git a/rs-matter/tests/data_model/long_reads.rs b/rs-matter/tests/data_model/long_reads.rs index 333801c4..cdbbb59a 100644 --- a/rs-matter/tests/data_model/long_reads.rs +++ b/rs-matter/tests/data_model/long_reads.rs @@ -207,15 +207,15 @@ fn wildcard_read_resp(part: u8) -> Vec> { adm_comm::AttributesDiscriminants::AdminFabricIndex, dont_care.clone() ), + ]; + + let part2 = vec![ attr_data!( 0, 60, adm_comm::AttributesDiscriminants::AdminVendorId, dont_care.clone() ), - ]; - - let part2 = vec![ attr_data!(0, 62, GlobalElements::FeatureMap, dont_care.clone()), attr_data!(0, 62, GlobalElements::AttributeList, dont_care.clone()), attr_data!(