From 3355f7d2c487634dd891a3425cee088b6cf06df0 Mon Sep 17 00:00:00 2001 From: Abel Feng Date: Thu, 14 Mar 2024 12:07:24 +0800 Subject: [PATCH] add channel to get send result Currently the send() method of stream implemented by send the value to an unbounded channel, so even the connection is closed for a long time, the send function still return succeed. This commit adds a channel to the message so that we can wait until the message is truely written to the connection. Signed-off-by: Abel Feng --- src/asynchronous/client.rs | 8 ++++--- src/asynchronous/connection.rs | 14 ++++++++---- src/asynchronous/server.rs | 8 +++---- src/asynchronous/stream.rs | 42 ++++++++++++++++++++++++++++++---- 4 files changed, 56 insertions(+), 16 deletions(-) diff --git a/src/asynchronous/client.rs b/src/asynchronous/client.rs index ed2ee3a8..80d4e7b0 100644 --- a/src/asynchronous/client.rs +++ b/src/asynchronous/client.rs @@ -27,6 +27,8 @@ use crate::r#async::stream::{ }; use crate::r#async::utils; +use super::stream::SendingMessage; + /// A ttrpc Client (async). #[derive(Clone)] pub struct Client { @@ -78,7 +80,7 @@ impl Client { self.streams.lock().unwrap().insert(stream_id, tx); self.req_tx - .send(msg) + .send(SendingMessage::new(msg)) .await .map_err(|e| Error::Others(format!("Send packet to sender error {e:?}")))?; @@ -139,7 +141,7 @@ impl Client { // TODO: check return self.streams.lock().unwrap().insert(stream_id, tx); self.req_tx - .send(msg) + .send(SendingMessage::new(msg)) .await .map_err(|e| Error::Others(format!("Send packet to sender error {e:?}")))?; @@ -204,7 +206,7 @@ struct ClientWriter { #[async_trait] impl WriterDelegate for ClientWriter { - async fn recv(&mut self) -> Option { + async fn recv(&mut self) -> Option { self.rx.recv().await } diff --git a/src/asynchronous/connection.rs b/src/asynchronous/connection.rs index 6372b25f..3ea062c4 100644 --- a/src/asynchronous/connection.rs +++ b/src/asynchronous/connection.rs @@ -16,6 +16,8 @@ use tokio::{ use crate::error::Error; use crate::proto::{GenMessage, GenMessageError, MessageHeader}; +use super::stream::SendingMessage; + pub trait Builder { type Reader; type Writer; @@ -25,7 +27,7 @@ pub trait Builder { #[async_trait] pub trait WriterDelegate { - async fn recv(&mut self) -> Option; + async fn recv(&mut self) -> Option; async fn disconnect(&self, msg: &GenMessage, e: Error); async fn exit(&self); } @@ -58,12 +60,14 @@ where let (reader_delegate, mut writer_delegate) = builder.build(); let writer_task = tokio::spawn(async move { - while let Some(msg) = writer_delegate.recv().await { - trace!("write message: {:?}", msg); - if let Err(e) = msg.write_to(&mut writer).await { + while let Some(mut sending_msg) = writer_delegate.recv().await { + trace!("write message: {:?}", sending_msg.msg); + if let Err(e) = sending_msg.msg.write_to(&mut writer).await { error!("write_message got error: {:?}", e); - writer_delegate.disconnect(&msg, e).await; + sending_msg.send_result(Err(e.clone())); + writer_delegate.disconnect(&sending_msg.msg, e).await; } + sending_msg.send_result(Ok(())); } writer_delegate.exit().await; trace!("Writer task exit."); diff --git a/src/asynchronous/server.rs b/src/asynchronous/server.rs index 26c49f2c..09348314 100644 --- a/src/asynchronous/server.rs +++ b/src/asynchronous/server.rs @@ -31,7 +31,7 @@ use tokio::{ #[cfg(any(target_os = "linux", target_os = "android"))] use tokio_vsock::VsockListener; -use crate::asynchronous::unix_incoming::UnixIncoming; +use crate::asynchronous::{stream::SendingMessage, unix_incoming::UnixIncoming}; use crate::common::{self, Domain}; use crate::context; use crate::error::{get_status, Error, Result}; @@ -339,7 +339,7 @@ struct ServerWriter { #[async_trait] impl WriterDelegate for ServerWriter { - async fn recv(&mut self) -> Option { + async fn recv(&mut self) -> Option { self.rx.recv().await } async fn disconnect(&self, _msg: &GenMessage, _: Error) {} @@ -462,7 +462,7 @@ impl HandlerContext { }; self.tx - .send(msg) + .send(SendingMessage::new(msg)) .await .map_err(err_to_others_err!(e, "Send packet to sender error ")) .ok(); @@ -652,7 +652,7 @@ impl HandlerContext { header: MessageHeader::new_response(stream_id, payload.len() as u32), payload, }; - tx.send(msg) + tx.send(SendingMessage::new(msg)) .await .map_err(err_to_others_err!(e, "Send packet to sender error ")) } diff --git a/src/asynchronous/stream.rs b/src/asynchronous/stream.rs index d3db18d6..5172ddb4 100644 --- a/src/asynchronous/stream.rs +++ b/src/asynchronous/stream.rs @@ -17,12 +17,42 @@ use crate::proto::{ MESSAGE_TYPE_DATA, MESSAGE_TYPE_RESPONSE, }; -pub type MessageSender = mpsc::Sender; -pub type MessageReceiver = mpsc::Receiver; +pub type MessageSender = mpsc::Sender; +pub type MessageReceiver = mpsc::Receiver; pub type ResultSender = mpsc::Sender>; pub type ResultReceiver = mpsc::Receiver>; +#[derive(Debug)] +pub struct SendingMessage { + pub msg: GenMessage, + pub result_chan: Option>>, +} + +impl SendingMessage { + pub fn new(msg: GenMessage) -> Self { + Self { + msg, + result_chan: None, + } + } + pub fn new_with_result( + msg: GenMessage, + result_chan: tokio::sync::oneshot::Sender>, + ) -> Self { + Self { + msg, + result_chan: Some(result_chan), + } + } + + pub fn send_result(&mut self, result: Result<()>) { + if let Some(result_ch) = self.result_chan.take() { + result_ch.send(result).unwrap_or_default(); + } + } +} + #[derive(Debug)] pub struct ClientStream { tx: CSSender, @@ -317,9 +347,13 @@ async fn _recv(rx: &mut ResultReceiver) -> Result { } async fn _send(tx: &MessageSender, msg: GenMessage) -> Result<()> { - tx.send(msg) + let (res_tx, res_rx) = tokio::sync::oneshot::channel(); + tx.send(SendingMessage::new_with_result(msg, res_tx)) + .await + .map_err(|e| Error::Others(format!("Send data packet to sender error {:?}", e)))?; + res_rx .await - .map_err(|e| Error::Others(format!("Send data packet to sender error {e:?}"))) + .map_err(|e| Error::Others(format!("Failed to wait send result {:?}", e)))? } #[derive(Clone, Copy, Debug, PartialEq, Eq)]