diff --git a/dnp3/src/master/association.rs b/dnp3/src/master/association.rs index 5cebffdd..e1fe626e 100644 --- a/dnp3/src/master/association.rs +++ b/dnp3/src/master/association.rs @@ -261,7 +261,7 @@ impl TaskStates { if self.time_sync.is_pending() { if let Some(procedure) = config.auto_time_sync { return self.time_sync.create_next_task(|| { - TimeSync(TimeSyncTask::get_procedure(procedure, Promise::None)).wrap() + TimeSync(TimeSyncTask::get_procedure(procedure, None)).wrap() }); } } @@ -680,7 +680,7 @@ impl Association { None => Next::None, Some(next) => { if now >= next { - Next::Now(Task::LinkStatus(Promise::None)) + Next::Now(Task::LinkStatus(Promise::null())) } else { Next::NotBefore(next) } diff --git a/dnp3/src/master/handler.rs b/dnp3/src/master/handler.rs index 94daf00a..81f63ffe 100644 --- a/dnp3/src/master/handler.rs +++ b/dnp3/src/master/handler.rs @@ -129,8 +129,8 @@ impl MasterChannel { /// Get the current decoding level used by this master pub async fn get_decode_level(&mut self) -> Result { - let (tx, rx) = tokio::sync::oneshot::channel::>(); - self.send_master_message(MasterMsg::GetDecodeLevel(Promise::OneShot(tx))) + let (promise, rx) = Promise::one_shot(); + self.send_master_message(MasterMsg::GetDecodeLevel(promise)) .await?; rx.await? } @@ -161,7 +161,7 @@ impl MasterChannel { ) -> Result { self.assert_channel_type(MasterChannelType::Stream)?; - let (tx, rx) = tokio::sync::oneshot::channel::>(); + let (promise, rx) = Promise::one_shot(); let addr = FragmentAddr { link: address, phys: PhysAddr::None, @@ -172,7 +172,7 @@ impl MasterChannel { read_handler, assoc_handler, assoc_information, - Promise::OneShot(tx), + promise, )) .await?; rx.await? @@ -196,7 +196,7 @@ impl MasterChannel { ) -> Result { self.assert_channel_type(MasterChannelType::Udp)?; - let (tx, rx) = tokio::sync::oneshot::channel::>(); + let (promise, rx) = Promise::one_shot(); let addr = FragmentAddr { link: address, phys: PhysAddr::Udp(destination), @@ -207,7 +207,7 @@ impl MasterChannel { read_handler, assoc_handler, assoc_information, - Promise::OneShot(tx), + promise, )) .await?; rx.await? @@ -266,14 +266,9 @@ impl AssociationHandle { request: ReadRequest, period: Duration, ) -> Result { - let (tx, rx) = tokio::sync::oneshot::channel::>(); - self.send_poll_message(PollMsg::AddPoll( - self.clone(), - request, - period, - Promise::OneShot(tx), - )) - .await?; + let (promise, rx) = Promise::one_shot(); + self.send_poll_message(PollMsg::AddPoll(self.clone(), request, period, promise)) + .await?; rx.await? } @@ -289,8 +284,8 @@ impl AssociationHandle { /// /// If successful, the [ReadHandler](ReadHandler) will process the received measurement data pub async fn read(&mut self, request: ReadRequest) -> Result<(), TaskError> { - let (tx, rx) = tokio::sync::oneshot::channel::>(); - let task = SingleReadTask::new(request, Promise::OneShot(tx)); + let (promise, rx) = Promise::one_shot(); + let task = SingleReadTask::new(request, promise); self.send_task(task).await?; rx.await? } @@ -305,8 +300,8 @@ impl AssociationHandle { function: FunctionCode, headers: Headers, ) -> Result<(), WriteError> { - let (tx, rx) = tokio::sync::oneshot::channel::>(); - let task = EmptyResponseTask::new(function, headers, Promise::OneShot(tx)); + let (promise, rx) = Promise::one_shot(); + let task = EmptyResponseTask::new(function, headers, promise); self.send_task(task).await?; rx.await? } @@ -319,8 +314,8 @@ impl AssociationHandle { request: ReadRequest, handler: Box, ) -> Result<(), TaskError> { - let (tx, rx) = tokio::sync::oneshot::channel::>(); - let task = SingleReadTask::new_with_custom_handler(request, handler, Promise::OneShot(tx)); + let (promise, rx) = Promise::one_shot(); + let task = SingleReadTask::new_with_custom_handler(request, handler, promise); self.send_task(task).await?; rx.await? } @@ -333,8 +328,8 @@ impl AssociationHandle { mode: CommandMode, headers: CommandHeaders, ) -> Result<(), CommandError> { - let (tx, rx) = tokio::sync::oneshot::channel::>(); - let task = CommandTask::from_mode(mode, headers, Promise::OneShot(tx)); + let (promise, rx) = Promise::one_shot(); + let task = CommandTask::from_mode(mode, headers, promise); self.send_task(task).await?; rx.await? } @@ -354,8 +349,8 @@ impl AssociationHandle { } async fn restart(&mut self, restart_type: RestartType) -> Result { - let (tx, rx) = tokio::sync::oneshot::channel::>(); - let task = RestartTask::new(restart_type, Promise::OneShot(tx)); + let (promise, rx) = Promise::one_shot(); + let task = RestartTask::new(restart_type, promise); self.send_task(task).await?; rx.await? } @@ -365,8 +360,8 @@ impl AssociationHandle { &mut self, procedure: TimeSyncProcedure, ) -> Result<(), TimeSyncError> { - let (tx, rx) = tokio::sync::oneshot::channel::>(); - let task = TimeSyncTask::get_procedure(procedure, Promise::OneShot(tx)); + let (promise, rx) = Promise::one_shot(); + let task = TimeSyncTask::get_procedure(procedure, Some(promise)); self.send_task(task).await?; rx.await? } @@ -376,8 +371,8 @@ impl AssociationHandle { &mut self, headers: Vec, ) -> Result<(), WriteError> { - let (tx, rx) = tokio::sync::oneshot::channel::>(); - let task = WriteDeadBandsTask::new(headers, Promise::OneShot(tx)); + let (promise, rx) = Promise::one_shot(); + let task = WriteDeadBandsTask::new(headers, promise); self.send_task(task).await?; rx.await? } @@ -390,9 +385,8 @@ impl AssociationHandle { /// If a [`TaskError::UnexpectedResponseHeaders`] is returned, the link might be alive /// but it didn't answer with the expected `LINK_STATUS`. pub async fn check_link_status(&mut self) -> Result<(), TaskError> { - let (tx, rx) = tokio::sync::oneshot::channel::>(); - self.send_task(Task::LinkStatus(Promise::OneShot(tx))) - .await?; + let (promise, rx) = Promise::one_shot(); + self.send_task(Task::LinkStatus(promise)).await?; rx.await? } diff --git a/dnp3/src/master/promise.rs b/dnp3/src/master/promise.rs index e949c500..c104daed 100644 --- a/dnp3/src/master/promise.rs +++ b/dnp3/src/master/promise.rs @@ -1,24 +1,54 @@ +pub(crate) type CallbackType = Box; + /// A generic callback type that must be invoked once and only once. /// The user can select to implement it using FnOnce or a /// one-shot reply channel -pub(crate) enum Promise { - /// nothing happens when the promise is completed - None, +enum Inner { /// one-shot reply channel is consumed when the promise is completed OneShot(tokio::sync::oneshot::Sender), + /// Boxed FnOnce + #[allow(dead_code)] + CallBack(CallbackType, T), +} + +pub(crate) struct Promise { + inner: Option>, } impl Promise { + pub(crate) fn null() -> Self { + Self { inner: None } + } + + fn new(inner: Inner) -> Self { + Self { inner: Some(inner) } + } + pub(crate) fn one_shot() -> (Self, tokio::sync::oneshot::Receiver) { let (tx, rx) = tokio::sync::oneshot::channel(); - (Self::OneShot(tx), rx) + (Self::new(Inner::OneShot(tx)), rx) + } + + pub(crate) fn complete(mut self, value: T) { + if let Some(x) = self.inner.take() { + match x { + Inner::OneShot(s) => { + s.send(value).ok(); + } + Inner::CallBack(cb, _) => cb(value), + } + } } +} - pub(crate) fn complete(self, value: T) { - match self { - Promise::None => {} - Promise::OneShot(s) => { - s.send(value).ok(); +impl Drop for Promise { + fn drop(&mut self) { + if let Some(x) = self.inner.take() { + match x { + Inner::OneShot(_) => {} + Inner::CallBack(cb, default) => { + cb(default); + } } } } diff --git a/dnp3/src/master/task.rs b/dnp3/src/master/task.rs index 623c8077..70a39855 100644 --- a/dnp3/src/master/task.rs +++ b/dnp3/src/master/task.rs @@ -318,19 +318,11 @@ impl MasterSession { res.map(|_| ()) } Task::LinkStatus(promise) => { - match self + let res = self .run_link_status_task(io, task.dest, writer, reader) - .await - { - Ok(result) => { - promise.complete(Ok(result)); - Ok(()) - } - Err(err) => { - promise.complete(Err(err)); - Err(err) - } - } + .await; + promise.complete(res); + res } }; diff --git a/dnp3/src/master/tasks/time.rs b/dnp3/src/master/tasks/time.rs index f9f093b1..0f33f8dd 100644 --- a/dnp3/src/master/tasks/time.rs +++ b/dnp3/src/master/tasks/time.rs @@ -23,7 +23,7 @@ enum State { pub(crate) struct TimeSyncTask { state: State, - promise: Promise>, + promise: Option>>, } impl From for Task { @@ -42,7 +42,7 @@ impl TimeSyncProcedure { } impl TimeSyncTask { - fn new(state: State, promise: Promise>) -> Self { + fn new(state: State, promise: Option>>) -> Self { Self { state, promise } } @@ -52,7 +52,7 @@ impl TimeSyncTask { pub(crate) fn get_procedure( procedure: TimeSyncProcedure, - promise: Promise>, + promise: Option>>, ) -> Self { Self::new(procedure.get_start_state(), promise) } @@ -110,12 +110,12 @@ impl TimeSyncTask { pub(crate) fn on_task_error(self, association: Option<&mut Association>, err: TaskError) { match self.promise { - Promise::None => { + None => { if let Some(association) = association { association.on_time_sync_failure(err.into()); } } - _ => self.promise.complete(Err(err.into())), + Some(x) => x.complete(Err(err.into())), } } @@ -292,15 +292,15 @@ impl TimeSyncTask { fn report_success(self, association: &mut Association) { match self.promise { - Promise::None => association.on_time_sync_success(), - _ => self.promise.complete(Ok(())), + None => association.on_time_sync_success(), + Some(x) => x.complete(Ok(())), } } fn report_error(self, association: &mut Association, error: TimeSyncError) { match self.promise { - Promise::None => association.on_time_sync_failure(error), - _ => self.promise.complete(Err(error)), + None => association.on_time_sync_failure(error), + Some(x) => x.complete(Err(error)), } } } @@ -395,9 +395,8 @@ mod tests { handler(system_time), Box::new(NullAssociationInformation), ); - let (tx, rx) = tokio::sync::oneshot::channel(); - let task = - NonReadTask::TimeSync(TimeSyncTask::get_procedure(procedure, Promise::OneShot(tx))); + let (promise, rx) = Promise::one_shot(); + let task = NonReadTask::TimeSync(TimeSyncTask::get_procedure(procedure, Some(promise))); (task, system_time, association, rx) }