diff --git a/zenoh-plugin-mqtt/src/config.rs b/zenoh-plugin-mqtt/src/config.rs index 6817853..b074c0b 100644 --- a/zenoh-plugin-mqtt/src/config.rs +++ b/zenoh-plugin-mqtt/src/config.rs @@ -46,10 +46,9 @@ pub struct Config { pub generalise_subs: Vec, #[serde(default)] pub generalise_pubs: Vec, - #[serde(default, skip_serializing)] - __required__: bool, - #[serde(default, skip_serializing, deserialize_with = "deserialize_paths")] - __path__: Vec, + __required__: Option, + #[serde(default, deserialize_with = "deserialize_path")] + __path__: Option>, } fn default_mqtt_port() -> String { @@ -63,38 +62,67 @@ where deserializer.deserialize_any(MqttPortVisitor) } -fn deserialize_paths<'de, D>(deserializer: D) -> Result, D::Error> +fn deserialize_path<'de, D>(deserializer: D) -> Result>, D::Error> where D: Deserializer<'de>, { - struct V; - impl<'de> serde::de::Visitor<'de> for V { - type Value = Vec; - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(formatter, "a string or vector of strings") - } - fn visit_str(self, v: &str) -> Result - where - E: de::Error, - { - Ok(vec![v.into()]) - } - fn visit_seq(self, mut seq: A) -> Result - where - A: de::SeqAccess<'de>, - { - let mut v = if let Some(l) = seq.size_hint() { - Vec::with_capacity(l) - } else { - Vec::new() - }; - while let Some(s) = seq.next_element()? { - v.push(s); - } - Ok(v) + deserializer.deserialize_option(OptPathVisitor) +} + +struct OptPathVisitor; + +impl<'de> serde::de::Visitor<'de> for OptPathVisitor { + type Value = Option>; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(formatter, "none or a string or an array of strings") + } + + fn visit_none(self) -> Result + where + E: de::Error, + { + Ok(None) + } + + fn visit_some(self, deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_any(PathVisitor).map(Some) + } +} + +struct PathVisitor; + +impl<'de> serde::de::Visitor<'de> for PathVisitor { + type Value = Vec; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(formatter, "a string or an array of strings") + } + + fn visit_str(self, v: &str) -> Result + where + E: de::Error, + { + Ok(vec![v.into()]) + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: de::SeqAccess<'de>, + { + let mut v = if let Some(l) = seq.size_hint() { + Vec::with_capacity(l) + } else { + Vec::new() + }; + while let Some(s) = seq.next_element()? { + v.push(s); } + Ok(v) } - deserializer.deserialize_any(V) } fn deserialize_regex<'de, D>(deserializer: D) -> Result, D::Error> @@ -162,3 +190,73 @@ impl<'de> Visitor<'de> for MqttPortVisitor { Ok(format!("{interface}:{port}")) } } + +#[cfg(test)] +mod tests { + use super::Config; + + #[test] + fn test_path_field() { + // See: https://github.com/eclipse-zenoh/zenoh-plugin-webserver/issues/19 + let config = serde_json::from_str::(r#"{"__path__": "/example/path"}"#); + + assert!(config.is_ok()); + let Config { + __required__, + __path__, + .. + } = config.unwrap(); + + assert_eq!(__path__, Some(vec![String::from("/example/path")])); + assert_eq!(__required__, None); + } + + #[test] + fn test_required_field() { + // See: https://github.com/eclipse-zenoh/zenoh-plugin-webserver/issues/19 + let config = serde_json::from_str::(r#"{"__required__": true}"#); + assert!(config.is_ok()); + let Config { + __required__, + __path__, + .. + } = config.unwrap(); + + assert_eq!(__path__, None); + assert_eq!(__required__, Some(true)); + } + + #[test] + fn test_path_field_and_required_field() { + // See: https://github.com/eclipse-zenoh/zenoh-plugin-webserver/issues/19 + let config = serde_json::from_str::( + r#"{"__path__": "/example/path", "__required__": true}"#, + ); + + assert!(config.is_ok()); + let Config { + __required__, + __path__, + .. + } = config.unwrap(); + + assert_eq!(__path__, Some(vec![String::from("/example/path")])); + assert_eq!(__required__, Some(true)); + } + + #[test] + fn test_no_path_field_and_no_required_field() { + // See: https://github.com/eclipse-zenoh/zenoh-plugin-webserver/issues/19 + let config = serde_json::from_str::("{}"); + + assert!(config.is_ok()); + let Config { + __required__, + __path__, + .. + } = config.unwrap(); + + assert_eq!(__path__, None); + assert_eq!(__required__, None); + } +}