diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index ee50ccbe419e..474615468291 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -47,6 +47,8 @@ where option_values: HashMap>, rpcmethods: HashMap>, subscriptions: HashMap>, + // Contains a Subscription if the user subscribed to "*" + wildcard_subscription : Option>, notifications: Vec, custommessages: Vec, featurebits: FeatureBits, @@ -72,6 +74,7 @@ where rpcmethods: HashMap>, hooks: HashMap>, subscriptions: HashMap>, + wildcard_subscription : Option>, #[allow(dead_code)] // unsure why rust thinks this field isn't used notifications: Vec, } @@ -91,6 +94,7 @@ where #[allow(dead_code)] // Unused until we fill in the Hook structs. hooks: HashMap>, subscriptions: HashMap>, + wildcard_subscription : Option> } #[derive(Clone)] @@ -123,6 +127,7 @@ where output: Some(output), hooks: HashMap::new(), subscriptions: HashMap::new(), + wildcard_subscription: None, options: HashMap::new(), // Should not be configured by user. // This values are set when parsing the init-call @@ -173,12 +178,16 @@ where C: Fn(Plugin, Request) -> F + 'static, F: Future> + Send + 'static, { - self.subscriptions.insert( - topic.to_string(), - Subscription { - callback: Box::new(move |p, r| Box::pin(callback(p, r))), - }, - ); + let subscription = Subscription { + callback : Box::new(move |p, r| Box::pin(callback(p, r))) + }; + + if topic == "*" { + self.wildcard_subscription = Some(subscription); + } + else { + self.subscriptions.insert(topic.to_string(), subscription); + }; self } @@ -328,6 +337,7 @@ where let subscriptions = HashMap::from_iter(self.subscriptions.drain().map(|(k, v)| (k, v.callback))); + let all_subscription = self.wildcard_subscription.map(|s| s.callback); // Leave the `init` reply pending, so we can disable based on // the options if required. @@ -339,6 +349,7 @@ where rpcmethods, notifications: self.notifications, subscriptions, + wildcard_subscription: all_subscription, options: self.options, option_values: self.option_values, configuration, @@ -378,9 +389,13 @@ where }) .collect(); + let subscriptions = self.subscriptions.keys() + .map(|s| s.clone()) + .chain(self.wildcard_subscription.iter().map(|_| String::from("*"))).collect(); + messages::GetManifestResponse { options: self.options.values().cloned().collect(), - subscriptions: self.subscriptions.keys().map(|s| s.clone()).collect(), + subscriptions, hooks: self.hooks.keys().map(|s| s.clone()).collect(), rpcmethods, notifications: self.notifications.clone(), @@ -553,6 +568,7 @@ where rpcmethods: self.rpcmethods, hooks: self.hooks, subscriptions: self.subscriptions, + wildcard_subscription : self.wildcard_subscription }; output @@ -724,25 +740,41 @@ where Ok(()) } messages::JsonRpc::CustomNotification(request) => { + // This code handles notifications trace!("Dispatching custom notification {:?}", request); let method = request .get("method") .context("Missing 'method' in request")? .as_str() .context("'method' is not a string")?; - let callback = self.subscriptions.get(method).with_context(|| { - anyhow!("No handler for notification '{}' registered", method) - })?; + let params = request .get("params") - .context("Missing 'params' field in request")? - .clone(); - - let plugin = plugin.clone(); - let call = callback(plugin.clone(), params); - - tokio::spawn(async move { call.await.unwrap() }); - Ok(()) + .context("Missing 'params' field in request")?; + + // Send to notification to the wildcard + // subscription "*" it it exists + match &self.wildcard_subscription { + Some(cb) => { + let call = cb(plugin.clone(), params.clone()); + tokio::spawn(async move {call.await.unwrap()});} + None => {} + }; + + // Find the appropriate callback and process it + // We'll log a warning if no handler is defined + match self.subscriptions.get(method) { + Some(cb) => { + let call = cb(plugin.clone(), params.clone()); + tokio::spawn(async move {call.await.unwrap()}); + }, + None => { + if self.wildcard_subscription.is_none() { + log::warn!("No handler for notification '{}' registered", method); + } + } + }; + Ok(()) } } } diff --git a/tests/test_cln_rs.py b/tests/test_cln_rs.py index d87eafa93d39..fa4a8196248a 100644 --- a/tests/test_cln_rs.py +++ b/tests/test_cln_rs.py @@ -384,7 +384,6 @@ def test_grpc_decode(node_factory): print(res) -@pytest.mark.xfail(strict=True) def test_rust_plugin_subscribe_wildcard(node_factory): """ Creates a plugin that loads the subscribe_wildcard plugin """