diff --git a/mqtt-v5/src/decoder.rs b/mqtt-v5/src/decoder.rs index 37e6014..49777f1 100644 --- a/mqtt-v5/src/decoder.rs +++ b/mqtt-v5/src/decoder.rs @@ -192,7 +192,7 @@ fn decode_property( Ok(Some(Property::CorrelationData(CorrelationData(correlation_data)))) }, PropertyType::SubscriptionIdentifier => { - let subscription_identifier = read_u32!(bytes); + let subscription_identifier = read_variable_int!(bytes); Ok(Some(Property::SubscriptionIdentifier(SubscriptionIdentifier(VariableByteInt( subscription_identifier, ))))) @@ -303,6 +303,16 @@ fn decode_property( fn decode_properties( bytes: &mut Cursor<&mut BytesMut>, mut closure: F, +) -> Result, DecodeError> { + try_decode_properties(bytes, |property| { + closure(property); + Ok(()) + }) +} + +fn try_decode_properties Result<(), DecodeError>>( + bytes: &mut Cursor<&mut BytesMut>, + mut closure: F, ) -> Result, DecodeError> { let property_length = read_variable_int!(bytes); @@ -322,7 +332,7 @@ fn decode_properties( } let property = read_property!(bytes); - closure(property); + closure(property)?; } Ok(Some(())) @@ -564,23 +574,48 @@ fn decode_publish( let mut response_topic = None; let mut correlation_data = None; let mut user_properties = vec![]; - let mut subscription_identifier = None; + let mut subscription_identifiers = None; let mut content_type = None; if protocol_version == ProtocolVersion::V500 { - return_if_none!(decode_properties(bytes, |property| { - match property { - Property::PayloadFormatIndicator(p) => payload_format_indicator = Some(p), - Property::MessageExpiryInterval(p) => message_expiry_interval = Some(p), - Property::TopicAlias(p) => topic_alias = Some(p), - Property::ResponseTopic(p) => response_topic = Some(p), - Property::CorrelationData(p) => correlation_data = Some(p), - Property::UserProperty(p) => user_properties.push(p), - Property::SubscriptionIdentifier(p) => subscription_identifier = Some(p), - Property::ContentType(p) => content_type = Some(p), - _ => {}, // Invalid property for packet - } - })?); + try_decode_properties(bytes, |property| match property { + Property::PayloadFormatIndicator(p) => { + payload_format_indicator = Some(p); + Ok(()) + }, + Property::MessageExpiryInterval(p) => { + message_expiry_interval = Some(p); + Ok(()) + }, + Property::TopicAlias(p) => { + topic_alias = Some(p); + Ok(()) + }, + Property::ResponseTopic(p) => { + response_topic = Some(p); + Ok(()) + }, + Property::CorrelationData(p) => { + correlation_data = Some(p); + Ok(()) + }, + Property::UserProperty(p) => { + user_properties.push(p); + Ok(()) + }, + Property::SubscriptionIdentifier(SubscriptionIdentifier(VariableByteInt(0))) => { + Err(DecodeError::InvalidSubscriptionIdentifier) + }, + Property::SubscriptionIdentifier(p) => { + subscription_identifiers.get_or_insert(Vec::new()).push(p); + Ok(()) + }, + Property::ContentType(p) => { + content_type = Some(p); + Ok(()) + }, + _ => Err(DecodeError::InvalidPropertyForPacket), + })?; } let end_cursor_pos = bytes.position(); @@ -592,6 +627,8 @@ fn decode_publish( } let payload_size = remaining_packet_length - variable_header_size; let payload = return_if_none!(decode_binary_data_with_size(bytes, payload_size as usize)?); + let subscription_identifiers = + subscription_identifiers.unwrap_or_else(|| Vec::with_capacity(0)); let packet = PublishPacket { is_duplicate, @@ -607,7 +644,7 @@ fn decode_publish( response_topic, correlation_data, user_properties, - subscription_identifier, + subscription_identifiers, content_type, payload, @@ -781,13 +818,27 @@ fn decode_subscribe( let mut user_properties = vec![]; if protocol_version == ProtocolVersion::V500 { - return_if_none!(decode_properties(bytes, |property| { + try_decode_properties(bytes, |property| { match property { - Property::SubscriptionIdentifier(p) => subscription_identifier = Some(p), - Property::UserProperty(p) => user_properties.push(p), - _ => {}, // Invalid property for packet + // [MQTT-3.8.2.1.2] The subscription identifier is allowed exactly once + Property::SubscriptionIdentifier(_) if subscription_identifier.is_some() => { + Err(DecodeError::InvalidSubscriptionIdentifier) + }, + // [MQTT-3.8.2.1.2] The subscription identifier must not be 0 + Property::SubscriptionIdentifier(SubscriptionIdentifier(VariableByteInt(0))) => { + Err(DecodeError::InvalidSubscriptionIdentifier) + }, + Property::SubscriptionIdentifier(p) => { + subscription_identifier = Some(p); + Ok(()) + }, + Property::UserProperty(p) => { + user_properties.push(p); + Ok(()) + }, + _ => Err(DecodeError::InvalidPropertyForPacket), } - })?); + })?; } let variable_header_size = (bytes.position() - start_cursor_pos) as u32; @@ -1153,7 +1204,7 @@ pub fn decode_mqtt( #[cfg(test)] mod tests { - use crate::{decoder::*, types::*}; + use crate::{decoder::*, topic::TopicFilter, types::*}; use bytes::BytesMut; #[test] @@ -1196,4 +1247,48 @@ mod tests { normal_test(&[0x80, 0x80, 0x80, 0x01], 2097152); normal_test(&[0xFF, 0xFF, 0xFF, 0x7F], 268435455); } + + #[test] + fn test_decode_subscribe() { + // Subscribe packet *without* Subscription Identifier + let mut without_subscription_identifier = BytesMut::from( + [0x82, 0x0a, 0x00, 0x01, 0x00, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00].as_slice(), + ); + let without_subscription_identifier_expected = Packet::Subscribe(SubscribePacket { + packet_id: 1, + subscription_identifier: None, + user_properties: vec![], + subscription_topics: vec![SubscriptionTopic { + topic_filter: TopicFilter::Concrete { filter: "test".into(), level_count: 1 }, + maximum_qos: QoS::AtMostOnce, + no_local: false, + retain_as_published: false, + retain_handling: RetainHandling::SendAtSubscribeTime, + }], + }); + let decoded = decode_mqtt(&mut without_subscription_identifier, ProtocolVersion::V500) + .unwrap() + .unwrap(); + assert_eq!(without_subscription_identifier_expected, decoded); + + // Subscribe packet with Subscription Identifier + let mut packet = BytesMut::from( + [0x82, 0x0c, 0xff, 0xf6, 0x02, 0x0b, 0x01, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x02] + .as_slice(), + ); + let decoded = decode_mqtt(&mut packet, ProtocolVersion::V500).unwrap().unwrap(); + let with_subscription_identifier_expected = Packet::Subscribe(SubscribePacket { + packet_id: 65526, + subscription_identifier: Some(SubscriptionIdentifier(VariableByteInt(1))), + user_properties: vec![], + subscription_topics: vec![SubscriptionTopic { + topic_filter: TopicFilter::Concrete { filter: "test".into(), level_count: 1 }, + maximum_qos: QoS::ExactlyOnce, + no_local: false, + retain_as_published: false, + retain_handling: RetainHandling::SendAtSubscribeTime, + }], + }); + assert_eq!(with_subscription_identifier_expected, decoded); + } } diff --git a/mqtt-v5/src/encoder.rs b/mqtt-v5/src/encoder.rs index cc0d3ff..79984e3 100644 --- a/mqtt-v5/src/encoder.rs +++ b/mqtt-v5/src/encoder.rs @@ -333,7 +333,7 @@ fn encode_publish(packet: &PublishPacket, bytes: &mut BytesMut, protocol_version packet.response_topic.encode(bytes); packet.correlation_data.encode(bytes); packet.user_properties.encode(bytes); - packet.subscription_identifier.encode(bytes); + packet.subscription_identifiers.encode(bytes); packet.content_type.encode(bytes); } @@ -660,7 +660,7 @@ mod tests { response_topic: None, correlation_data: None, user_properties: vec![], - subscription_identifier: None, + subscription_identifiers: Vec::with_capacity(0), content_type: None, payload: vec![22; 100].into(), diff --git a/mqtt-v5/src/types.rs b/mqtt-v5/src/types.rs index 7c3b4e6..0be9b82 100644 --- a/mqtt-v5/src/types.rs +++ b/mqtt-v5/src/types.rs @@ -23,6 +23,7 @@ pub enum DecodeError { InvalidPublishReleaseReason, InvalidPublishCompleteReason, InvalidSubscribeAckReason, + InvalidSubscriptionIdentifier, InvalidUnsubscribeAckReason, InvalidAuthenticateReason, InvalidPropertyId, @@ -129,6 +130,14 @@ impl Encode for Vec { } } +impl Encode for Vec { + fn encode(&self, bytes: &mut BytesMut) { + for identifier in self { + identifier.encode(bytes); + } + } +} + impl PacketSize for u16 { fn calc_size(&self, _protocol_version: ProtocolVersion) -> u32 { 2 @@ -317,6 +326,11 @@ pub mod properties { 1 + self.0.calc_size(protocol_version) } } + impl PacketSize for Vec { + fn calc_size(&self, protocol_version: ProtocolVersion) -> u32 { + self.iter().map(|x| x.calc_size(protocol_version)).sum() + } + } #[derive(Debug, Clone, Copy, PartialEq)] pub struct SessionExpiryInterval(pub u32); @@ -917,7 +931,7 @@ pub struct PublishPacket { pub response_topic: Option, pub correlation_data: Option, pub user_properties: Vec, - pub subscription_identifier: Option, + pub subscription_identifiers: Vec, pub content_type: Option, // Payload @@ -933,7 +947,7 @@ impl PropertySize for PublishPacket { property_size += self.response_topic.calc_size(protocol_version); property_size += self.correlation_data.calc_size(protocol_version); property_size += self.user_properties.calc_size(protocol_version); - property_size += self.subscription_identifier.calc_size(protocol_version); + property_size += self.subscription_identifiers.calc_size(protocol_version); property_size += self.content_type.calc_size(protocol_version); property_size @@ -958,7 +972,7 @@ impl From for PublishPacket { response_topic: will.response_topic, correlation_data: will.correlation_data, user_properties: will.user_properties, - subscription_identifier: None, + subscription_identifiers: Vec::with_capacity(0), content_type: will.content_type, // Payload