From 9839914069bbb3951b729554783697f85616d7a7 Mon Sep 17 00:00:00 2001 From: cbonaudo Date: Sat, 22 May 2021 20:24:31 +0200 Subject: [PATCH 1/2] add request_timeout function --- src/bastion/src/distributor.rs | 142 ++++++++++++++++++++++++++++++++- 1 file changed, 139 insertions(+), 3 deletions(-) diff --git a/src/bastion/src/distributor.rs b/src/bastion/src/distributor.rs index 23c9565d..e125f32e 100644 --- a/src/bastion/src/distributor.rs +++ b/src/bastion/src/distributor.rs @@ -6,11 +6,13 @@ use crate::{ system::{STRING_INTERNER, SYSTEM}, }; use anyhow::Result as AnyResult; -use futures::channel::oneshot; +use futures::{channel::oneshot, FutureExt}; +use futures_timer::Delay; use lasso::Spur; use std::{ fmt::Debug, sync::mpsc::{channel, Receiver}, + time::Duration, }; // Copy is fine here because we're working @@ -44,7 +46,7 @@ impl Distributor { /// /// This can be achieved manually using a `MessageHandler` and `ask_one`. /// Ask a question to a recipient attached to the `Distributor` - /// + /// # Example /// /// ```no_run /// # use bastion::prelude::*; @@ -228,6 +230,116 @@ impl Distributor { receiver } + /// Ask a question to a recipient attached to the `Distributor` + /// and wait for a reply. + /// + /// This is the timeout variant of the 'request' function. If the request cannot be performed + /// in the provided duration, an Error is propagated. + /// # Example + /// + /// ```no_run + /// # use core::time::Duration; + /// # use bastion::prelude::*; + /// # #[cfg(feature = "tokio-runtime")] + /// # #[tokio::main] + /// # async fn main() { + /// # run(); + /// # } + /// # + /// # #[cfg(not(feature = "tokio-runtime"))] + /// # fn main() { + /// # run(); + /// # } + /// # + /// # async fn run() { + /// # Bastion::init(); + /// # Bastion::start(); + /// + /// # Bastion::supervisor(|supervisor| { + /// # supervisor.children(|children| { + /// // attach a named distributor to the children + /// children + /// # .with_redundancy(1) + /// .with_distributor(Distributor::named("my distributor")) + /// .with_exec(|ctx: BastionContext| { + /// async move { + /// loop { + /// // The message handler needs an `on_question` section + /// // that matches the `question` you're going to send, + /// // and that will reply with the Type the request expects. + /// // In our example, we ask a `&str` question, and expect a `bool` reply. + /// MessageHandler::new(ctx.recv().await?) + /// .on_question(|message: &str, sender| { + /// if message == "is it raining today?" { + /// sender.reply(true).unwrap(); + /// } + /// }); + /// } + /// Ok(()) + /// } + /// }) + /// # }) + /// # }); + /// + /// let distributor = Distributor::named("my distributor"); + /// + /// let timeout = Duration::from_millis(10); + /// let reply: Result = distributor + /// .request_timeout("is it raining today?", timeout) + /// .await + /// .expect("couldn't receive reply"); + /// + /// # Bastion::stop(); + /// # Bastion::block_until_stopped(); + /// # } + /// ``` + pub fn request_timeout( + &self, + question: impl Message, + timeout: Duration, + ) -> oneshot::Receiver> { + let (sender, receiver) = oneshot::channel(); + let s = *self; + spawn!(async move { + match SYSTEM.dispatcher().ask(s, question) { + Ok(response) => { + futures::select! { + response_awaited = response.fuse() => { + match response_awaited { + Ok(message) => { + let message_to_send = MessageHandler::new(message) + .on_tell(|reply: R, _| Ok(reply)) + .on_fallback(|_, _| { + Err(SendError::Other(anyhow::anyhow!( + "received a message with the wrong type" + ))) + }); + let _ = sender.send(message_to_send); + } + Err(e) => { + let _ = sender.send(Err(SendError::Other(anyhow::anyhow!( + "couldn't receive reply: {:?}", + e + )))); + } + } + }, + _duration = Delay::new(timeout).fuse() => { + let _ = sender.send(Err(SendError::Other(anyhow::anyhow!( + "operation timed out before finish" + )))); + } + } + } + Err(error) => { + let _ = sender.send(Err(error)); + } + } + }); + + receiver + } + /// Ask a question to a recipient attached to the `Distributor` /// /// # Example @@ -539,8 +651,10 @@ impl Distributor { #[cfg(test)] mod distributor_tests { use crate::prelude::*; + use core::time; use futures::channel::mpsc::channel; use futures::{SinkExt, StreamExt}; + use std::{thread, time::Duration}; const TEST_DISTRIBUTOR: &str = "test distributor"; const SUBSCRIBE_TEST_DISTRIBUTOR: &str = "subscribe test"; @@ -675,12 +789,33 @@ mod distributor_tests { }); let answer_sync: u8 = test_distributor - .request_sync(question) + .request_sync(question.clone()) .recv() .unwrap() .unwrap(); assert_eq!(42, answer_sync); + + run!(async { + let timeout = Duration::from_millis(10); + let answer_timeout: u8 = test_distributor + .request_timeout(question.clone(), timeout) + .await + .unwrap() + .unwrap(); + assert_eq!(42, answer_timeout); + }); + + run!(async { + let timeout = Duration::from_nanos(1); + let answer_timeout: Result = test_distributor + .request_timeout(question.clone(), timeout) + .await + .unwrap(); + + let err_msg: SendError = answer_timeout.unwrap_err(); + assert!(matches!(err_msg, SendError::Other { .. })); + }); } fn setup() { @@ -709,6 +844,7 @@ mod distributor_tests { let child_ref = ctx.current().clone(); MessageHandler::new(ctx.recv().await?) .on_question(|_: String, sender| { + thread::sleep(time::Duration::from_millis(10)); let _ = sender.reply(42_u8); }) // send your child ref From 82c8addadfacdf9efbffc05f488a3642eb88b5b6 Mon Sep 17 00:00:00 2001 From: cbonaudo Date: Sat, 22 May 2021 20:26:14 +0200 Subject: [PATCH 2/2] add millis --- src/bastion/src/distributor.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/bastion/src/distributor.rs b/src/bastion/src/distributor.rs index e125f32e..f60cacd1 100644 --- a/src/bastion/src/distributor.rs +++ b/src/bastion/src/distributor.rs @@ -797,7 +797,7 @@ mod distributor_tests { assert_eq!(42, answer_sync); run!(async { - let timeout = Duration::from_millis(10); + let timeout = Duration::from_millis(100); let answer_timeout: u8 = test_distributor .request_timeout(question.clone(), timeout) .await