diff --git a/src/asynchronous/server.rs b/src/asynchronous/server.rs index 26c49f2c..9158e648 100644 --- a/src/asynchronous/server.rs +++ b/src/asynchronous/server.rs @@ -381,12 +381,14 @@ impl ReaderDelegate for ServerReader { async fn handle_msg(&self, msg: GenMessage) { let handler_shutdown_waiter = self.handler_shutdown.subscribe(); let context = self.context(); + let (wait_tx, wait_rx) = tokio::sync::oneshot::channel::<()>(); spawn(async move { select! { - _ = context.handle_msg(msg) => {} + _ = context.handle_msg(msg, wait_tx) => {} _ = handler_shutdown_waiter.wait_shutdown() => {} } }); + wait_rx.await.unwrap_or_default(); } async fn handle_err(&self, header: MessageHeader, e: Error) { @@ -424,7 +426,7 @@ impl HandlerContext { }) .ok(); } - async fn handle_msg(&self, msg: GenMessage) { + async fn handle_msg(&self, msg: GenMessage, wait_tx: tokio::sync::oneshot::Sender<()>) { let stream_id = msg.header.stream_id; if (stream_id % 2) != 1 { @@ -438,7 +440,7 @@ impl HandlerContext { } match msg.header.type_ { - MESSAGE_TYPE_REQUEST => match self.handle_request(msg).await { + MESSAGE_TYPE_REQUEST => match self.handle_request(msg, wait_tx).await { Ok(opt_msg) => match opt_msg { Some(mut resp) => { // Server: check size before sending to client @@ -471,6 +473,8 @@ impl HandlerContext { Err(status) => Self::respond_with_status(self.tx.clone(), stream_id, status).await, }, MESSAGE_TYPE_DATA => { + // no need to wait data message handling + drop(wait_tx); // TODO(wllenyj): Compatible with golang behavior. if (msg.header.flags & FLAG_REMOTE_CLOSED) == FLAG_REMOTE_CLOSED && !msg.payload.is_empty() @@ -518,7 +522,11 @@ impl HandlerContext { } } - async fn handle_request(&self, msg: GenMessage) -> StdResult, Status> { + async fn handle_request( + &self, + msg: GenMessage, + wait_tx: tokio::sync::oneshot::Sender<()>, + ) -> StdResult, Status> { //TODO: //if header.stream_id <= self.last_stream_id { // return Err; @@ -539,10 +547,11 @@ impl HandlerContext { })?; if let Some(method) = srv.get_method(&req.method) { + drop(wait_tx); return self.handle_method(method, req_msg).await; } if let Some(stream) = srv.get_stream(&req.method) { - return self.handle_stream(stream, req_msg).await; + return self.handle_stream(stream, req_msg, wait_tx).await; } Err(get_status( Code::UNIMPLEMENTED, @@ -598,6 +607,7 @@ impl HandlerContext { &self, stream: Arc, req_msg: Message, + wait_tx: tokio::sync::oneshot::Sender<()>, ) -> StdResult, Status> { let stream_id = req_msg.header.stream_id; let req = req_msg.payload; @@ -609,6 +619,8 @@ impl HandlerContext { let no_data = (req_msg.header.flags & FLAG_NO_DATA) == FLAG_NO_DATA; + drop(wait_tx); + let si = StreamInner::new( stream_id, self.tx.clone(),