Skip to content

Commit

Permalink
Merge pull request #1 from Burning1020/v0.7.1-kuasar
Browse files Browse the repository at this point in the history
cherry-pick containerd#220 containerd#222 containerd#220 from containerd/ttrpc-rust
  • Loading branch information
abel-von authored Jul 30, 2024
2 parents 5d1d5dc + 564bb21 commit db83ba8
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 23 deletions.
13 changes: 12 additions & 1 deletion compiler/src/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,12 @@ impl<'a> ServiceGen<'a> {
.any(|method| !matches!(method.method_type().0, MethodType::Unary))
}

fn has_unary_method(&self) -> bool {
self.methods
.iter()
.any(|method| matches!(method.method_type().0, MethodType::Unary))
}

fn write_client(&self, w: &mut CodeWriter) {
if async_on(self.customize, "client") {
self.write_async_client(w)
Expand Down Expand Up @@ -589,9 +595,14 @@ impl<'a> ServiceGen<'a> {
);

let has_stream_method = self.has_stream_method();
let has_unary_method = self.has_unary_method();
w.pub_fn(&s, |w| {
w.write_line("let mut ret = HashMap::new();");
w.write_line("let mut methods = HashMap::new();");
if has_unary_method {
w.write_line("let mut methods = HashMap::new();");
} else {
w.write_line("let methods = HashMap::new();");
}
if has_stream_method {
w.write_line("let mut streams = HashMap::new();");
} else {
Expand Down
8 changes: 5 additions & 3 deletions src/asynchronous/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ use crate::r#async::stream::{
};
use crate::r#async::utils;

use super::stream::SendingMessage;

/// A ttrpc Client (async).
#[derive(Clone)]
pub struct Client {
Expand Down Expand Up @@ -78,7 +80,7 @@ impl Client {
self.streams.lock().unwrap().insert(stream_id, tx);

self.req_tx
.send(msg)
.send(SendingMessage::new(msg))
.await
.map_err(|e| Error::Others(format!("Send packet to sender error {:?}", e)))?;

Expand Down Expand Up @@ -131,7 +133,7 @@ impl Client {
// TODO: check return
self.streams.lock().unwrap().insert(stream_id, tx);
self.req_tx
.send(msg)
.send(SendingMessage::new(msg))
.await
.map_err(|e| Error::Others(format!("Send packet to sender error {:?}", e)))?;

Expand Down Expand Up @@ -196,7 +198,7 @@ struct ClientWriter {

#[async_trait]
impl WriterDelegate for ClientWriter {
async fn recv(&mut self) -> Option<GenMessage> {
async fn recv(&mut self) -> Option<SendingMessage> {
self.rx.recv().await
}

Expand Down
14 changes: 9 additions & 5 deletions src/asynchronous/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ use tokio::{
use crate::error::Error;
use crate::proto::GenMessage;

use super::stream::SendingMessage;

pub trait Builder {
type Reader;
type Writer;
Expand All @@ -25,7 +27,7 @@ pub trait Builder {

#[async_trait]
pub trait WriterDelegate {
async fn recv(&mut self) -> Option<GenMessage>;
async fn recv(&mut self) -> Option<SendingMessage>;
async fn disconnect(&self, msg: &GenMessage, e: Error);
async fn exit(&self);
}
Expand Down Expand Up @@ -57,12 +59,14 @@ where
let (reader_delegate, mut writer_delegate) = builder.build();

let writer_task = tokio::spawn(async move {
while let Some(msg) = writer_delegate.recv().await {
trace!("write message: {:?}", msg);
if let Err(e) = msg.write_to(&mut writer).await {
while let Some(mut sending_msg) = writer_delegate.recv().await {
trace!("write message: {:?}", sending_msg.msg);
if let Err(e) = sending_msg.msg.write_to(&mut writer).await {
error!("write_message got error: {:?}", e);
writer_delegate.disconnect(&msg, e).await;
sending_msg.send_result(Err(e.clone()));
writer_delegate.disconnect(&sending_msg.msg, e).await;
}
sending_msg.send_result(Ok(()));
}
writer_delegate.exit().await;
trace!("Writer task exit.");
Expand Down
31 changes: 22 additions & 9 deletions src/asynchronous/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use tokio::{
#[cfg(target_os = "linux")]
use tokio_vsock::VsockListener;

use crate::asynchronous::unix_incoming::UnixIncoming;
use crate::asynchronous::{stream::SendingMessage, unix_incoming::UnixIncoming};
use crate::common::{self, Domain};
use crate::context;
use crate::error::{get_status, Error, Result};
Expand Down Expand Up @@ -329,7 +329,7 @@ struct ServerWriter {

#[async_trait]
impl WriterDelegate for ServerWriter {
async fn recv(&mut self) -> Option<GenMessage> {
async fn recv(&mut self) -> Option<SendingMessage> {
self.rx.recv().await
}
async fn disconnect(&self, _msg: &GenMessage, _: Error) {}
Expand Down Expand Up @@ -371,12 +371,14 @@ impl ReaderDelegate for ServerReader {
async fn handle_msg(&self, msg: GenMessage) {
let handler_shutdown_waiter = self.handler_shutdown.subscribe();
let context = self.context();
let (wait_tx, wait_rx) = tokio::sync::oneshot::channel::<()>();
spawn(async move {
select! {
_ = context.handle_msg(msg) => {}
_ = context.handle_msg(msg, wait_tx) => {}
_ = handler_shutdown_waiter.wait_shutdown() => {}
}
});
wait_rx.await.unwrap_or_default();
}
}

Expand All @@ -402,7 +404,7 @@ struct HandlerContext {
}

impl HandlerContext {
async fn handle_msg(&self, msg: GenMessage) {
async fn handle_msg(&self, msg: GenMessage, wait_tx: tokio::sync::oneshot::Sender<()>) {
let stream_id = msg.header.stream_id;

if (stream_id % 2) != 1 {
Expand All @@ -416,7 +418,7 @@ impl HandlerContext {
}

match msg.header.type_ {
MESSAGE_TYPE_REQUEST => match self.handle_request(msg).await {
MESSAGE_TYPE_REQUEST => match self.handle_request(msg, wait_tx).await {
Ok(opt_msg) => match opt_msg {
Some(msg) => {
Self::respond(self.tx.clone(), stream_id, msg)
Expand All @@ -435,7 +437,7 @@ impl HandlerContext {
};

self.tx
.send(msg)
.send(SendingMessage::new(msg))
.await
.map_err(err_to_others_err!(e, "Send packet to sender error "))
.ok();
Expand All @@ -444,6 +446,8 @@ impl HandlerContext {
Err(status) => Self::respond_with_status(self.tx.clone(), stream_id, status).await,
},
MESSAGE_TYPE_DATA => {
// no need to wait data message handling
drop(wait_tx);
// TODO(wllenyj): Compatible with golang behavior.
if (msg.header.flags & FLAG_REMOTE_CLOSED) == FLAG_REMOTE_CLOSED
&& !msg.payload.is_empty()
Expand Down Expand Up @@ -492,7 +496,11 @@ impl HandlerContext {
}
}

async fn handle_request(&self, msg: GenMessage) -> StdResult<Option<Response>, Status> {
async fn handle_request(
&self,
msg: GenMessage,
wait_tx: tokio::sync::oneshot::Sender<()>,
) -> StdResult<Option<Response>, Status> {
//TODO:
//if header.stream_id <= self.last_stream_id {
// return Err;
Expand All @@ -513,10 +521,11 @@ impl HandlerContext {
})?;

if let Some(method) = srv.get_method(&req.method) {
drop(wait_tx);
return self.handle_method(method, req_msg).await;
}
if let Some(stream) = srv.get_stream(&req.method) {
return self.handle_stream(stream, req_msg).await;
return self.handle_stream(stream, req_msg, wait_tx).await;
}
Err(get_status(
Code::UNIMPLEMENTED,
Expand Down Expand Up @@ -572,6 +581,7 @@ impl HandlerContext {
&self,
stream: Arc<dyn StreamHandler + Send + Sync>,
req_msg: Message<Request>,
wait_tx: tokio::sync::oneshot::Sender<()>,
) -> StdResult<Option<Response>, Status> {
let stream_id = req_msg.header.stream_id;
let req = req_msg.payload;
Expand All @@ -583,6 +593,9 @@ impl HandlerContext {

let _remote_close = (req_msg.header.flags & FLAG_REMOTE_CLOSED) == FLAG_REMOTE_CLOSED;
let _remote_open = (req_msg.header.flags & FLAG_REMOTE_OPEN) == FLAG_REMOTE_OPEN;

drop(wait_tx);

let si = StreamInner::new(
stream_id,
self.tx.clone(),
Expand Down Expand Up @@ -631,7 +644,7 @@ impl HandlerContext {
header: MessageHeader::new_response(stream_id, payload.len() as u32),
payload,
};
tx.send(msg)
tx.send(SendingMessage::new(msg))
.await
.map_err(err_to_others_err!(e, "Send packet to sender error "))
}
Expand Down
42 changes: 38 additions & 4 deletions src/asynchronous/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,42 @@ use crate::proto::{
MESSAGE_TYPE_DATA, MESSAGE_TYPE_RESPONSE,
};

pub type MessageSender = mpsc::Sender<GenMessage>;
pub type MessageReceiver = mpsc::Receiver<GenMessage>;
pub type MessageSender = mpsc::Sender<SendingMessage>;
pub type MessageReceiver = mpsc::Receiver<SendingMessage>;

pub type ResultSender = mpsc::Sender<Result<GenMessage>>;
pub type ResultReceiver = mpsc::Receiver<Result<GenMessage>>;

#[derive(Debug)]
pub struct SendingMessage {
pub msg: GenMessage,
pub result_chan: Option<tokio::sync::oneshot::Sender<Result<()>>>,
}

impl SendingMessage {
pub fn new(msg: GenMessage) -> Self {
Self {
msg,
result_chan: None,
}
}
pub fn new_with_result(
msg: GenMessage,
result_chan: tokio::sync::oneshot::Sender<Result<()>>,
) -> Self {
Self {
msg,
result_chan: Some(result_chan),
}
}

pub fn send_result(&mut self, result: Result<()>) {
if let Some(result_ch) = self.result_chan.take() {
result_ch.send(result).unwrap_or_default();
}
}
}

#[derive(Debug)]
pub struct ClientStream<Q, P> {
tx: CSSender<Q>,
Expand Down Expand Up @@ -317,9 +347,13 @@ async fn _recv(rx: &mut ResultReceiver) -> Result<GenMessage> {
}

async fn _send(tx: &MessageSender, msg: GenMessage) -> Result<()> {
tx.send(msg)
let (res_tx, res_rx) = tokio::sync::oneshot::channel();
tx.send(SendingMessage::new_with_result(msg, res_tx))
.await
.map_err(|e| Error::Others(format!("Send data packet to sender error {:?}", e)))?;
res_rx
.await
.map_err(|e| Error::Others(format!("Send data packet to sender error {:?}", e)))
.map_err(|e| Error::Others(format!("Failed to wait send result {:?}", e)))?
}

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
Expand Down
2 changes: 1 addition & 1 deletion ttrpc-codegen/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ readme = "README.md"
protobuf-support = "3.1.0"
protobuf = { version = "2.27.1" }
protobuf-codegen = "3.1.0"
ttrpc-compiler = "0.6.1"
ttrpc-compiler = { path = "../ttrpc-compiler" }

0 comments on commit db83ba8

Please sign in to comment.