Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributor: Add request_timeout function #332

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 139 additions & 3 deletions src/bastion/src/distributor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ use crate::{
system::{STRING_INTERNER, SYSTEM},
};
use anyhow::Result as AnyResult;
use futures::channel::oneshot;
use futures::{channel::oneshot, FutureExt};
use futures_timer::Delay;
use lasso::Spur;
use std::{
fmt::Debug,
sync::mpsc::{channel, Receiver},
time::Duration,
};

// Copy is fine here because we're working
Expand Down Expand Up @@ -44,7 +46,7 @@ impl Distributor {
///
/// This can be achieved manually using a `MessageHandler` and `ask_one`.
/// Ask a question to a recipient attached to the `Distributor`
///
/// # Example
///
/// ```no_run
/// # use bastion::prelude::*;
Expand Down Expand Up @@ -228,6 +230,116 @@ impl Distributor {
receiver
}

/// Ask a question to a recipient attached to the `Distributor`
/// and wait for a reply.
///
/// This is the timeout variant of the 'request' function. If the request cannot be performed
/// in the provided duration, an Error is propagated.
/// # Example
///
/// ```no_run
/// # use core::time::Duration;
/// # use bastion::prelude::*;
/// # #[cfg(feature = "tokio-runtime")]
/// # #[tokio::main]
/// # async fn main() {
/// # run();
/// # }
/// #
/// # #[cfg(not(feature = "tokio-runtime"))]
/// # fn main() {
/// # run();
/// # }
/// #
/// # async fn run() {
/// # Bastion::init();
/// # Bastion::start();
///
/// # Bastion::supervisor(|supervisor| {
/// # supervisor.children(|children| {
/// // attach a named distributor to the children
/// children
/// # .with_redundancy(1)
/// .with_distributor(Distributor::named("my distributor"))
/// .with_exec(|ctx: BastionContext| {
/// async move {
/// loop {
/// // The message handler needs an `on_question` section
/// // that matches the `question` you're going to send,
/// // and that will reply with the Type the request expects.
/// // In our example, we ask a `&str` question, and expect a `bool` reply.
/// MessageHandler::new(ctx.recv().await?)
/// .on_question(|message: &str, sender| {
/// if message == "is it raining today?" {
/// sender.reply(true).unwrap();
/// }
/// });
/// }
/// Ok(())
/// }
/// })
/// # })
/// # });
///
/// let distributor = Distributor::named("my distributor");
///
/// let timeout = Duration::from_millis(10);
/// let reply: Result<String, SendError> = distributor
/// .request_timeout("is it raining today?", timeout)
/// .await
/// .expect("couldn't receive reply");
///
/// # Bastion::stop();
/// # Bastion::block_until_stopped();
/// # }
/// ```
pub fn request_timeout<R: Message>(
&self,
question: impl Message,
timeout: Duration,
) -> oneshot::Receiver<Result<R, SendError>> {
let (sender, receiver) = oneshot::channel();
let s = *self;
spawn!(async move {
match SYSTEM.dispatcher().ask(s, question) {
Ok(response) => {
futures::select! {
response_awaited = response.fuse() => {
match response_awaited {
Ok(message) => {
let message_to_send = MessageHandler::new(message)
.on_tell(|reply: R, _| Ok(reply))
.on_fallback(|_, _| {
Err(SendError::Other(anyhow::anyhow!(
"received a message with the wrong type"
)))
});
let _ = sender.send(message_to_send);
}
Err(e) => {
let _ = sender.send(Err(SendError::Other(anyhow::anyhow!(
"couldn't receive reply: {:?}",
e
))));
}
}
},
_duration = Delay::new(timeout).fuse() => {
let _ = sender.send(Err(SendError::Other(anyhow::anyhow!(
"operation timed out before finish"
))));
}
}
}
Err(error) => {
let _ = sender.send(Err(error));
}
}
});

receiver
}

/// Ask a question to a recipient attached to the `Distributor`
///
/// # Example
Expand Down Expand Up @@ -539,8 +651,10 @@ impl Distributor {
#[cfg(test)]
mod distributor_tests {
use crate::prelude::*;
use core::time;
use futures::channel::mpsc::channel;
use futures::{SinkExt, StreamExt};
use std::{thread, time::Duration};

const TEST_DISTRIBUTOR: &str = "test distributor";
const SUBSCRIBE_TEST_DISTRIBUTOR: &str = "subscribe test";
Expand Down Expand Up @@ -675,12 +789,33 @@ mod distributor_tests {
});

let answer_sync: u8 = test_distributor
.request_sync(question)
.request_sync(question.clone())
.recv()
.unwrap()
.unwrap();

assert_eq!(42, answer_sync);

run!(async {
let timeout = Duration::from_millis(100);
let answer_timeout: u8 = test_distributor
.request_timeout(question.clone(), timeout)
.await
.unwrap()
.unwrap();
assert_eq!(42, answer_timeout);
});

run!(async {
let timeout = Duration::from_nanos(1);
let answer_timeout: Result<u8, SendError> = test_distributor
.request_timeout(question.clone(), timeout)
.await
.unwrap();

let err_msg: SendError = answer_timeout.unwrap_err();
assert!(matches!(err_msg, SendError::Other { .. }));
});
}

fn setup() {
Expand Down Expand Up @@ -709,6 +844,7 @@ mod distributor_tests {
let child_ref = ctx.current().clone();
MessageHandler::new(ctx.recv().await?)
.on_question(|_: String, sender| {
thread::sleep(time::Duration::from_millis(10));
let _ = sender.reply(42_u8);
})
// send your child ref
Expand Down