diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index b0cc136c..198475a2 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -18,7 +18,7 @@ jobs: with: access_token: ${{ github.token }} - uses: actions/checkout@v3 - - uses: dtolnay/rust-toolchain@stable + - uses: dtolnay/rust-toolchain@nightly with: targets: mipsel-unknown-linux-gnu - run: cargo check --all-features @@ -33,7 +33,7 @@ jobs: with: access_token: ${{ github.token }} - uses: actions/checkout@v3 - - uses: dtolnay/rust-toolchain@stable + - uses: dtolnay/rust-toolchain@nightly - run: cargo test - run: cargo test --manifest-path tarpc/Cargo.toml --features serde1 - run: cargo test --manifest-path tarpc/Cargo.toml --features tokio1 @@ -50,7 +50,7 @@ jobs: with: access_token: ${{ github.token }} - uses: actions/checkout@v3 - - uses: dtolnay/rust-toolchain@stable + - uses: dtolnay/rust-toolchain@nightly with: components: rustfmt - run: cargo fmt --all -- --check @@ -64,7 +64,7 @@ jobs: with: access_token: ${{ github.token }} - uses: actions/checkout@v3 - - uses: dtolnay/rust-toolchain@stable + - uses: dtolnay/rust-toolchain@nightly with: components: clippy - run: cargo clippy --all-features -- -D warnings diff --git a/README.md b/README.md index 0b5b9f46..419e4e8b 100644 --- a/README.md +++ b/README.md @@ -67,7 +67,7 @@ Some other features of tarpc: Add to your `Cargo.toml` dependencies: ```toml -tarpc = "0.33" +tarpc = "0.31" ``` The `tarpc::service` attribute expands to a collection of items that form an rpc service. @@ -127,7 +127,7 @@ impl World for HelloServer { type HelloFut = Ready; - fn hello(self, _: context::Context, name: String) -> Self::HelloFut { + fn hello(self, _: &mut context::Context, name: String) -> Self::HelloFut { future::ready(format!("Hello, {name}!")) } } diff --git a/RELEASES.md b/RELEASES.md index 8ea6ca37..a6ce438b 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -1,27 +1,3 @@ -## 0.33.0 (2023-04-01) - -### Breaking Changes - -Opentelemetry dependency version increased to 0.18. - -## 0.32.0 (2023-03-24) - -### Breaking Changes - -- As part of a fix to return more channel errors in RPC results, a few error types have changed: - - 0. `client::RpcError::Disconnected` was split into the following errors: - - Shutdown: the client was shutdown, either intentionally or due to an error. If due to an - error, pending RPCs should see the more specific errors below. - - Send: an RPC message failed to send over the transport. Only the RPC that failed to be sent - will see this error. - - Receive: a fatal error occurred while receiving from the transport. All in-flight RPCs will - receive this error. - 0. `client::ChannelError` and `server::ChannelError` are unified in `tarpc::ChannelError`. - Previously, server transport errors would not indicate during which activity the transport - error occurred. Now, just like the client already was, it will be specific: reading, readying, - sending, flushing, or closing. - ## 0.31.0 (2022-11-03) ### New Features diff --git a/example-service/Cargo.toml b/example-service/Cargo.toml index 8b325a4f..e7601117 100644 --- a/example-service/Cargo.toml +++ b/example-service/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tarpc-example-service" -version = "0.15.0" +version = "0.13.0" rust-version = "1.56" authors = ["Tim Kuehn "] edition = "2021" @@ -21,7 +21,7 @@ futures = "0.3" opentelemetry = { version = "0.17", features = ["rt-tokio"] } opentelemetry-jaeger = { version = "0.16", features = ["rt-tokio"] } rand = "0.8" -tarpc = { version = "0.33", path = "../tarpc", features = ["full"] } +tarpc = { version = "0.31", path = "../tarpc", features = ["full"] } tokio = { version = "1", features = ["macros", "net", "rt-multi-thread"] } tracing = { version = "0.1" } tracing-opentelemetry = "0.17" diff --git a/example-service/src/client.rs b/example-service/src/client.rs index 2877c815..f59003bd 100644 --- a/example-service/src/client.rs +++ b/example-service/src/client.rs @@ -26,8 +26,7 @@ async fn main() -> anyhow::Result<()> { let flags = Flags::parse(); init_tracing("Tarpc Example Client")?; - let mut transport = tarpc::serde_transport::tcp::connect(flags.server_addr, Json::default); - transport.config_mut().max_frame_length(usize::MAX); + let transport = tarpc::serde_transport::tcp::connect(flags.server_addr, Json::default); // WorldClient is generated by the service attribute. It has a constructor `new` that takes a // config and any Transport as input. @@ -43,10 +42,7 @@ async fn main() -> anyhow::Result<()> { .instrument(tracing::info_span!("Two Hellos")) .await; - match hello { - Ok(hello) => tracing::info!("{hello:?}"), - Err(e) => tracing::warn!("{:?}", anyhow::Error::from(e)), - } + tracing::info!("{:?}", hello); // Let the background span processor finish. sleep(Duration::from_micros(1)).await; diff --git a/example-service/src/server.rs b/example-service/src/server.rs index b0281e98..3cd4b43f 100644 --- a/example-service/src/server.rs +++ b/example-service/src/server.rs @@ -3,7 +3,6 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. - use clap::Parser; use futures::{future, prelude::*}; use rand::{ @@ -34,9 +33,8 @@ struct Flags { #[derive(Clone)] struct HelloServer(SocketAddr); -#[tarpc::server] impl World for HelloServer { - async fn hello(self, _: context::Context, name: String) -> String { + async fn hello(self, _: &mut context::Context, name: String) -> String { let sleep_time = Duration::from_millis(Uniform::new_inclusive(1, 10).sample(&mut thread_rng())); time::sleep(sleep_time).await; @@ -44,6 +42,10 @@ impl World for HelloServer { } } +async fn spawn(fut: impl Future + Send + 'static) { + tokio::spawn(fut); +} + #[tokio::main] async fn main() -> anyhow::Result<()> { let flags = Flags::parse(); @@ -66,7 +68,7 @@ async fn main() -> anyhow::Result<()> { // the generated World trait. .map(|channel| { let server = HelloServer(channel.transport().peer_addr().unwrap()); - channel.execute(server.serve()) + channel.execute(server.serve()).for_each(spawn) }) // Max 10 channels. .buffer_unordered(10) diff --git a/hooks/pre-push b/hooks/pre-push index 7b527e0a..1e5500d6 100755 --- a/hooks/pre-push +++ b/hooks/pre-push @@ -84,12 +84,12 @@ command -v rustup &>/dev/null if [ "$?" == 0 ]; then printf "${SUCCESS}\n" - try_run "Building ... " cargo +stable build --color=always - try_run "Testing ... " cargo +stable test --color=always - try_run "Testing with all features enabled ... " cargo +stable test --all-features --color=always - for EXAMPLE in $(cargo +stable run --example 2>&1 | grep ' ' | awk '{print $1}') + try_run "Building ... " cargo build --color=always + try_run "Testing ... " cargo test --color=always + try_run "Testing with all features enabled ... " cargo test --all-features --color=always + for EXAMPLE in $(cargo run --example 2>&1 | grep ' ' | awk '{print $1}') do - try_run "Running example \"$EXAMPLE\" ... " cargo +stable run --example $EXAMPLE + try_run "Running example \"$EXAMPLE\" ... " cargo run --example $EXAMPLE done check_toolchain nightly diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index 1b83c324..15490017 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -12,18 +12,18 @@ extern crate quote; extern crate syn; use proc_macro::TokenStream; -use proc_macro2::{Span, TokenStream as TokenStream2}; +use proc_macro2::TokenStream as TokenStream2; use quote::{format_ident, quote, ToTokens}; use syn::{ braced, ext::IdentExt, parenthesized, parse::{Parse, ParseStream}, - parse_macro_input, parse_quote, parse_str, + parse_macro_input, parse_quote, spanned::Spanned, token::Comma, - Attribute, FnArg, Ident, ImplItem, ImplItemMethod, ImplItemType, ItemImpl, Lit, LitBool, - MetaNameValue, Pat, PatType, ReturnType, Token, Type, Visibility, + Attribute, FnArg, Ident, Lit, LitBool, MetaNameValue, Pat, PatType, ReturnType, Token, Type, + Visibility, }; /// Accumulates multiple errors into a result. @@ -257,7 +257,6 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { .map(|rpc| snake_to_camel(&rpc.ident.unraw().to_string())) .collect(); let args: &[&[PatType]] = &rpcs.iter().map(|rpc| &*rpc.args).collect::>(); - let response_fut_name = &format!("{}ResponseFut", ident.unraw()); let derive_serialize = if derive_serde.0 { Some( quote! {#[derive(tarpc::serde::Serialize, tarpc::serde::Deserialize)] @@ -274,10 +273,9 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { .collect::>(); ServiceGenerator { - response_fut_name, service_ident: ident, + client_stub_ident: &format_ident!("{}Stub", ident), server_ident: &format_ident!("Serve{}", ident), - response_fut_ident: &Ident::new(response_fut_name, ident.span()), client_ident: &format_ident!("{}Client", ident), request_ident: &format_ident!("{}Request", ident), response_ident: &format_ident!("{}Response", ident), @@ -304,137 +302,18 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { .zip(camel_case_fn_names.iter()) .map(|(rpc, name)| Ident::new(name, rpc.ident.span())) .collect::>(), - future_types: &camel_case_fn_names - .iter() - .map(|name| parse_str(&format!("{name}Fut")).unwrap()) - .collect::>(), derive_serialize: derive_serialize.as_ref(), } .into_token_stream() .into() } -/// generate an identifier consisting of the method name to CamelCase with -/// Fut appended to it. -fn associated_type_for_rpc(method: &ImplItemMethod) -> String { - snake_to_camel(&method.sig.ident.unraw().to_string()) + "Fut" -} - -/// Transforms an async function into a sync one, returning a type declaration -/// for the return type (a future). -fn transform_method(method: &mut ImplItemMethod) -> ImplItemType { - method.sig.asyncness = None; - - // get either the return type or (). - let ret = match &method.sig.output { - ReturnType::Default => quote!(()), - ReturnType::Type(_, ret) => quote!(#ret), - }; - - let fut_name = associated_type_for_rpc(method); - let fut_name_ident = Ident::new(&fut_name, method.sig.ident.span()); - - // generate the updated return signature. - method.sig.output = parse_quote! { - -> ::core::pin::Pin + ::core::marker::Send - >> - }; - - // transform the body of the method into Box::pin(async move { body }). - let block = method.block.clone(); - method.block = parse_quote! [{ - Box::pin(async move - #block - ) - }]; - - // generate and return type declaration for return type. - let t: ImplItemType = parse_quote! { - type #fut_name_ident = ::core::pin::Pin + ::core::marker::Send>>; - }; - - t -} - -#[proc_macro_attribute] -pub fn server(_attr: TokenStream, input: TokenStream) -> TokenStream { - let mut item = syn::parse_macro_input!(input as ItemImpl); - let span = item.span(); - - // the generated type declarations - let mut types: Vec = Vec::new(); - let mut expected_non_async_types: Vec<(&ImplItemMethod, String)> = Vec::new(); - let mut found_non_async_types: Vec<&ImplItemType> = Vec::new(); - - for inner in &mut item.items { - match inner { - ImplItem::Method(method) => { - if method.sig.asyncness.is_some() { - // if this function is declared async, transform it into a regular function - let typedecl = transform_method(method); - types.push(typedecl); - } else { - // If it's not async, keep track of all required associated types for better - // error reporting. - expected_non_async_types.push((method, associated_type_for_rpc(method))); - } - } - ImplItem::Type(typedecl) => found_non_async_types.push(typedecl), - _ => {} - } - } - - if let Err(e) = - verify_types_were_provided(span, &expected_non_async_types, &found_non_async_types) - { - return TokenStream::from(e.to_compile_error()); - } - - // add the type declarations into the impl block - for t in types.into_iter() { - item.items.push(syn::ImplItem::Type(t)); - } - - TokenStream::from(quote!(#item)) -} - -fn verify_types_were_provided( - span: Span, - expected: &[(&ImplItemMethod, String)], - provided: &[&ImplItemType], -) -> syn::Result<()> { - let mut result = Ok(()); - for (method, expected) in expected { - if !provided.iter().any(|typedecl| typedecl.ident == expected) { - let mut e = syn::Error::new( - span, - format!("not all trait items implemented, missing: `{expected}`"), - ); - let fn_span = method.sig.fn_token.span(); - e.extend(syn::Error::new( - fn_span.join(method.sig.ident.span()).unwrap_or(fn_span), - format!( - "hint: `#[tarpc::server]` only rewrites async fns, and `fn {}` is not async", - method.sig.ident - ), - )); - match result { - Ok(_) => result = Err(e), - Err(ref mut error) => error.extend(Some(e)), - } - } - } - result -} - // Things needed to generate the service items: trait, serve impl, request/response enums, and // the client stub. struct ServiceGenerator<'a> { service_ident: &'a Ident, + client_stub_ident: &'a Ident, server_ident: &'a Ident, - response_fut_ident: &'a Ident, - response_fut_name: &'a str, client_ident: &'a Ident, request_ident: &'a Ident, response_ident: &'a Ident, @@ -442,7 +321,6 @@ struct ServiceGenerator<'a> { attrs: &'a [Attribute], rpcs: &'a [RpcMethod], camel_case_idents: &'a [Ident], - future_types: &'a [Type], method_idents: &'a [&'a Ident], request_names: &'a [String], method_attrs: &'a [&'a [Attribute]], @@ -458,42 +336,37 @@ impl<'a> ServiceGenerator<'a> { attrs, rpcs, vis, - future_types, return_types, service_ident, + client_stub_ident, + request_ident, + response_ident, server_ident, .. } = self; - let types_and_fns = rpcs + let rpc_fns = rpcs .iter() - .zip(future_types.iter()) .zip(return_types.iter()) .map( |( - ( - RpcMethod { - attrs, ident, args, .. - }, - future_type, - ), + RpcMethod { + attrs, ident, args, .. + }, output, )| { - let ty_doc = format!("The response future returned by [`{service_ident}::{ident}`]."); quote! { - #[doc = #ty_doc] - type #future_type: std::future::Future; - #( #attrs )* - fn #ident(self, context: tarpc::context::Context, #( #args ),*) -> Self::#future_type; + async fn #ident(self, context: &mut tarpc::context::Context, #( #args ),*) -> #output; } }, ); + let stub_doc = format!("The stub trait for service [`{service_ident}`]."); quote! { #( #attrs )* #vis trait #service_ident: Sized { - #( #types_and_fns )* + #( #rpc_fns )* /// Returns a serving function to use with /// [InFlightRequest::execute](tarpc::server::InFlightRequest::execute). @@ -501,6 +374,15 @@ impl<'a> ServiceGenerator<'a> { #server_ident { service: self } } } + + #[doc = #stub_doc] + #vis trait #client_stub_ident: tarpc::client::stub::Stub { + } + + impl #client_stub_ident for S + where S: tarpc::client::stub::Stub + { + } } } @@ -524,7 +406,6 @@ impl<'a> ServiceGenerator<'a> { server_ident, service_ident, response_ident, - response_fut_ident, camel_case_idents, arg_pats, method_idents, @@ -533,11 +414,11 @@ impl<'a> ServiceGenerator<'a> { } = self; quote! { - impl tarpc::server::Serve<#request_ident> for #server_ident + impl tarpc::server::Serve for #server_ident where S: #service_ident { + type Req = #request_ident; type Resp = #response_ident; - type Fut = #response_fut_ident; fn method(&self, req: &#request_ident) -> Option<&'static str> { Some(match req { @@ -549,15 +430,16 @@ impl<'a> ServiceGenerator<'a> { }) } - fn serve(self, ctx: tarpc::context::Context, req: #request_ident) -> Self::Fut { + async fn serve(self, ctx: &mut tarpc::context::Context, req: #request_ident) + -> Result<#response_ident, tarpc::ServerError> { match req { #( #request_ident::#camel_case_idents{ #( #arg_pats ),* } => { - #response_fut_ident::#camel_case_idents( + Ok(#response_ident::#camel_case_idents( #service_ident::#method_idents( self.service, ctx, #( #arg_pats ),* - ) - ) + ).await + )) } )* } @@ -608,73 +490,6 @@ impl<'a> ServiceGenerator<'a> { } } - fn enum_response_future(&self) -> TokenStream2 { - let &Self { - vis, - service_ident, - response_fut_ident, - camel_case_idents, - future_types, - .. - } = self; - - quote! { - /// A future resolving to a server response. - #[allow(missing_docs)] - #vis enum #response_fut_ident { - #( #camel_case_idents(::#future_types) ),* - } - } - } - - fn impl_debug_for_response_future(&self) -> TokenStream2 { - let &Self { - service_ident, - response_fut_ident, - response_fut_name, - .. - } = self; - - quote! { - impl std::fmt::Debug for #response_fut_ident { - fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { - fmt.debug_struct(#response_fut_name).finish() - } - } - } - } - - fn impl_future_for_response_future(&self) -> TokenStream2 { - let &Self { - service_ident, - response_fut_ident, - response_ident, - camel_case_idents, - .. - } = self; - - quote! { - impl std::future::Future for #response_fut_ident { - type Output = #response_ident; - - fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) - -> std::task::Poll<#response_ident> - { - unsafe { - match std::pin::Pin::get_unchecked_mut(self) { - #( - #response_fut_ident::#camel_case_idents(resp) => - std::pin::Pin::new_unchecked(resp) - .poll(cx) - .map(#response_ident::#camel_case_idents), - )* - } - } - } - } - } - } - fn struct_client(&self) -> TokenStream2 { let &Self { vis, @@ -689,7 +504,9 @@ impl<'a> ServiceGenerator<'a> { #[derive(Clone, Debug)] /// The client stub that makes RPC calls to the server. All request methods return /// [Futures](std::future::Future). - #vis struct #client_ident(tarpc::client::Channel<#request_ident, #response_ident>); + #vis struct #client_ident< + Stub = tarpc::client::Channel<#request_ident, #response_ident> + >(Stub); } } @@ -719,6 +536,17 @@ impl<'a> ServiceGenerator<'a> { dispatch: new_client.dispatch, } } + } + + impl From for #client_ident + where Stub: tarpc::client::stub::Stub< + Req = #request_ident, + Resp = #response_ident> + { + /// Returns a new client stub that sends requests over the given transport. + fn from(stub: Stub) -> Self { + #client_ident(stub) + } } } @@ -741,7 +569,11 @@ impl<'a> ServiceGenerator<'a> { } = self; quote! { - impl #client_ident { + impl #client_ident + where Stub: tarpc::client::stub::Stub< + Req = #request_ident, + Resp = #response_ident> + { #( #[allow(unused)] #( #method_attrs )* @@ -770,9 +602,6 @@ impl<'a> ToTokens for ServiceGenerator<'a> { self.impl_serve_for_server(), self.enum_request(), self.enum_response(), - self.enum_response_future(), - self.impl_debug_for_response_future(), - self.impl_future_for_response_future(), self.struct_client(), self.impl_client_new(), self.impl_client_rpc_methods(), diff --git a/plugins/tests/server.rs b/plugins/tests/server.rs index f0222ffd..9d412969 100644 --- a/plugins/tests/server.rs +++ b/plugins/tests/server.rs @@ -1,8 +1,3 @@ -use assert_type_eq::assert_type_eq; -use futures::Future; -use std::pin::Pin; -use tarpc::context; - // these need to be out here rather than inside the function so that the // assert_type_eq macro can pick them up. #[tarpc::service] @@ -12,42 +7,6 @@ trait Foo { async fn baz(); } -#[test] -fn type_generation_works() { - #[tarpc::server] - impl Foo for () { - async fn two_part(self, _: context::Context, s: String, i: i32) -> (String, i32) { - (s, i) - } - - async fn bar(self, _: context::Context, s: String) -> String { - s - } - - async fn baz(self, _: context::Context) {} - } - - // the assert_type_eq macro can only be used once per block. - { - assert_type_eq!( - <() as Foo>::TwoPartFut, - Pin + Send>> - ); - } - { - assert_type_eq!( - <() as Foo>::BarFut, - Pin + Send>> - ); - } - { - assert_type_eq!( - <() as Foo>::BazFut, - Pin + Send>> - ); - } -} - #[allow(non_camel_case_types)] #[test] fn raw_idents_work() { @@ -59,24 +18,6 @@ fn raw_idents_work() { async fn r#fn(r#impl: r#yield) -> r#yield; async fn r#async(); } - - #[tarpc::server] - impl r#trait for () { - async fn r#await( - self, - _: context::Context, - r#struct: r#yield, - r#enum: i32, - ) -> (r#yield, i32) { - (r#struct, r#enum) - } - - async fn r#fn(self, _: context::Context, r#impl: r#yield) -> r#yield { - r#impl - } - - async fn r#async(self, _: context::Context) {} - } } #[test] @@ -100,45 +41,4 @@ fn syntax() { #[doc = "attr"] async fn one_arg_implicit_return_error(one: String); } - - #[tarpc::server] - impl Syntax for () { - #[deny(warnings)] - #[allow(non_snake_case)] - async fn TestCamelCaseDoesntConflict(self, _: context::Context) {} - - async fn hello(self, _: context::Context) -> String { - String::new() - } - - async fn attr(self, _: context::Context, _s: String) -> String { - String::new() - } - - async fn no_args_no_return(self, _: context::Context) {} - - async fn no_args(self, _: context::Context) -> () {} - - async fn one_arg(self, _: context::Context, _one: String) -> i32 { - 0 - } - - async fn two_args_no_return(self, _: context::Context, _one: String, _two: u64) {} - - async fn two_args(self, _: context::Context, _one: String, _two: u64) -> String { - String::new() - } - - async fn no_args_ret_error(self, _: context::Context) -> i32 { - 0 - } - - async fn one_arg_ret_error(self, _: context::Context, _one: String) -> String { - String::new() - } - - async fn no_arg_implicit_return_error(self, _: context::Context) {} - - async fn one_arg_implicit_return_error(self, _: context::Context, _one: String) {} - } } diff --git a/plugins/tests/service.rs b/plugins/tests/service.rs index b37cbcea..afb62ce8 100644 --- a/plugins/tests/service.rs +++ b/plugins/tests/service.rs @@ -2,8 +2,6 @@ use tarpc::context; #[test] fn att_service_trait() { - use futures::future::{ready, Ready}; - #[tarpc::service] trait Foo { async fn two_part(s: String, i: i32) -> (String, i32); @@ -12,19 +10,16 @@ fn att_service_trait() { } impl Foo for () { - type TwoPartFut = Ready<(String, i32)>; - fn two_part(self, _: context::Context, s: String, i: i32) -> Self::TwoPartFut { - ready((s, i)) + async fn two_part(self, _: &mut context::Context, s: String, i: i32) -> (String, i32) { + (s, i) } - type BarFut = Ready; - fn bar(self, _: context::Context, s: String) -> Self::BarFut { - ready(s) + async fn bar(self, _: &mut context::Context, s: String) -> String { + s } - type BazFut = Ready<()>; - fn baz(self, _: context::Context) -> Self::BazFut { - ready(()) + async fn baz(self, _: &mut context::Context) { + () } } } @@ -32,8 +27,6 @@ fn att_service_trait() { #[allow(non_camel_case_types)] #[test] fn raw_idents() { - use futures::future::{ready, Ready}; - type r#yield = String; #[tarpc::service] @@ -44,19 +37,21 @@ fn raw_idents() { } impl r#trait for () { - type AwaitFut = Ready<(r#yield, i32)>; - fn r#await(self, _: context::Context, r#struct: r#yield, r#enum: i32) -> Self::AwaitFut { - ready((r#struct, r#enum)) + async fn r#await( + self, + _: &mut context::Context, + r#struct: r#yield, + r#enum: i32, + ) -> (r#yield, i32) { + (r#struct, r#enum) } - type FnFut = Ready; - fn r#fn(self, _: context::Context, r#impl: r#yield) -> Self::FnFut { - ready(r#impl) + async fn r#fn(self, _: &mut context::Context, r#impl: r#yield) -> r#yield { + r#impl } - type AsyncFut = Ready<()>; - fn r#async(self, _: context::Context) -> Self::AsyncFut { - ready(()) + async fn r#async(self, _: &mut context::Context) { + () } } } diff --git a/tarpc/Cargo.toml b/tarpc/Cargo.toml index 97ac9523..a36a62e7 100644 --- a/tarpc/Cargo.toml +++ b/tarpc/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tarpc" -version = "0.33.0" +version = "0.31.0" rust-version = "1.58.0" authors = [ "Adam Wright ", @@ -19,7 +19,7 @@ description = "An RPC framework for Rust with a focus on ease of use." [features] default = [] -serde1 = ["tarpc-plugins/serde1", "serde", "serde/derive"] +serde1 = ["tarpc-plugins/serde1", "serde", "serde/derive", "serde/rc"] tokio1 = ["tokio/rt"] serde-transport = ["serde1", "tokio1", "tokio-serde", "tokio-util/codec"] serde-transport-json = ["tokio-serde/json"] @@ -42,6 +42,7 @@ travis-ci = { repository = "google/tarpc" } [dependencies] anyhow = "1.0" +anymap = "0.12.1" fnv = "1.0" futures = "0.3" humantime = "2.0" @@ -58,8 +59,8 @@ tracing = { version = "0.1", default-features = false, features = [ "attributes", "log", ] } -tracing-opentelemetry = { version = "0.18.0", default-features = false } -opentelemetry = { version = "0.18.0", default-features = false } +tracing-opentelemetry = { version = "0.17.2", default-features = false } +opentelemetry = { version = "0.17.0", default-features = false } [dev-dependencies] @@ -68,18 +69,17 @@ bincode = "1.3" bytes = { version = "1", features = ["serde"] } flate2 = "1.0" futures-test = "0.3" -opentelemetry = { version = "0.18.0", default-features = false, features = [ +opentelemetry = { version = "0.17.0", default-features = false, features = [ "rt-tokio", ] } -opentelemetry-jaeger = { version = "0.17.0", features = ["rt-tokio"] } +opentelemetry-jaeger = { version = "0.16.0", features = ["rt-tokio"] } pin-utils = "0.1.0-alpha" serde_bytes = "0.11" tracing-subscriber = { version = "0.3", features = ["env-filter"] } -tokio = { version = "1", features = ["full", "test-util"] } +tokio = { version = "1", features = ["full", "test-util", "tracing"] } +console-subscriber = "0.1" tokio-serde = { version = "0.8", features = ["json", "bincode"] } trybuild = "1.0" -tokio-rustls = "0.23" -rustls-pemfile = "1.0" [package.metadata.docs.rs] all-features = true @@ -105,10 +105,6 @@ required-features = ["full"] name = "custom_transport" required-features = ["serde1", "tokio1", "serde-transport"] -[[example]] -name = "tls_over_tcp" -required-features = ["full"] - [[test]] name = "service_functional" required-features = ["serde-transport"] diff --git a/tarpc/examples/certs/eddsa/client.cert b/tarpc/examples/certs/eddsa/client.cert deleted file mode 100644 index 0d314458..00000000 --- a/tarpc/examples/certs/eddsa/client.cert +++ /dev/null @@ -1,11 +0,0 @@ ------BEGIN CERTIFICATE----- -MIIBlDCCAUagAwIBAgICAxUwBQYDK2VwMC4xLDAqBgNVBAMMI3Bvbnl0b3duIEVk -RFNBIGxldmVsIDIgaW50ZXJtZWRpYXRlMB4XDTIzMDMxNzEwMTEwNFoXDTI4MDkw -NjEwMTEwNFowGjEYMBYGA1UEAwwPcG9ueXRvd24gY2xpZW50MCowBQYDK2VwAyEA -NTKuLume19IhJfEFd/5OZUuYDKZH6xvy4AGver17OoejgZswgZgwDAYDVR0TAQH/ -BAIwADALBgNVHQ8EBAMCBsAwFgYDVR0lAQH/BAwwCgYIKwYBBQUHAwIwHQYDVR0O -BBYEFDjdrlMu4tyw5MHtbg7WnzSGRBpFMEQGA1UdIwQ9MDuAFHIl7fHKWP6/l8FE -fI2YEIM3oHxKoSCkHjAcMRowGAYDVQQDDBFwb255dG93biBFZERTQSBDQYIBezAF -BgMrZXADQQCaahfj/QLxoCOpvl6y0ZQ9CpojPqBnxV3460j5nUOp040Va2MpF137 -izCBY7LwgUE/YG6E+kH30G4jMEnqVEYK ------END CERTIFICATE----- diff --git a/tarpc/examples/certs/eddsa/client.chain b/tarpc/examples/certs/eddsa/client.chain deleted file mode 100644 index cd760dc2..00000000 --- a/tarpc/examples/certs/eddsa/client.chain +++ /dev/null @@ -1,19 +0,0 @@ ------BEGIN CERTIFICATE----- -MIIBeDCCASqgAwIBAgIBezAFBgMrZXAwHDEaMBgGA1UEAwwRcG9ueXRvd24gRWRE -U0EgQ0EwHhcNMjMwMzE3MTAxMTA0WhcNMzMwMzE0MTAxMTA0WjAuMSwwKgYDVQQD -DCNwb255dG93biBFZERTQSBsZXZlbCAyIGludGVybWVkaWF0ZTAqMAUGAytlcAMh -AEFsAexz4x2R4k4+PnTbvRVn0r3F/qw/zVnNBxfGcoEpo38wfTAdBgNVHQ4EFgQU -ciXt8cpY/r+XwUR8jZgQgzegfEowIAYDVR0lAQH/BBYwFAYIKwYBBQUHAwEGCCsG -AQUFBwMCMAwGA1UdEwQFMAMBAf8wCwYDVR0PBAQDAgH+MB8GA1UdIwQYMBaAFKYU -oLdKeY7mp7QgMZKrkVtSWYBKMAUGAytlcANBAHVpNpCV8nu4fkH3Smikx5A9qtHc -zgLIyp+wrF1a4YSa6sfTvuQmJd5aF23OXgq5grCOPXtdpHO50Mx5Qy74zQg= ------END CERTIFICATE----- ------BEGIN CERTIFICATE----- -MIIBTDCB/6ADAgECAhRZLuF0TWjDs/31OO8VeKHkNIJQaDAFBgMrZXAwHDEaMBgG -A1UEAwwRcG9ueXRvd24gRWREU0EgQ0EwHhcNMjMwMzE3MTAxMTA0WhcNMzMwMzE0 -MTAxMTA0WjAcMRowGAYDVQQDDBFwb255dG93biBFZERTQSBDQTAqMAUGAytlcAMh -ABRPZ4TiuBE8CqAFByZvqpMo/unjnnryfG2AkkWGXpa3o1MwUTAdBgNVHQ4EFgQU -phSgt0p5juantCAxkquRW1JZgEowHwYDVR0jBBgwFoAUphSgt0p5juantCAxkquR -W1JZgEowDwYDVR0TAQH/BAUwAwEB/zAFBgMrZXADQQB29o8erJA0/a8/xOHilOCC -t/s5wPHHnS5NSKx/m2N2nRn3zPxEnETlrAmGulJoeKOx8OblwmPi9rBT2K+QY2UB ------END CERTIFICATE----- diff --git a/tarpc/examples/certs/eddsa/client.key b/tarpc/examples/certs/eddsa/client.key deleted file mode 100644 index a407ea84..00000000 --- a/tarpc/examples/certs/eddsa/client.key +++ /dev/null @@ -1,3 +0,0 @@ ------BEGIN PRIVATE KEY----- -MC4CAQAwBQYDK2VwBCIEIIJX9ThTHpVS1SNZb6HP4myg4fRInIVGunTRdgnc+weH ------END PRIVATE KEY----- diff --git a/tarpc/examples/certs/eddsa/end.cert b/tarpc/examples/certs/eddsa/end.cert deleted file mode 100644 index b2eb159f..00000000 --- a/tarpc/examples/certs/eddsa/end.cert +++ /dev/null @@ -1,12 +0,0 @@ ------BEGIN CERTIFICATE----- -MIIBuDCCAWqgAwIBAgICAcgwBQYDK2VwMC4xLDAqBgNVBAMMI3Bvbnl0b3duIEVk -RFNBIGxldmVsIDIgaW50ZXJtZWRpYXRlMB4XDTIzMDMxNzEwMTEwNFoXDTI4MDkw -NjEwMTEwNFowGTEXMBUGA1UEAwwOdGVzdHNlcnZlci5jb20wKjAFBgMrZXADIQDc -RLl3/N2tPoWnzBV3noVn/oheEl8IUtiY11Vg/QXTUKOBwDCBvTAMBgNVHRMBAf8E -AjAAMAsGA1UdDwQEAwIGwDAdBgNVHQ4EFgQUk7U2mnxedNWBAH84BsNy5si3ZQow -RAYDVR0jBD0wO4AUciXt8cpY/r+XwUR8jZgQgzegfEqhIKQeMBwxGjAYBgNVBAMM -EXBvbnl0b3duIEVkRFNBIENBggF7MDsGA1UdEQQ0MDKCDnRlc3RzZXJ2ZXIuY29t -ghVzZWNvbmQudGVzdHNlcnZlci5jb22CCWxvY2FsaG9zdDAFBgMrZXADQQCFWIcF -9FiztCuUNzgXDNu5kshuflt0RjkjWpGlWzQjGoYM2IvYhNVPeqnCiY92gqwDSBtq -amD2TBup4eNUCsQB ------END CERTIFICATE----- diff --git a/tarpc/examples/certs/eddsa/end.chain b/tarpc/examples/certs/eddsa/end.chain deleted file mode 100644 index cd760dc2..00000000 --- a/tarpc/examples/certs/eddsa/end.chain +++ /dev/null @@ -1,19 +0,0 @@ ------BEGIN CERTIFICATE----- -MIIBeDCCASqgAwIBAgIBezAFBgMrZXAwHDEaMBgGA1UEAwwRcG9ueXRvd24gRWRE -U0EgQ0EwHhcNMjMwMzE3MTAxMTA0WhcNMzMwMzE0MTAxMTA0WjAuMSwwKgYDVQQD -DCNwb255dG93biBFZERTQSBsZXZlbCAyIGludGVybWVkaWF0ZTAqMAUGAytlcAMh -AEFsAexz4x2R4k4+PnTbvRVn0r3F/qw/zVnNBxfGcoEpo38wfTAdBgNVHQ4EFgQU -ciXt8cpY/r+XwUR8jZgQgzegfEowIAYDVR0lAQH/BBYwFAYIKwYBBQUHAwEGCCsG -AQUFBwMCMAwGA1UdEwQFMAMBAf8wCwYDVR0PBAQDAgH+MB8GA1UdIwQYMBaAFKYU -oLdKeY7mp7QgMZKrkVtSWYBKMAUGAytlcANBAHVpNpCV8nu4fkH3Smikx5A9qtHc -zgLIyp+wrF1a4YSa6sfTvuQmJd5aF23OXgq5grCOPXtdpHO50Mx5Qy74zQg= ------END CERTIFICATE----- ------BEGIN CERTIFICATE----- -MIIBTDCB/6ADAgECAhRZLuF0TWjDs/31OO8VeKHkNIJQaDAFBgMrZXAwHDEaMBgG -A1UEAwwRcG9ueXRvd24gRWREU0EgQ0EwHhcNMjMwMzE3MTAxMTA0WhcNMzMwMzE0 -MTAxMTA0WjAcMRowGAYDVQQDDBFwb255dG93biBFZERTQSBDQTAqMAUGAytlcAMh -ABRPZ4TiuBE8CqAFByZvqpMo/unjnnryfG2AkkWGXpa3o1MwUTAdBgNVHQ4EFgQU -phSgt0p5juantCAxkquRW1JZgEowHwYDVR0jBBgwFoAUphSgt0p5juantCAxkquR -W1JZgEowDwYDVR0TAQH/BAUwAwEB/zAFBgMrZXADQQB29o8erJA0/a8/xOHilOCC -t/s5wPHHnS5NSKx/m2N2nRn3zPxEnETlrAmGulJoeKOx8OblwmPi9rBT2K+QY2UB ------END CERTIFICATE----- diff --git a/tarpc/examples/certs/eddsa/end.key b/tarpc/examples/certs/eddsa/end.key deleted file mode 100644 index f5541b32..00000000 --- a/tarpc/examples/certs/eddsa/end.key +++ /dev/null @@ -1,3 +0,0 @@ ------BEGIN PRIVATE KEY----- -MC4CAQAwBQYDK2VwBCIEIMU6xGVe8JTpZ3bN/wajHfw6pEHt0Rd7wPBxds9eEFy2 ------END PRIVATE KEY----- diff --git a/tarpc/examples/compression.rs b/tarpc/examples/compression.rs index 942fdc8a..40084978 100644 --- a/tarpc/examples/compression.rs +++ b/tarpc/examples/compression.rs @@ -1,5 +1,11 @@ +// Copyright 2022 Google LLC +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + use flate2::{read::DeflateDecoder, write::DeflateEncoder, Compression}; -use futures::{Sink, SinkExt, Stream, StreamExt, TryStreamExt}; +use futures::{prelude::*, Sink, SinkExt, Stream, StreamExt, TryStreamExt}; use serde::{Deserialize, Serialize}; use serde_bytes::ByteBuf; use std::{io, io::Read, io::Write}; @@ -99,13 +105,16 @@ pub trait World { #[derive(Clone, Debug)] struct HelloServer; -#[tarpc::server] impl World for HelloServer { - async fn hello(self, _: context::Context, name: String) -> String { + async fn hello(self, _: &mut context::Context, name: String) -> String { format!("Hey, {name}!") } } +async fn spawn(fut: impl Future + Send + 'static) { + tokio::spawn(fut); +} + #[tokio::main] async fn main() -> anyhow::Result<()> { let mut incoming = tcp::listen("localhost:0", Bincode::default).await?; @@ -114,6 +123,7 @@ async fn main() -> anyhow::Result<()> { let transport = incoming.next().await.unwrap().unwrap(); BaseChannel::with_defaults(add_compression(transport)) .execute(HelloServer.serve()) + .for_each(spawn) .await; }); diff --git a/tarpc/examples/custom_transport.rs b/tarpc/examples/custom_transport.rs index e7e2ce3d..07686611 100644 --- a/tarpc/examples/custom_transport.rs +++ b/tarpc/examples/custom_transport.rs @@ -1,3 +1,10 @@ +// Copyright 2022 Google LLC +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +use futures::prelude::*; use tarpc::context::Context; use tarpc::serde_transport as transport; use tarpc::server::{BaseChannel, Channel}; @@ -13,9 +20,8 @@ pub trait PingService { #[derive(Clone)] struct Service; -#[tarpc::server] impl PingService for Service { - async fn ping(self, _: Context) {} + async fn ping(self, _: &mut Context) {} } #[tokio::main] @@ -26,13 +32,18 @@ async fn main() -> anyhow::Result<()> { let listener = UnixListener::bind(bind_addr).unwrap(); let codec_builder = LengthDelimitedCodec::builder(); + async fn spawn(fut: impl Future + Send + 'static) { + tokio::spawn(fut); + } tokio::spawn(async move { loop { let (conn, _addr) = listener.accept().await.unwrap(); let framed = codec_builder.new_framed(conn); let transport = transport::new(framed, Bincode::default()); - let fut = BaseChannel::with_defaults(transport).execute(Service.serve()); + let fut = BaseChannel::with_defaults(transport) + .execute(Service.serve()) + .for_each(spawn); tokio::spawn(fut); } }); diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index 910ab535..8e80a912 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -3,35 +3,6 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. - -/// - The PubSub server sets up TCP listeners on 2 ports, the "subscriber" port and the "publisher" -/// port. Because both publishers and subscribers initiate their connections to the PubSub -/// server, the server requires no prior knowledge of either publishers or subscribers. -/// -/// - Subscribers connect to the server on the server's "subscriber" port. Once a connection is -/// established, the server acts as the client of the Subscriber service, initially requesting -/// the topics the subscriber is interested in, and subsequently sending topical messages to the -/// subscriber. -/// -/// - Publishers connect to the server on the "publisher" port and, once connected, they send -/// topical messages via Publisher service to the server. The server then broadcasts each -/// messages to all clients subscribed to the topic of that message. -/// -/// Subscriber Publisher PubSub Server -/// T1 | | | -/// T2 |-----Connect------------------------------------------------------>| -/// T3 | | | -/// T2 |<-------------------------------------------------------Topics-----| -/// T2 |-----(OK) Topics-------------------------------------------------->| -/// T3 | | | -/// T4 | |-----Connect-------------------->| -/// T5 | | | -/// T6 | |-----Publish-------------------->| -/// T7 | | | -/// T8 |<------------------------------------------------------Receive-----| -/// T9 |-----(OK) Receive------------------------------------------------->| -/// T10 | | | -/// T11 | |<--------------(OK) Publish------| use anyhow::anyhow; use futures::{ channel::oneshot, @@ -79,13 +50,12 @@ struct Subscriber { topics: Vec, } -#[tarpc::server] impl subscriber::Subscriber for Subscriber { - async fn topics(self, _: context::Context) -> Vec { + async fn topics(self, _: &mut context::Context) -> Vec { self.topics.clone() } - async fn receive(self, _: context::Context, topic: String, message: String) { + async fn receive(self, _: &mut context::Context, topic: String, message: String) { info!(local_addr = %self.local_addr, %topic, %message, "ReceivedMessage") } } @@ -117,7 +87,8 @@ impl Subscriber { )) } }; - let (handler, abort_handle) = future::abortable(handler.execute(subscriber.serve())); + let (handler, abort_handle) = + future::abortable(handler.execute(subscriber.serve()).for_each(spawn)); tokio::spawn(async move { match handler.await { Ok(()) | Err(future::Aborted) => info!(?local_addr, "subscriber shutdown."), @@ -143,6 +114,10 @@ struct PublisherAddrs { subscriptions: SocketAddr, } +async fn spawn(fut: impl Future + Send + 'static) { + tokio::spawn(fut); +} + impl Publisher { async fn start(self) -> io::Result { let mut connecting_publishers = tcp::listen("localhost:0", Json::default).await?; @@ -162,6 +137,7 @@ impl Publisher { server::BaseChannel::with_defaults(publisher) .execute(self.serve()) + .for_each(spawn) .await }); @@ -257,9 +233,8 @@ impl Publisher { } } -#[tarpc::server] impl publisher::Publisher for Publisher { - async fn publish(self, _: context::Context, topic: String, message: String) { + async fn publish(self, _: &mut context::Context, topic: String, message: String) { info!("received message to publish."); let mut subscribers = match self.subscriptions.read().unwrap().get(&topic) { None => return, @@ -282,7 +257,7 @@ impl publisher::Publisher for Publisher { /// Initializes an OpenTelemetry tracing subscriber with a Jaeger backend. fn init_tracing(service_name: &str) -> anyhow::Result<()> { env::set_var("OTEL_BSP_MAX_EXPORT_BATCH_SIZE", "12"); - let tracer = opentelemetry_jaeger::new_agent_pipeline() + let tracer = opentelemetry_jaeger::new_pipeline() .with_service_name(service_name) .with_max_packet_size(2usize.pow(13)) .install_batch(opentelemetry::runtime::Tokio)?; diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index 80792314..dcae3943 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -3,8 +3,7 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. - -use futures::future::{self, Ready}; +use futures::prelude::*; use tarpc::{ client, context, server::{self, Channel}, @@ -23,22 +22,21 @@ pub trait World { struct HelloServer; impl World for HelloServer { - // Each defined rpc generates two items in the trait, a fn that serves the RPC, and - // an associated type representing the future output by the fn. - - type HelloFut = Ready; - - fn hello(self, _: context::Context, name: String) -> Self::HelloFut { - future::ready(format!("Hello, {name}!")) + async fn hello(self, _: &mut context::Context, name: String) -> String { + format!("Hello, {name}!") } } +async fn spawn(fut: impl Future + Send + 'static) { + tokio::spawn(fut); +} + #[tokio::main] async fn main() -> anyhow::Result<()> { let (client_transport, server_transport) = tarpc::transport::channel::unbounded(); let server = server::BaseChannel::with_defaults(server_transport); - tokio::spawn(server.execute(HelloServer.serve())); + tokio::spawn(server.execute(HelloServer.serve()).for_each(spawn)); // WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new` // that takes a config and any Transport as input. diff --git a/tarpc/examples/tls_over_tcp.rs b/tarpc/examples/tls_over_tcp.rs deleted file mode 100644 index 92d76c98..00000000 --- a/tarpc/examples/tls_over_tcp.rs +++ /dev/null @@ -1,152 +0,0 @@ -use rustls_pemfile::certs; -use std::io::{BufReader, Cursor}; -use std::net::{IpAddr, Ipv4Addr}; -use tokio_rustls::rustls::server::AllowAnyAuthenticatedClient; - -use std::sync::Arc; -use tokio::net::TcpListener; -use tokio::net::TcpStream; -use tokio_rustls::rustls::{self, Certificate, OwnedTrustAnchor, RootCertStore}; -use tokio_rustls::{webpki, TlsAcceptor, TlsConnector}; - -use tarpc::context::Context; -use tarpc::serde_transport as transport; -use tarpc::server::{BaseChannel, Channel}; -use tarpc::tokio_serde::formats::Bincode; -use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec; - -#[tarpc::service] -pub trait PingService { - async fn ping() -> String; -} - -#[derive(Clone)] -struct Service; - -#[tarpc::server] -impl PingService for Service { - async fn ping(self, _: Context) -> String { - "🔒".to_owned() - } -} - -// certs were generated with openssl 3 https://github.com/rustls/rustls/tree/main/test-ca -// used on client-side for server tls -const END_CHAIN: &[u8] = include_bytes!("certs/eddsa/end.chain"); -// used on client-side for client-auth -const CLIENT_PRIVATEKEY_CLIENT_AUTH: &str = include_str!("certs/eddsa/client.key"); -const CLIENT_CERT_CLIENT_AUTH: &str = include_str!("certs/eddsa/client.cert"); - -// used on server-side for server tls -const END_CERT: &str = include_str!("certs/eddsa/end.cert"); -const END_PRIVATEKEY: &str = include_str!("certs/eddsa/end.key"); -// used on server-side for client-auth -const CLIENT_CHAIN_CLIENT_AUTH: &str = include_str!("certs/eddsa/client.chain"); - -pub fn load_private_key(key: &str) -> rustls::PrivateKey { - let mut reader = BufReader::new(Cursor::new(key)); - loop { - match rustls_pemfile::read_one(&mut reader).expect("cannot parse private key .pem file") { - Some(rustls_pemfile::Item::RSAKey(key)) => return rustls::PrivateKey(key), - Some(rustls_pemfile::Item::PKCS8Key(key)) => return rustls::PrivateKey(key), - Some(rustls_pemfile::Item::ECKey(key)) => return rustls::PrivateKey(key), - None => break, - _ => {} - } - } - panic!("no keys found in {:?} (encrypted keys not supported)", key); -} - -#[tokio::main] -async fn main() -> anyhow::Result<()> { - // -------------------- start here to setup tls tcp tokio stream -------------------------- - // ref certs and loading from: https://github.com/tokio-rs/tls/blob/master/tokio-rustls/tests/test.rs - // ref basic tls server setup from: https://github.com/tokio-rs/tls/blob/master/tokio-rustls/examples/server/src/main.rs - let cert = certs(&mut BufReader::new(Cursor::new(END_CERT))) - .unwrap() - .into_iter() - .map(rustls::Certificate) - .collect(); - let key = load_private_key(END_PRIVATEKEY); - let server_addr = (IpAddr::V4(Ipv4Addr::LOCALHOST), 5000); - - // ------------- server side client_auth cert loading start - let roots: Vec = certs(&mut BufReader::new(Cursor::new(CLIENT_CHAIN_CLIENT_AUTH))) - .unwrap() - .into_iter() - .map(rustls::Certificate) - .collect(); - let mut client_auth_roots = RootCertStore::empty(); - for root in roots { - client_auth_roots.add(&root).unwrap(); - } - let client_auth = AllowAnyAuthenticatedClient::new(client_auth_roots); - // ------------- server side client_auth cert loading end - - let config = rustls::ServerConfig::builder() - .with_safe_defaults() - .with_client_cert_verifier(client_auth) // use .with_no_client_auth() instead if you don't want client-auth - .with_single_cert(cert, key) - .unwrap(); - let acceptor = TlsAcceptor::from(Arc::new(config)); - let listener = TcpListener::bind(&server_addr).await.unwrap(); - let codec_builder = LengthDelimitedCodec::builder(); - - // ref ./custom_transport.rs server side - tokio::spawn(async move { - loop { - let (stream, _peer_addr) = listener.accept().await.unwrap(); - let acceptor = acceptor.clone(); - let tls_stream = acceptor.accept(stream).await.unwrap(); - let framed = codec_builder.new_framed(tls_stream); - - let transport = transport::new(framed, Bincode::default()); - - let fut = BaseChannel::with_defaults(transport).execute(Service.serve()); - tokio::spawn(fut); - } - }); - - // ---------------------- client connection --------------------- - // cert loading from: https://github.com/tokio-rs/tls/blob/357bc562483dcf04c1f8d08bd1a831b144bf7d4c/tokio-rustls/tests/test.rs#L113 - // tls client connection from https://github.com/tokio-rs/tls/blob/master/tokio-rustls/examples/client/src/main.rs - let chain = certs(&mut std::io::Cursor::new(END_CHAIN)).unwrap(); - let mut root_store = rustls::RootCertStore::empty(); - root_store.add_server_trust_anchors(chain.iter().map(|cert| { - let ta = webpki::TrustAnchor::try_from_cert_der(&cert[..]).unwrap(); - OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject, - ta.spki, - ta.name_constraints, - ) - })); - - let client_auth_private_key = load_private_key(CLIENT_PRIVATEKEY_CLIENT_AUTH); - let client_auth_certs: Vec = - certs(&mut BufReader::new(Cursor::new(CLIENT_CERT_CLIENT_AUTH))) - .unwrap() - .into_iter() - .map(rustls::Certificate) - .collect(); - - let config = rustls::ClientConfig::builder() - .with_safe_defaults() - .with_root_certificates(root_store) - .with_single_cert(client_auth_certs, client_auth_private_key)?; // use .with_no_client_auth() instead if you don't want client-auth - - let domain = rustls::ServerName::try_from("localhost")?; - let connector = TlsConnector::from(Arc::new(config)); - - let stream = TcpStream::connect(server_addr).await?; - let stream = connector.connect(domain, stream).await?; - - let transport = transport::new(codec_builder.new_framed(stream), Bincode::default()); - let answer = PingServiceClient::new(Default::default(), transport) - .spawn() - .ping(tarpc::context::current()) - .await?; - - println!("ping answer: {answer}"); - - Ok(()) -} diff --git a/tarpc/examples/tracing.rs b/tarpc/examples/tracing.rs index 27561468..78846e19 100644 --- a/tarpc/examples/tracing.rs +++ b/tarpc/examples/tracing.rs @@ -3,14 +3,33 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. - -use crate::{add::Add as AddService, double::Double as DoubleService}; +use crate::{ + add::{Add as AddService, AddStub}, + double::Double as DoubleService, +}; use futures::{future, prelude::*}; +use std::{ + io, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, +}; use tarpc::{ - client, context, - server::{incoming::Incoming, BaseChannel}, + client::{ + self, + stub::{load_balance, retry}, + RpcError, + }, + context, serde_transport, + server::{ + incoming::{spawn_incoming, Incoming}, + BaseChannel, Serve, + }, tokio_serde::formats::Json, + ClientMessage, Response, ServerError, Transport, }; +use tokio::net::TcpStream; use tracing_subscriber::prelude::*; pub mod add { @@ -32,21 +51,23 @@ pub mod double { #[derive(Clone)] struct AddServer; -#[tarpc::server] impl AddService for AddServer { - async fn add(self, _: context::Context, x: i32, y: i32) -> i32 { + async fn add(self, _: &mut context::Context, x: i32, y: i32) -> i32 { x + y } } #[derive(Clone)] -struct DoubleServer { - add_client: add::AddClient, +struct DoubleServer { + add_client: add::AddClient, } -#[tarpc::server] -impl DoubleService for DoubleServer { - async fn double(self, _: context::Context, x: i32) -> Result { +impl DoubleService for DoubleServer +where + Stub: AddStub + Clone + Send + Sync + 'static, + for<'a> Stub::RespFut<'a>: Send, +{ + async fn double(self, _: &mut context::Context, x: i32) -> Result { self.add_client .add(context::current(), x, x) .await @@ -55,7 +76,7 @@ impl DoubleService for DoubleServer { } fn init_tracing(service_name: &str) -> anyhow::Result<()> { - let tracer = opentelemetry_jaeger::new_agent_pipeline() + let tracer = opentelemetry_jaeger::new_pipeline() .with_service_name(service_name) .with_auto_split_batch(true) .with_max_packet_size(2usize.pow(13)) @@ -70,32 +91,86 @@ fn init_tracing(service_name: &str) -> anyhow::Result<()> { Ok(()) } +async fn listen_on_random_port() -> anyhow::Result<( + impl Stream>>, + std::net::SocketAddr, +)> +where + Item: for<'de> serde::Deserialize<'de>, + SinkItem: serde::Serialize, +{ + let listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default) + .await? + .filter_map(|r| future::ready(r.ok())) + .take(1); + let addr = listener.get_ref().get_ref().local_addr(); + Ok((listener, addr)) +} + +fn make_stub( + backends: [impl Transport>, Response> + Send + Sync + 'static; N], +) -> retry::Retry< + impl Fn(&Result, u32) -> bool + Clone, + load_balance::RoundRobin, Resp>>, +> +where + Req: Send + Sync + 'static, + Resp: Send + Sync + 'static, +{ + let stub = load_balance::RoundRobin::new( + backends + .into_iter() + .map(|transport| tarpc::client::new(client::Config::default(), transport).spawn()) + .collect(), + ); + let stub = retry::Retry::new(stub, |resp, attempts| { + if let Err(e) = resp { + tracing::warn!("Got an error: {e:?}"); + attempts < 3 + } else { + false + } + }); + stub +} + #[tokio::main] async fn main() -> anyhow::Result<()> { init_tracing("tarpc_tracing_example")?; - let add_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default) - .await? - .filter_map(|r| future::ready(r.ok())); - let addr = add_listener.get_ref().local_addr(); - let add_server = add_listener - .map(BaseChannel::with_defaults) - .take(1) - .execute(AddServer.serve()); - tokio::spawn(add_server); - - let to_add_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?; - let add_client = add::AddClient::new(client::Config::default(), to_add_server).spawn(); + let (add_listener1, addr1) = listen_on_random_port().await?; + let (add_listener2, addr2) = listen_on_random_port().await?; + let something_bad_happened = Arc::new(AtomicBool::new(false)); + let server = AddServer.serve().before(move |_: &mut _, _: &_| { + let something_bad_happened = something_bad_happened.clone(); + async move { + if something_bad_happened.fetch_xor(true, Ordering::Relaxed) { + Err(ServerError::new( + io::ErrorKind::NotFound, + "Gamma Ray!".into(), + )) + } else { + Ok(()) + } + } + }); + let add_server = add_listener1 + .chain(add_listener2) + .map(BaseChannel::with_defaults); + tokio::spawn(spawn_incoming(add_server.execute(server))); + + let add_client = add::AddClient::from(make_stub([ + tarpc::serde_transport::tcp::connect(addr1, Json::default).await?, + tarpc::serde_transport::tcp::connect(addr2, Json::default).await?, + ])); let double_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default) .await? .filter_map(|r| future::ready(r.ok())); let addr = double_listener.get_ref().local_addr(); - let double_server = double_listener - .map(BaseChannel::with_defaults) - .take(1) - .execute(DoubleServer { add_client }.serve()); - tokio::spawn(double_server); + let double_server = double_listener.map(BaseChannel::with_defaults).take(1); + let server = DoubleServer { add_client }.serve(); + tokio::spawn(spawn_incoming(double_server.execute(server))); let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?; let double_client = @@ -103,7 +178,7 @@ async fn main() -> anyhow::Result<()> { let ctx = context::current(); for _ in 1..=5 { - tracing::info!("{:?}", double_client.double(ctx, 1).await?); + tracing::info!("{:?}", double_client.double(ctx.clone(), 1).await?); } opentelemetry::global::shutdown_tracer_provider(); diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index 109ee8ff..dba5833a 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -7,16 +7,18 @@ //! Provides a client that connects to a server and sends multiplexed requests. mod in_flight_requests; +pub mod stub; use crate::{ cancellations::{cancellations, CanceledRequests, RequestCancellation}, - context, trace, ChannelError, ClientMessage, Request, Response, ServerError, Transport, + context, trace, ClientMessage, Request, Response, ServerError, Transport, }; use futures::{prelude::*, ready, stream::Fuse, task::*}; -use in_flight_requests::InFlightRequests; +use in_flight_requests::{DeadlineExceededError, InFlightRequests}; use pin_project::pin_project; use std::{ convert::TryFrom, + error::Error, fmt, pin::Pin, sync::{ @@ -116,7 +118,7 @@ impl Channel { skip(self, ctx, request_name, request), fields( rpc.trace_id = tracing::field::Empty, - rpc.deadline = %humantime::format_rfc3339(ctx.deadline), + rpc.deadline = %humantime::format_rfc3339(*ctx.deadline), otel.kind = "client", otel.name = request_name) )] @@ -157,7 +159,7 @@ impl Channel { response_completion, }) .await - .map_err(|mpsc::error::SendError(_)| RpcError::Shutdown)?; + .map_err(|mpsc::error::SendError(_)| RpcError::Disconnected)?; response_guard.response().await } } @@ -165,7 +167,7 @@ impl Channel { /// A server response that is completed by request dispatch when the corresponding response /// arrives off the wire. struct ResponseGuard<'a, Resp> { - response: &'a mut oneshot::Receiver>, + response: &'a mut oneshot::Receiver, DeadlineExceededError>>, cancellation: &'a RequestCancellation, request_id: u64, cancel: bool, @@ -173,17 +175,12 @@ struct ResponseGuard<'a, Resp> { /// An error that can occur in the processing of an RPC. This is not request-specific errors but /// rather cross-cutting errors that can always occur. -#[derive(thiserror::Error, Debug)] +#[derive(thiserror::Error, Clone, Debug, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub enum RpcError { /// The client disconnected from the server. - #[error("the connection to the server was already shutdown")] - Shutdown, - /// The client failed to send the request. - #[error("the client failed to send the request")] - Send(#[source] Box), - /// An error occurred while waiting for the server response. - #[error("an error occurred while waiting for the server response")] - Receive(#[source] Arc), + #[error("the client disconnected from the server")] + Disconnected, /// The request exceeded its deadline. #[error("the request exceeded its deadline")] DeadlineExceeded, @@ -192,18 +189,24 @@ pub enum RpcError { Server(#[from] ServerError), } +impl From for RpcError { + fn from(_: DeadlineExceededError) -> Self { + RpcError::DeadlineExceeded + } +} + impl ResponseGuard<'_, Resp> { async fn response(mut self) -> Result { let response = (&mut self.response).await; // Cancel drop logic once a response has been received. self.cancel = false; match response { - Ok(response) => response, + Ok(resp) => Ok(resp?.message?), Err(oneshot::error::RecvError { .. }) => { // The oneshot is Canceled when the dispatch task ends. In that case, // there's nothing listening on the other side, so there's no point in // propagating cancellation. - Err(RpcError::Shutdown) + Err(RpcError::Disconnected) } } } @@ -240,6 +243,7 @@ where { let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer); let (cancellation, canceled_requests) = cancellations(); + let canceled_requests = canceled_requests; NewClient { client: Channel { @@ -271,18 +275,42 @@ pub struct RequestDispatch { /// Requests that were dropped. canceled_requests: CanceledRequests, /// Requests already written to the wire that haven't yet received responses. - in_flight_requests: InFlightRequests>, + in_flight_requests: InFlightRequests, /// Configures limits to prevent unlimited resource usage. config: Config, } +/// Critical errors that result in a Channel disconnecting. +#[derive(thiserror::Error, Debug)] +pub enum ChannelError +where + E: Error + Send + Sync + 'static, +{ + /// Could not read from the transport. + #[error("could not read from the transport")] + Read(#[source] E), + /// Could not ready the transport for writes. + #[error("could not ready the transport for writes")] + Ready(#[source] E), + /// Could not write to the transport. + #[error("could not write to the transport")] + Write(#[source] E), + /// Could not flush the transport. + #[error("could not flush the transport")] + Flush(#[source] E), + /// Could not close the write end of the transport. + #[error("could not close the write end of the transport")] + Close(#[source] E), + /// Could not poll expired requests. + #[error("could not poll expired requests")] + Timer(#[source] tokio::time::error::Error), +} + impl RequestDispatch where C: Transport, Response>, { - fn in_flight_requests<'a>( - self: &'a mut Pin<&mut Self>, - ) -> &'a mut InFlightRequests> { + fn in_flight_requests<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut InFlightRequests { self.as_mut().project().in_flight_requests } @@ -342,17 +370,7 @@ where ) -> Poll>>> { self.transport_pin_mut() .poll_next(cx) - .map_err(|e| { - let e = Arc::new(e); - for span in self - .in_flight_requests() - .complete_all_requests(|| Err(RpcError::Receive(e.clone()))) - { - let _entered = span.enter(); - tracing::info!("ReceiveError"); - } - ChannelError::Read(e) - }) + .map_err(ChannelError::Read) .map_ok(|response| { self.complete(response); }) @@ -382,10 +400,7 @@ where // Receiving Poll::Ready(None) when polling expired requests never indicates "Closed", // because there can temporarily be zero in-flight rquests. Therefore, there is no need to // track the status like is done with pending and cancelled requests. - if let Poll::Ready(Some(_)) = self - .in_flight_requests() - .poll_expired(cx, || Err(RpcError::DeadlineExceeded)) - { + if let Poll::Ready(Some(_)) = self.in_flight_requests().poll_expired(cx) { // Expired requests are considered complete; there is no compelling reason to send a // cancellation message to the server, since it will have already exhausted its // allotted processing time. @@ -496,29 +511,23 @@ where Some(dispatch_request) => dispatch_request, None => return Poll::Ready(None), }; - let _entered = span.enter(); + let entered = span.enter(); // poll_next_request only returns Ready if there is room to buffer another request. // Therefore, we can call write_request without fear of erroring due to a full // buffer. let request_id = request_id; let request = ClientMessage::Request(Request { - id: request_id, + request_id: request_id, message: request, - context: context::Context { - deadline: ctx.deadline, - trace_context: ctx.trace_context, - }, + context: ctx.clone(), }); + self.start_send(request)?; + tracing::info!("SendRequest"); + drop(entered); + self.in_flight_requests() - .insert_request(request_id, ctx, span.clone(), response_completion) + .insert_request(request_id, ctx, span, response_completion) .expect("Request IDs should be unique"); - match self.start_send(request) { - Ok(()) => tracing::info!("SendRequest"), - Err(e) => { - self.in_flight_requests() - .complete_request(request_id, Err(RpcError::Send(Box::new(e)))); - } - } Poll::Ready(Some(Ok(()))) } @@ -543,10 +552,7 @@ where /// Sends a server response to the client task that initiated the associated request. fn complete(mut self: Pin<&mut Self>, response: Response) -> bool { - self.in_flight_requests().complete_request( - response.request_id, - response.message.map_err(RpcError::Server), - ) + self.in_flight_requests().complete_request(response) } } @@ -595,37 +601,30 @@ struct DispatchRequest { pub span: Span, pub request_id: u64, pub request: Req, - pub response_completion: oneshot::Sender>, + pub response_completion: oneshot::Sender, DeadlineExceededError>>, } #[cfg(test)] mod tests { - use super::{ - cancellations, Channel, DispatchRequest, RequestDispatch, ResponseGuard, RpcError, - }; + use super::{cancellations, Channel, DispatchRequest, RequestDispatch, ResponseGuard}; use crate::{ - client::{in_flight_requests::InFlightRequests, Config}, - context::{self, current}, + client::{ + in_flight_requests::{DeadlineExceededError, InFlightRequests}, + Config, + }, + context, transport::{self, channel::UnboundedChannel}, - ChannelError, ClientMessage, Response, + ClientMessage, Response, }; use assert_matches::assert_matches; use futures::{prelude::*, task::*}; use std::{ convert::TryFrom, - fmt::Display, - marker::PhantomData, pin::Pin, - sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, - }, - }; - use thiserror::Error; - use tokio::sync::{ - mpsc::{self}, - oneshot, + sync::atomic::{AtomicUsize, Ordering}, + sync::Arc, }; + use tokio::sync::{mpsc, oneshot}; use tracing::Span; #[tokio::test] @@ -634,19 +633,22 @@ mod tests { let cx = &mut Context::from_waker(noop_waker_ref()); let (tx, mut rx) = oneshot::channel(); + let ctx = context::current(); + dispatch .in_flight_requests - .insert_request(0, context::current(), Span::current(), tx) + .insert_request(0, ctx.clone(), Span::current(), tx) .unwrap(); server_channel .send(Response { request_id: 0, + context: ctx, message: Ok("Resp".into()), }) .await .unwrap(); assert_matches!(dispatch.as_mut().poll(cx), Poll::Pending); - assert_matches!(rx.try_recv(), Ok(Ok(resp)) if resp == "Resp"); + assert_matches!(rx.try_recv(), Ok(Ok(Response { request_id: 0, message: Ok(resp), context: _ctx })) if resp == "Resp"); } #[tokio::test] @@ -670,6 +672,7 @@ mod tests { let (tx, mut response) = oneshot::channel(); tx.send(Ok(Response { request_id: 0, + context: context::current(), message: Ok("well done"), })) .unwrap(); @@ -720,6 +723,7 @@ mod tests { &mut server_channel, Response { request_id: 0, + context: context::current(), message: Ok("hello".into()), }, ) @@ -780,185 +784,6 @@ mod tests { assert!(dispatch.as_mut().poll_next_request(cx).is_pending()); } - #[tokio::test] - async fn test_shutdown_error() { - let _ = tracing_subscriber::fmt().with_test_writer().try_init(); - let (dispatch, mut channel, _) = set_up(); - let (tx, mut rx) = oneshot::channel(); - // send succeeds - let resp = send_request(&mut channel, "hi", tx, &mut rx).await; - drop(dispatch); - // error on receive - assert_matches!(resp.response().await, Err(RpcError::Shutdown)); - let (dispatch, channel, _) = set_up(); - drop(dispatch); - // error on send - let resp = channel - .call(current(), "test_request", "hi".to_string()) - .await; - assert_matches!(resp, Err(RpcError::Shutdown)); - } - - #[tokio::test] - async fn test_transport_error_write() { - let cause = TransportError::Write; - let (mut dispatch, mut channel, mut cx) = setup_always_err(cause); - let (tx, mut rx) = oneshot::channel(); - - let resp = send_request(&mut channel, "hi", tx, &mut rx).await; - assert!(dispatch.as_mut().poll(&mut cx).is_pending()); - let res = resp.response().await; - assert_matches!(res, Err(RpcError::Send(_))); - let client_error: anyhow::Error = res.unwrap_err().into(); - let mut chain = client_error.chain(); - chain.next(); // original RpcError - assert_eq!( - chain - .next() - .unwrap() - .downcast_ref::>(), - Some(&ChannelError::Write(cause)) - ); - assert_eq!( - client_error.root_cause().downcast_ref::(), - Some(&cause) - ); - } - - #[tokio::test] - async fn test_transport_error_read() { - let cause = TransportError::Read; - let (mut dispatch, mut channel, mut cx) = setup_always_err(cause); - let (tx, mut rx) = oneshot::channel(); - let resp = send_request(&mut channel, "hi", tx, &mut rx).await; - assert_eq!( - dispatch.as_mut().pump_write(&mut cx), - Poll::Ready(Some(Ok(()))) - ); - assert_eq!( - dispatch.as_mut().pump_read(&mut cx), - Poll::Ready(Some(Err(ChannelError::Read(Arc::new(cause))))) - ); - assert_matches!(resp.response().await, Err(RpcError::Receive(_))); - } - - #[tokio::test] - async fn test_transport_error_ready() { - let cause = TransportError::Ready; - let (mut dispatch, _, mut cx) = setup_always_err(cause); - assert_eq!( - dispatch.as_mut().poll(&mut cx), - Poll::Ready(Err(ChannelError::Ready(cause))) - ); - } - - #[tokio::test] - async fn test_transport_error_flush() { - let cause = TransportError::Flush; - let (mut dispatch, _, mut cx) = setup_always_err(cause); - assert_eq!( - dispatch.as_mut().poll(&mut cx), - Poll::Ready(Err(ChannelError::Flush(cause))) - ); - } - - #[tokio::test] - async fn test_transport_error_close() { - let cause = TransportError::Close; - let (mut dispatch, channel, mut cx) = setup_always_err(cause); - drop(channel); - assert_eq!( - dispatch.as_mut().poll(&mut cx), - Poll::Ready(Err(ChannelError::Close(cause))) - ); - } - - fn setup_always_err( - cause: TransportError, - ) -> ( - Pin>>>, - Channel, - Context<'static>, - ) { - let (to_dispatch, pending_requests) = mpsc::channel(1); - let (cancellation, canceled_requests) = cancellations(); - let transport: AlwaysErrorTransport = AlwaysErrorTransport(cause, PhantomData); - let dispatch = Box::pin(RequestDispatch:: { - transport: transport.fuse(), - pending_requests, - canceled_requests, - in_flight_requests: InFlightRequests::default(), - config: Config::default(), - }); - let channel = Channel { - to_dispatch, - cancellation, - next_request_id: Arc::new(AtomicUsize::new(0)), - }; - let cx = Context::from_waker(noop_waker_ref()); - (dispatch, channel, cx) - } - - struct AlwaysErrorTransport(TransportError, PhantomData); - - #[derive(Debug, Error, PartialEq, Eq, Clone, Copy)] - enum TransportError { - Read, - Ready, - Write, - Flush, - Close, - } - - impl Display for TransportError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str(&format!("{self:?}")) - } - } - - impl Sink for AlwaysErrorTransport { - type Error = TransportError; - fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { - match self.0 { - TransportError::Ready => Poll::Ready(Err(self.0)), - TransportError::Flush => Poll::Pending, - _ => Poll::Ready(Ok(())), - } - } - fn start_send(self: Pin<&mut Self>, _: S) -> Result<(), Self::Error> { - if matches!(self.0, TransportError::Write) { - Err(self.0) - } else { - Ok(()) - } - } - fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { - if matches!(self.0, TransportError::Flush) { - Poll::Ready(Err(self.0)) - } else { - Poll::Ready(Ok(())) - } - } - fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { - if matches!(self.0, TransportError::Close) { - Poll::Ready(Err(self.0)) - } else { - Poll::Ready(Ok(())) - } - } - } - - impl Stream for AlwaysErrorTransport { - type Item = Result, TransportError>; - fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { - if matches!(self.0, TransportError::Read) { - Poll::Ready(Some(Err(self.0))) - } else { - Poll::Pending - } - } - } - fn set_up() -> ( Pin< Box< @@ -998,8 +823,8 @@ mod tests { async fn send_request<'a>( channel: &'a mut Channel, request: &str, - response_completion: oneshot::Sender>, - response: &'a mut oneshot::Receiver>, + response_completion: oneshot::Sender, DeadlineExceededError>>, + response: &'a mut oneshot::Receiver, DeadlineExceededError>>, ) -> ResponseGuard<'a, String> { let request_id = u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); diff --git a/tarpc/src/client/in_flight_requests.rs b/tarpc/src/client/in_flight_requests.rs index cb69f680..a7e5fb53 100644 --- a/tarpc/src/client/in_flight_requests.rs +++ b/tarpc/src/client/in_flight_requests.rs @@ -1,6 +1,7 @@ use crate::{ context, util::{Compact, TimeUntil}, + Response, }; use fnv::FnvHashMap; use std::{ @@ -27,11 +28,17 @@ impl Default for InFlightRequests { } } +/// The request exceeded its deadline. +#[derive(thiserror::Error, Debug)] +#[non_exhaustive] +#[error("the request exceeded its deadline")] +pub struct DeadlineExceededError; + #[derive(Debug)] -struct RequestData { +struct RequestData { ctx: context::Context, span: Span, - response_completion: oneshot::Sender, + response_completion: oneshot::Sender, DeadlineExceededError>>, /// The key to remove the timer for the request's deadline. deadline_key: delay_queue::Key, } @@ -41,7 +48,7 @@ struct RequestData { #[derive(Debug)] pub struct AlreadyExistsError; -impl InFlightRequests { +impl InFlightRequests { /// Returns the number of in-flight requests. pub fn len(&self) -> usize { self.request_data.len() @@ -58,7 +65,7 @@ impl InFlightRequests { request_id: u64, ctx: context::Context, span: Span, - response_completion: oneshot::Sender, + response_completion: oneshot::Sender, DeadlineExceededError>>, ) -> Result<(), AlreadyExistsError> { match self.request_data.entry(request_id) { hash_map::Entry::Vacant(vacant) => { @@ -77,35 +84,25 @@ impl InFlightRequests { } /// Removes a request without aborting. Returns true iff the request was found. - pub fn complete_request(&mut self, request_id: u64, result: Res) -> bool { - if let Some(request_data) = self.request_data.remove(&request_id) { + pub fn complete_request(&mut self, response: Response) -> bool { + if let Some(request_data) = self.request_data.remove(&response.request_id) { let _entered = request_data.span.enter(); tracing::info!("ReceiveResponse"); self.request_data.compact(0.1); self.deadlines.remove(&request_data.deadline_key); - let _ = request_data.response_completion.send(result); + let _ = request_data.response_completion.send(Ok(response)); return true; } - tracing::debug!("No in-flight request found for request_id = {request_id}."); + tracing::debug!( + "No in-flight request found for request_id = {}.", + response.request_id + ); // If the response completion was absent, then the request was already canceled. false } - /// Completes all requests using the provided function. - /// Returns Spans for all completes requests. - pub fn complete_all_requests<'a>( - &'a mut self, - mut result: impl FnMut() -> Res + 'a, - ) -> impl Iterator + 'a { - self.deadlines.clear(); - self.request_data.drain().map(move |(_, request_data)| { - let _ = request_data.response_completion.send(result()); - request_data.span - }) - } - /// Cancels a request without completing (typically used when a request handle was dropped /// before the request completed). pub fn cancel_request(&mut self, request_id: u64) -> Option<(context::Context, Span)> { @@ -120,18 +117,16 @@ impl InFlightRequests { /// Yields a request that has expired, completing it with a TimedOut error. /// The caller should send cancellation messages for any yielded request ID. - pub fn poll_expired( - &mut self, - cx: &mut Context, - expired_error: impl Fn() -> Res, - ) -> Poll> { + pub fn poll_expired(&mut self, cx: &mut Context) -> Poll> { self.deadlines.poll_expired(cx).map(|expired| { let request_id = expired?.into_inner(); if let Some(request_data) = self.request_data.remove(&request_id) { let _entered = request_data.span.enter(); tracing::error!("DeadlineExceeded"); self.request_data.compact(0.1); - let _ = request_data.response_completion.send(expired_error()); + let _ = request_data + .response_completion + .send(Err(DeadlineExceededError)); } Some(request_id) }) diff --git a/tarpc/src/client/stub.rs b/tarpc/src/client/stub.rs new file mode 100644 index 00000000..894e2efb --- /dev/null +++ b/tarpc/src/client/stub.rs @@ -0,0 +1,57 @@ +//! Provides a Stub trait, implemented by types that can call remote services. + +use crate::{ + client::{Channel, RpcError}, + context, +}; +use futures::prelude::*; + +pub mod load_balance; +pub mod retry; + +#[cfg(test)] +mod mock; + +/// A connection to a remote service. +/// Calls the service with requests of type `Req` and receives responses of type `Resp`. +pub trait Stub { + /// The service request type. + type Req; + + /// The service response type. + type Resp; + + /// The type of the future returned by `Stub::call`. + type RespFut<'a>: Future> + where + Self: 'a, + Self::Req: 'a, + Self::Resp: 'a; + + /// Calls a remote service. + fn call<'a>( + &'a self, + ctx: context::Context, + request_name: &'static str, + request: Self::Req, + ) -> Self::RespFut<'a>; +} + +impl Stub for Channel { + type Req = Req; + type Resp = Resp; + type RespFut<'a> = RespFut<'a, Req, Resp> + where Self: 'a; + + fn call<'a>( + &'a self, + ctx: context::Context, + request_name: &'static str, + request: Req, + ) -> Self::RespFut<'a> { + Self::call(self, ctx, request_name, request) + } +} + +/// A type alias for a response future +pub type RespFut<'a, Req: 'a, Resp: 'a> = impl Future> + 'a; diff --git a/tarpc/src/client/stub/load_balance.rs b/tarpc/src/client/stub/load_balance.rs new file mode 100644 index 00000000..2fa67ec8 --- /dev/null +++ b/tarpc/src/client/stub/load_balance.rs @@ -0,0 +1,305 @@ +//! Provides load-balancing [Stubs](crate::client::stub::Stub). + +pub use consistent_hash::ConsistentHash; +pub use round_robin::RoundRobin; + +/// Provides a stub that load-balances with a simple round-robin strategy. +mod round_robin { + use crate::{ + client::{stub, RpcError}, + context, + }; + use cycle::AtomicCycle; + use futures::prelude::*; + + impl stub::Stub for RoundRobin + where + Stub: stub::Stub, + { + type Req = Stub::Req; + type Resp = Stub::Resp; + type RespFut<'a> = RespFut<'a, Stub> + where Self: 'a; + + fn call<'a>( + &'a self, + ctx: context::Context, + request_name: &'static str, + request: Self::Req, + ) -> Self::RespFut<'a> { + Self::call(self, ctx, request_name, request) + } + } + + pub type RespFut<'a, Stub: stub::Stub + 'a> = + impl Future> + 'a; + + /// A Stub that load-balances across backing stubs by round robin. + #[derive(Clone, Debug)] + pub struct RoundRobin { + stubs: AtomicCycle, + } + + impl RoundRobin + where + Stub: stub::Stub, + { + /// Returns a new RoundRobin stub. + pub fn new(stubs: Vec) -> Self { + Self { + stubs: AtomicCycle::new(stubs), + } + } + + async fn call( + &self, + ctx: context::Context, + request_name: &'static str, + request: Stub::Req, + ) -> Result { + let next = self.stubs.next(); + next.call(ctx, request_name, request).await + } + } + + mod cycle { + use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }; + + /// Cycles endlessly and atomically over a collection of elements of type T. + #[derive(Clone, Debug)] + pub struct AtomicCycle(Arc>); + + #[derive(Debug)] + struct State { + elements: Vec, + next: AtomicUsize, + } + + impl AtomicCycle { + pub fn new(elements: Vec) -> Self { + Self(Arc::new(State { + elements, + next: Default::default(), + })) + } + + pub fn next(&self) -> &T { + self.0.next() + } + } + + impl State { + pub fn next(&self) -> &T { + let next = self.next.fetch_add(1, Ordering::Relaxed); + &self.elements[next % self.elements.len()] + } + } + + #[test] + fn test_cycle() { + let cycle = AtomicCycle::new(vec![1, 2, 3]); + assert_eq!(cycle.next(), &1); + assert_eq!(cycle.next(), &2); + assert_eq!(cycle.next(), &3); + assert_eq!(cycle.next(), &1); + } + } +} + +/// Provides a stub that load-balances with a consistent hashing strategy. +/// +/// Each request is hashed, then mapped to a stub based on the hash. Equivalent requests will use +/// the same stub. +mod consistent_hash { + use crate::{ + client::{stub, RpcError}, + context, + }; + use futures::prelude::*; + use std::{ + collections::hash_map::RandomState, + hash::{BuildHasher, Hash, Hasher}, + num::TryFromIntError, + }; + + impl stub::Stub for ConsistentHash + where + Stub: stub::Stub, + Stub::Req: Hash, + { + type Req = Stub::Req; + type Resp = Stub::Resp; + type RespFut<'a> = RespFut<'a, Stub> + where Self: 'a; + + fn call<'a>( + &'a self, + ctx: context::Context, + request_name: &'static str, + request: Self::Req, + ) -> Self::RespFut<'a> { + Self::call(self, ctx, request_name, request) + } + } + + pub type RespFut<'a, Stub: stub::Stub + 'a> = + impl Future> + 'a; + + /// A Stub that load-balances across backing stubs by round robin. + #[derive(Clone, Debug)] + pub struct ConsistentHash { + stubs: Vec, + stubs_len: u64, + hasher: S, + } + + impl ConsistentHash + where + Stub: stub::Stub, + Stub::Req: Hash, + { + /// Returns a new RoundRobin stub. + /// Returns an err if the length of `stubs` overflows a u64. + pub fn new(stubs: Vec) -> Result { + Ok(Self { + stubs_len: stubs.len().try_into()?, + stubs, + hasher: RandomState::new(), + }) + } + } + + impl ConsistentHash + where + Stub: stub::Stub, + Stub::Req: Hash, + S: BuildHasher, + { + /// Returns a new RoundRobin stub. + /// Returns an err if the length of `stubs` overflows a u64. + pub fn with_hasher(stubs: Vec, hasher: S) -> Result { + Ok(Self { + stubs_len: stubs.len().try_into()?, + stubs, + hasher, + }) + } + + async fn call( + &self, + ctx: context::Context, + request_name: &'static str, + request: Stub::Req, + ) -> Result { + let index = usize::try_from(self.hash_request(&request) % self.stubs_len).expect( + "invariant broken: stubs_len is not larger than a usize, \ + so the hash modulo stubs_len should always fit in a usize", + ); + let next = &self.stubs[index]; + next.call(ctx, request_name, request).await + } + + fn hash_request(&self, req: &Stub::Req) -> u64 { + let mut hasher = self.hasher.build_hasher(); + req.hash(&mut hasher); + hasher.finish() + } + } + + #[cfg(test)] + mod tests { + use super::ConsistentHash; + use crate::{client::stub::mock::Mock, context}; + use std::{ + collections::HashMap, + hash::{BuildHasher, Hash, Hasher}, + rc::Rc, + }; + + #[tokio::test] + async fn test() -> anyhow::Result<()> { + let stub = ConsistentHash::with_hasher( + vec![ + // For easier reading of the assertions made in this test, each Mock's response + // value is equal to a hash value that should map to its index: 3 % 3 = 0, 1 % + // 3 = 1, etc. + Mock::new([('a', 3), ('b', 3), ('c', 3)]), + Mock::new([('a', 1), ('b', 1), ('c', 1)]), + Mock::new([('a', 2), ('b', 2), ('c', 2)]), + ], + FakeHasherBuilder::new([('a', 1), ('b', 2), ('c', 3)]), + )?; + + for _ in 0..2 { + let resp = stub.call(context::current(), "", 'a').await?; + assert_eq!(resp, 1); + + let resp = stub.call(context::current(), "", 'b').await?; + assert_eq!(resp, 2); + + let resp = stub.call(context::current(), "", 'c').await?; + assert_eq!(resp, 3); + } + + Ok(()) + } + + struct HashRecorder(Vec); + impl Hasher for HashRecorder { + fn write(&mut self, bytes: &[u8]) { + self.0 = Vec::from(bytes); + } + fn finish(&self) -> u64 { + 0 + } + } + + struct FakeHasherBuilder { + recorded_hashes: Rc, u64>>, + } + + struct FakeHasher { + recorded_hashes: Rc, u64>>, + output: u64, + } + + impl BuildHasher for FakeHasherBuilder { + type Hasher = FakeHasher; + + fn build_hasher(&self) -> Self::Hasher { + FakeHasher { + recorded_hashes: self.recorded_hashes.clone(), + output: 0, + } + } + } + + impl FakeHasherBuilder { + fn new(fake_hashes: [(T, u64); N]) -> Self { + let mut recorded_hashes = HashMap::new(); + for (to_hash, fake_hash) in fake_hashes { + let mut recorder = HashRecorder(vec![]); + to_hash.hash(&mut recorder); + recorded_hashes.insert(recorder.0, fake_hash); + } + Self { + recorded_hashes: Rc::new(recorded_hashes), + } + } + } + + impl Hasher for FakeHasher { + fn write(&mut self, bytes: &[u8]) { + if let Some(hash) = self.recorded_hashes.get(bytes) { + self.output = *hash; + } + } + fn finish(&self) -> u64 { + self.output + } + } + } +} diff --git a/tarpc/src/client/stub/mock.rs b/tarpc/src/client/stub/mock.rs new file mode 100644 index 00000000..99a54422 --- /dev/null +++ b/tarpc/src/client/stub/mock.rs @@ -0,0 +1,54 @@ +use crate::{ + client::{stub::Stub, RpcError}, + context, ServerError, +}; +use futures::future; +use std::{collections::HashMap, hash::Hash, io}; + +/// A mock stub that returns user-specified responses. +pub struct Mock { + responses: HashMap, +} + +impl Mock +where + Req: Eq + Hash, +{ + /// Returns a new mock, mocking the specified (request, response) pairs. + pub fn new(responses: [(Req, Resp); N]) -> Self { + Self { + responses: HashMap::from(responses), + } + } +} + +impl Stub for Mock +where + Req: Eq + Hash, + Resp: Clone, +{ + type Req = Req; + type Resp = Resp; + type RespFut<'a> = future::Ready> + where Self: 'a; + + fn call<'a>( + &'a self, + _: context::Context, + _: &'static str, + request: Self::Req, + ) -> Self::RespFut<'a> { + future::ready( + self.responses + .get(&request) + .cloned() + .map(Ok) + .unwrap_or_else(|| { + Err(RpcError::Server(ServerError { + kind: io::ErrorKind::NotFound, + detail: "mock (request, response) entry not found".into(), + })) + }), + ) + } +} diff --git a/tarpc/src/client/stub/retry.rs b/tarpc/src/client/stub/retry.rs new file mode 100644 index 00000000..138a1922 --- /dev/null +++ b/tarpc/src/client/stub/retry.rs @@ -0,0 +1,76 @@ +//! Provides a stub that retries requests based on response contents.. + +use crate::{ + client::{stub, RpcError}, + context, +}; +use futures::prelude::*; +use std::sync::Arc; + +impl stub::Stub for Retry +where + Stub: stub::Stub>, + F: Fn(&Result, u32) -> bool, +{ + type Req = Req; + type Resp = Stub::Resp; + type RespFut<'a> = RespFut<'a, Stub, Req, F> + where Self: 'a, + Self::Req: 'a; + + fn call<'a>( + &'a self, + ctx: context::Context, + request_name: &'static str, + request: Self::Req, + ) -> Self::RespFut<'a> { + Self::call(self, ctx, request_name, request) + } +} + +/// A type alias for a response future +pub type RespFut<'a, Stub: stub::Stub + 'a, Req: 'a, F: 'a> = + impl Future> + 'a; + +/// A Stub that retries requests based on response contents. +/// Note: to use this stub with Serde serialization, the "rc" feature of Serde needs to be enabled. +#[derive(Clone, Debug)] +pub struct Retry { + should_retry: F, + stub: Stub, +} + +impl Retry +where + Stub: stub::Stub>, + F: Fn(&Result, u32) -> bool, +{ + /// Creates a new Retry stub that delegates calls to the underlying `stub`. + pub fn new(stub: Stub, should_retry: F) -> Self { + Self { stub, should_retry } + } + + async fn call<'a, 'b>( + &'a self, + ctx: context::Context, + request_name: &'static str, + request: Req, + ) -> Result + where + Req: 'b, + { + let request = Arc::new(request); + for i in 1.. { + let result = self + .stub + .call(ctx.clone(), request_name, Arc::clone(&request)) + .await; + if (self.should_retry)(&result, i) { + tracing::trace!("Retrying on attempt {i}"); + continue; + } + return result; + } + unreachable!("Wow, that was a lot of attempts!"); + } +} diff --git a/tarpc/src/context.rs b/tarpc/src/context.rs index e3a6aff1..2d90b33c 100644 --- a/tarpc/src/context.rs +++ b/tarpc/src/context.rs @@ -14,6 +14,9 @@ use std::{ convert::TryFrom, time::{Duration, SystemTime}, }; +use std::hash::{Hash, Hasher}; +use std::ops::{Deref, DerefMut}; +use anymap::any::CloneAny; use tracing_opentelemetry::OpenTelemetrySpanExt; /// A request context that carries request-scoped information like deadlines and trace information. @@ -21,54 +24,73 @@ use tracing_opentelemetry::OpenTelemetrySpanExt; /// /// The context should not be stored directly in a server implementation, because the context will /// be different for each request in scope. -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Debug, Default)] #[non_exhaustive] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct Context { /// When the client expects the request to be complete by. The server should cancel the request /// if it is not complete by this time. - #[cfg_attr(feature = "serde1", serde(default = "ten_seconds_from_now"))] // Serialized as a Duration to prevent clock skew issues. #[cfg_attr(feature = "serde1", serde(with = "absolute_to_relative_time"))] - pub deadline: SystemTime, + pub deadline: Deadline, /// Uniquely identifies requests originating from the same source. /// When a service handles a request by making requests itself, those requests should /// include the same `trace_id` as that included on the original request. This way, /// users can trace related actions across a distributed system. pub trace_context: trace::Context, + + /// Any extra information can be requested + #[cfg_attr(feature = "serde1", serde(skip))] + pub extensions: Extensions +} + +impl PartialEq for Context { + fn eq(&self, other: &Self) -> bool { + self.trace_context == other.trace_context && self.deadline == other.deadline + } +} + +impl Eq for Context { } + +impl Hash for Context { + fn hash(&self, state: &mut H) { + self.trace_context.hash(state); + self.deadline.hash(state); + } } #[cfg(feature = "serde1")] mod absolute_to_relative_time { pub use serde::{Deserialize, Deserializer, Serialize, Serializer}; pub use std::time::{Duration, SystemTime}; + use crate::context::Deadline; - pub fn serialize(deadline: &SystemTime, serializer: S) -> Result + pub fn serialize(deadline: &Deadline, serializer: S) -> Result where S: Serializer, { - let deadline = deadline + let deadline = deadline.0 .duration_since(SystemTime::now()) .unwrap_or(Duration::ZERO); deadline.serialize(serializer) } - pub fn deserialize<'de, D>(deserializer: D) -> Result + pub fn deserialize<'de, D>(deserializer: D) -> Result where D: Deserializer<'de>, { let deadline = Duration::deserialize(deserializer)?; - Ok(SystemTime::now() + deadline) + Ok(Deadline(SystemTime::now() + deadline)) } #[cfg(test)] - #[derive(serde::Serialize, serde::Deserialize)] - struct AbsoluteToRelative(#[serde(with = "self")] SystemTime); + #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] + struct AbsoluteToRelative(#[serde(with = "self")] Deadline); #[test] fn test_serialize() { let now = SystemTime::now(); - let deadline = now + Duration::from_secs(10); + let deadline = Deadline(now + Duration::from_secs(10)); let serialized_deadline = bincode::serialize(&AbsoluteToRelative(deadline)).unwrap(); let deserialized_deadline: Duration = bincode::deserialize(&serialized_deadline).unwrap(); // TODO: how to avoid flakiness? @@ -82,14 +104,14 @@ mod absolute_to_relative_time { let AbsoluteToRelative(deserialized_deadline) = bincode::deserialize(&serialized_deadline).unwrap(); // TODO: how to avoid flakiness? - assert!(deserialized_deadline > SystemTime::now() + Duration::from_secs(9)); + assert!(*deserialized_deadline > SystemTime::now() + Duration::from_secs(9)); } } assert_impl_all!(Context: Send, Sync); -fn ten_seconds_from_now() -> SystemTime { - SystemTime::now() + Duration::from_secs(10) +fn ten_seconds_from_now() -> Deadline { + Deadline(SystemTime::now() + Duration::from_secs(10)) } /// Returns the context for the current request, or a default Context if no request is active. @@ -97,12 +119,58 @@ pub fn current() -> Context { Context::current() } -#[derive(Clone)] -struct Deadline(SystemTime); +/// Deadline for executing a request +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct Deadline(pub(self) SystemTime); impl Default for Deadline { fn default() -> Self { - Self(ten_seconds_from_now()) + ten_seconds_from_now() + } +} + +impl Deref for Deadline { + type Target = SystemTime; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for Deadline { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl Deadline { + /// Creates a new deadline + pub fn new(t: SystemTime) -> Deadline { + Deadline(t) + } +} + +/// Extensions associated with a request +#[derive(Clone, Debug)] +pub struct Extensions(anymap::Map); + +impl Default for Extensions { + fn default() -> Self { + Self(anymap::Map::new()) + } +} + +impl Deref for Extensions { + type Target = anymap::Map; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for Extensions { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 } } @@ -110,15 +178,16 @@ impl Context { /// Returns the context for the current request, or a default Context if no request is active. pub fn current() -> Self { let span = tracing::Span::current(); + Self { trace_context: trace::Context::try_from(&span) .unwrap_or_else(|_| trace::Context::default()), + extensions: Default::default(), // span is always cloned so saving this doesn't make sense. deadline: span .context() .get::() .cloned() - .unwrap_or_default() - .0, + .unwrap_or_default(), } } @@ -146,7 +215,7 @@ impl SpanExt for tracing::Span { true, opentelemetry::trace::TraceState::default(), )) - .with_value(Deadline(context.deadline)), + .with_value(context.deadline.clone()) ); } } diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index 418cedd8..91ac63c9 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -126,13 +126,9 @@ //! struct HelloServer; //! //! impl World for HelloServer { -//! // Each defined rpc generates two items in the trait, a fn that serves the RPC, and -//! // an associated type representing the future output by the fn. -//! -//! type HelloFut = Ready; -//! -//! fn hello(self, _: context::Context, name: String) -> Self::HelloFut { -//! future::ready(format!("Hello, {name}!")) +//! // Each defined rpc generates an async fn that serves the RPC +//! async fn hello(self, _: &mut context::Context, name: String) -> String { +//! format!("Hello, {name}!") //! } //! } //! ``` @@ -164,11 +160,9 @@ //! # #[derive(Clone)] //! # struct HelloServer; //! # impl World for HelloServer { -//! # // Each defined rpc generates two items in the trait, a fn that serves the RPC, and -//! # // an associated type representing the future output by the fn. -//! # type HelloFut = Ready; -//! # fn hello(self, _: context::Context, name: String) -> Self::HelloFut { -//! # future::ready(format!("Hello, {name}!")) +//! // Each defined rpc generates an async fn that serves the RPC +//! # async fn hello(self, _: &mut context::Context, name: String) -> String { +//! # format!("Hello, {name}!") //! # } //! # } //! # #[cfg(not(feature = "tokio1"))] @@ -179,7 +173,12 @@ //! let (client_transport, server_transport) = tarpc::transport::channel::unbounded(); //! //! let server = server::BaseChannel::with_defaults(server_transport); -//! tokio::spawn(server.execute(HelloServer.serve())); +//! tokio::spawn( +//! server.execute(HelloServer.serve()) +//! // Handle all requests concurrently. +//! .for_each(|response| async move { +//! tokio::spawn(response); +//! })); //! //! // WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new` //! // that takes a config and any Transport as input. @@ -200,6 +199,12 @@ //! //! Use `cargo doc` as you normally would to see the documentation created for all //! items expanded by a `service!` invocation. +#![feature( + iter_intersperse, + type_alias_impl_trait, +)] +#![cfg_attr(feature = "serde1", feature(async_closure))] + #![deny(missing_docs)] #![allow(clippy::type_complexity)] #![cfg_attr(docsrs, feature(doc_cfg))] @@ -244,62 +249,6 @@ pub use tarpc_plugins::derive_serde; /// * `fn new_stub` -- creates a new Client stub. pub use tarpc_plugins::service; -/// A utility macro that can be used for RPC server implementations. -/// -/// Syntactic sugar to make using async functions in the server implementation -/// easier. It does this by rewriting code like this, which would normally not -/// compile because async functions are disallowed in trait implementations: -/// -/// ```rust -/// # use tarpc::context; -/// # use std::net::SocketAddr; -/// #[tarpc::service] -/// trait World { -/// async fn hello(name: String) -> String; -/// } -/// -/// #[derive(Clone)] -/// struct HelloServer(SocketAddr); -/// -/// #[tarpc::server] -/// impl World for HelloServer { -/// async fn hello(self, _: context::Context, name: String) -> String { -/// format!("Hello, {name}! You are connected from {:?}.", self.0) -/// } -/// } -/// ``` -/// -/// Into code like this, which matches the service trait definition: -/// -/// ```rust -/// # use tarpc::context; -/// # use std::pin::Pin; -/// # use futures::Future; -/// # use std::net::SocketAddr; -/// #[derive(Clone)] -/// struct HelloServer(SocketAddr); -/// -/// #[tarpc::service] -/// trait World { -/// async fn hello(name: String) -> String; -/// } -/// -/// impl World for HelloServer { -/// type HelloFut = Pin + Send>>; -/// -/// fn hello(self, _: context::Context, name: String) -> Pin -/// + Send>> { -/// Box::pin(async move { -/// format!("Hello, {name}! You are connected from {:?}.", self.0) -/// }) -/// } -/// } -/// ``` -/// -/// Note that this won't touch functions unless they have been annotated with -/// `async`, meaning that this should not break existing code. -pub use tarpc_plugins::server; - pub(crate) mod cancellations; pub mod client; pub mod context; @@ -311,7 +260,6 @@ pub use crate::transport::sealed::Transport; use anyhow::Context as _; use futures::task::*; -use std::sync::Arc; use std::{error::Error, fmt::Display, io, time::SystemTime}; /// A message from a client to a server. @@ -341,14 +289,14 @@ pub enum ClientMessage { } /// A request from a client to a server. -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Debug)] #[non_exhaustive] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct Request { /// Trace context, deadline, and other cross-cutting concerns. pub context: context::Context, /// Uniquely identifies the request across all requests sent over a single channel. - pub id: u64, + pub request_id: u64, /// The request body. pub message: T, } @@ -358,6 +306,9 @@ pub struct Request { #[non_exhaustive] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct Response { + /// Trace context, deadline, and other cross-cutting concerns. + #[cfg_attr(feature = "serde1", serde(skip))] + pub context: context::Context, /// The ID of the request being responded to. pub request_id: u64, /// The response body, or an error if the request failed. @@ -384,27 +335,11 @@ pub struct ServerError { pub detail: String, } -/// Critical errors that result in a Channel disconnecting. -#[derive(thiserror::Error, Debug, PartialEq, Eq)] -pub enum ChannelError -where - E: Error + Send + Sync + 'static, -{ - /// Could not read from the transport. - #[error("could not read from the transport")] - Read(#[source] Arc), - /// Could not ready the transport for writes. - #[error("could not ready the transport for writes")] - Ready(#[source] E), - /// Could not write to the transport. - #[error("could not write to the transport")] - Write(#[source] E), - /// Could not flush the transport. - #[error("could not flush the transport")] - Flush(#[source] E), - /// Could not close the write end of the transport. - #[error("could not close the write end of the transport")] - Close(#[source] E), +impl ServerError { + /// Returns a new server error with `kind` and `detail`. + pub fn new(kind: io::ErrorKind, detail: String) -> ServerError { + Self { kind, detail } + } } impl Request { diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index a06e8f8a..1691a574 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -9,7 +9,7 @@ use crate::{ cancellations::{cancellations, CanceledRequests, RequestCancellation}, context::{self, SpanExt}, - trace, ChannelError, ClientMessage, Request, Response, Transport, + trace, ClientMessage, Request, Response, ServerError, Transport, }; use ::tokio::sync::mpsc; use futures::{ @@ -21,10 +21,11 @@ use futures::{ }; use in_flight_requests::{AlreadyExistsError, InFlightRequests}; use pin_project::pin_project; -use std::{convert::TryFrom, error::Error, fmt, marker::PhantomData, pin::Pin, sync::Arc}; +use std::{convert::TryFrom, error::Error, fmt, marker::PhantomData, pin::Pin}; use tracing::{info_span, instrument::Instrument, Span}; mod in_flight_requests; +pub mod request_hook; #[cfg(test)] mod testing; @@ -34,10 +35,9 @@ pub mod limits; /// Provides helper methods for streams of Channels. pub mod incoming; -/// Provides convenience functionality for tokio-enabled applications. -#[cfg(feature = "tokio1")] -#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] -pub mod tokio; +use request_hook::{ + AfterRequest, AfterRequestHook, BeforeAndAfterRequestHook, BeforeRequest, BeforeRequestHook, +}; /// Settings that control the behavior of [channels](Channel). #[derive(Clone, Debug)] @@ -67,32 +67,211 @@ impl Config { } /// Equivalent to a `FnOnce(Req) -> impl Future`. -pub trait Serve { +pub trait Serve { + /// Type of request. + type Req; + /// Type of response. type Resp; - /// Type of response future. - type Fut: Future; + /// Responds to a single request. + fn serve(self, ctx: &mut context::Context, req: Self::Req) -> impl Future>; /// Extracts a method name from the request. - fn method(&self, _request: &Req) -> Option<&'static str> { + fn method(&self, _request: &Self::Req) -> Option<&'static str> { None } - /// Responds to a single request. - fn serve(self, ctx: context::Context, req: Req) -> Self::Fut; + /// Runs a hook before execution of the request. + /// + /// If the hook returns an error, the request will not be executed and the error will be + /// returned instead. + /// + /// The hook can also modify the request context. This could be used, for example, to enforce a + /// maximum deadline on all requests. + /// + /// Any type that implements [`BeforeRequest`] can be used as the hook. Types that implement + /// `FnMut(&mut Context, &RequestType) -> impl Future>` can + /// also be used. + /// + /// # Example + /// + /// ```rust + /// use futures::{executor::block_on, future}; + /// use tarpc::{context, ServerError, server::{Serve, serve}}; + /// use std::io; + /// + /// let serve = serve(|_ctx, i| async move { Ok(i + 1) }) + /// .before(|_ctx: &mut context::Context, req: &i32| { + /// future::ready( + /// if *req == 1 { + /// Err(ServerError::new( + /// io::ErrorKind::Other, + /// format!("I don't like {req}"))) + /// } else { + /// Ok(()) + /// }) + /// }); + /// let mut ctx = context::current(); + /// let response = serve.serve(&mut ctx, 1); + /// assert!(block_on(response).is_err()); + /// ``` + fn before(self, hook: Hook) -> BeforeRequestHook + where + Hook: BeforeRequest, + Self: Sized, + { + BeforeRequestHook::new(self, hook) + } + + /// Runs a hook after completion of a request. + /// + /// The hook can modify the request context and the response. + /// + /// Any type that implements [`AfterRequest`] can be used as the hook. Types that implement + /// `FnMut(&mut Context, &mut Result) -> impl Future` + /// can also be used. + /// + /// # Example + /// + /// ```rust + /// use futures::{executor::block_on, future}; + /// use tarpc::{context, ServerError, server::{Serve, serve}}; + /// use std::io; + /// + /// let serve = serve( + /// |_ctx, i| async move { + /// if i == 1 { + /// Err(ServerError::new( + /// io::ErrorKind::Other, + /// format!("{i} is the loneliest number"))) + /// } else { + /// Ok(i + 1) + /// } + /// }) + /// .after(|_ctx: &mut context::Context, resp: &mut Result| { + /// if let Err(e) = resp { + /// eprintln!("server error: {e:?}"); + /// } + /// future::ready(()) + /// }); + /// + /// let mut ctx = context::current(); + /// let response = serve.serve(&mut ctx, 1); + /// assert!(block_on(response).is_err()); + /// ``` + fn after(self, hook: Hook) -> AfterRequestHook + where + Hook: AfterRequest, + Self: Sized, + { + AfterRequestHook::new(self, hook) + } + + /// Runs a hook before and after execution of the request. + /// + /// If the hook returns an error, the request will not be executed and the error will be + /// returned instead. + /// + /// The hook can also modify the request context and the response. This could be used, for + /// example, to enforce a maximum deadline on all requests. + /// + /// # Example + /// + /// ```rust + /// use futures::{executor::block_on, future}; + /// use tarpc::{ + /// context, ServerError, server::{Serve, serve, request_hook::{BeforeRequest, AfterRequest}} + /// }; + /// use std::{io, time::Instant}; + /// + /// struct PrintLatency(Instant); + /// + /// impl BeforeRequest for PrintLatency { + /// type Fut<'a> = future::Ready> where Self: 'a, Req: 'a; + /// + /// fn before<'a>(&'a mut self, _: &'a mut context::Context, _: &'a Req) -> Self::Fut<'a> { + /// self.0 = Instant::now(); + /// future::ready(Ok(())) + /// } + /// } + /// + /// impl AfterRequest for PrintLatency { + /// type Fut<'a> = future::Ready<()> where Self:'a, Resp:'a; + /// + /// fn after<'a>( + /// &'a mut self, + /// _: &'a mut context::Context, + /// _: &'a mut Result, + /// ) -> Self::Fut<'a> { + /// tracing::info!("Elapsed: {:?}", self.0.elapsed()); + /// future::ready(()) + /// } + /// } + /// + /// let serve = serve(|_ctx, i| async move { + /// Ok(i + 1) + /// }).before_and_after(PrintLatency(Instant::now())); + /// let mut ctx = context::current(); + /// let response = serve.serve(&mut ctx, 1); + /// assert!(block_on(response).is_ok()); + /// ``` + fn before_and_after( + self, + hook: Hook, + ) -> BeforeAndAfterRequestHook + where + Hook: BeforeRequest + AfterRequest, + Self: Sized, + { + BeforeAndAfterRequestHook::new(self, hook) + } +} + +/// A Serve wrapper around a Fn. +#[derive(Debug)] +pub struct ServeFn { + f: F, + data: PhantomData Resp>, } -impl Serve for F +impl Clone for ServeFn where - F: FnOnce(context::Context, Req) -> Fut, - Fut: Future, + F: Clone, { + fn clone(&self) -> Self { + Self { + f: self.f.clone(), + data: PhantomData, + } + } +} + +impl Copy for ServeFn where F: Copy {} + +/// Creates a [`Serve`] wrapper around a `FnOnce(context::Context, Req) -> impl Future>`. +pub fn serve(f: F) -> ServeFn +where + F: FnOnce(&mut context::Context, Req) -> Fut, + Fut: Future>, +{ + ServeFn { + f, + data: PhantomData, + } +} + +impl Serve for ServeFn +where + F: FnOnce(&mut context::Context, Req) -> Fut, + Fut: Future>, +{ + type Req = Req; type Resp = Resp; - type Fut = Fut; - fn serve(self, ctx: context::Context, req: Req) -> Self::Fut { - self(ctx, req) + async fn serve(self, ctx: &mut context::Context, req: Req) -> Result { + (self.f)(ctx, req).await } } @@ -120,7 +299,7 @@ pub struct BaseChannel { /// Holds data necessary to clean up in-flight requests. in_flight_requests: InFlightRequests, /// Types the request and response. - ghost: PhantomData<(Req, Resp)>, + ghost: PhantomData<(fn() -> Req, fn(Resp))>, } impl BaseChannel @@ -176,7 +355,7 @@ where let span = info_span!( "RPC", rpc.trace_id = %request.context.trace_id(), - rpc.deadline = %humantime::format_rfc3339(request.context.deadline), + rpc.deadline = %humantime::format_rfc3339(*request.context.deadline), otel.kind = "server", otel.name = tracing::field::Empty, ); @@ -191,8 +370,8 @@ where let entered = span.enter(); tracing::info!("ReceiveRequest"); let start = self.in_flight_requests_mut().start_request( - request.id, - request.context.deadline, + request.request_id, + *request.context.deadline, span.clone(), ); match start { @@ -202,7 +381,7 @@ where abort_registration, span, response_guard: ResponseGuard { - request_id: request.id, + request_id: request.request_id, request_cancellation: self.request_cancellation.clone(), cancel: false, }, @@ -307,6 +486,34 @@ where /// This is a terminal operation. After calling `requests`, the channel cannot be retrieved, /// and the only way to complete requests is via [`Requests::execute`] or /// [`InFlightRequest::execute`]. + /// + /// # Example + /// + /// ```rust + /// use tarpc::{ + /// context, + /// client::{self, NewClient}, + /// server::{self, BaseChannel, Channel, serve}, + /// transport, + /// }; + /// use futures::prelude::*; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx) = transport::channel::unbounded(); + /// let server = BaseChannel::new(server::Config::default(), rx); + /// let NewClient { client, dispatch } = client::new(client::Config::default(), tx); + /// tokio::spawn(dispatch); + /// + /// let mut requests = server.requests(); + /// tokio::spawn(async move { + /// while let Some(Ok(request)) = requests.next().await { + /// tokio::spawn(request.execute(serve(|_, i| async move { Ok(i + 1) }))); + /// } + /// }); + /// assert_eq!(client.call(context::current(), "AddOne", 1).await.unwrap(), 2); + /// } + /// ``` fn requests(self) -> Requests where Self: Sized, @@ -320,23 +527,61 @@ where } } - /// Runs the channel until completion by executing all requests using the given service - /// function. Request handlers are run concurrently by [spawning](::tokio::spawn) on tokio's - /// default executor. - #[cfg(feature = "tokio1")] - #[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] - fn execute(self, serve: S) -> self::tokio::TokioChannelExecutor, S> + /// Returns a stream of request execution futures. Each future represents an in-flight request + /// being responded to by the server. The futures must be awaited or spawned to complete their + /// requests. + /// + /// # Example + /// + /// ```rust + /// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport}; + /// use futures::prelude::*; + /// use tracing_subscriber::prelude::*; + /// + /// #[derive(PartialEq, Eq, Debug)] + /// struct MyInt(i32); + /// + /// # #[cfg(not(feature = "tokio1"))] + /// # fn main() {} + /// # #[cfg(feature = "tokio1")] + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx) = transport::channel::unbounded(); + /// let client = client::new(client::Config::default(), tx).spawn(); + /// let channel = BaseChannel::with_defaults(rx); + /// tokio::spawn( + /// channel.execute(serve(|_, MyInt(i)| async move { Ok(MyInt(i + 1)) })) + /// .for_each(|response| async move { + /// tokio::spawn(response); + /// })); + /// assert_eq!( + /// client.call(context::current(), "AddOne", MyInt(1)).await.unwrap(), + /// MyInt(2)); + /// } + /// ``` + fn execute(self, serve: S) -> impl Stream> where Self: Sized, - S: Serve + Send + 'static, - S::Fut: Send, - Self::Req: Send + 'static, - Self::Resp: Send + 'static, + S: Serve + Clone, { self.requests().execute(serve) } } +/// Critical errors that result in a Channel disconnecting. +#[derive(thiserror::Error, Debug)] +pub enum ChannelError +where + E: Error + Send + Sync + 'static, +{ + /// An error occurred reading from, or writing to, the transport. + #[error("an error occurred in the transport")] + Transport(#[source] E), + /// An error occurred while polling expired requests. + #[error("an error occurred while polling expired requests")] + Timer(#[source] ::tokio::time::error::Error), +} + impl Stream for BaseChannel where T: Transport, ClientMessage>, @@ -393,7 +638,7 @@ where let request_status = match self .transport_pin_mut() .poll_next(cx) - .map_err(|e| ChannelError::Read(Arc::new(e)))? + .map_err(ChannelError::Transport)? { Poll::Ready(Some(message)) => match message { ClientMessage::Request(request) => { @@ -409,12 +654,12 @@ where } } ClientMessage::Cancel { - trace_context, + trace_context: context, request_id, } => { if !self.in_flight_requests_mut().cancel_request(request_id) { tracing::trace!( - rpc.trace_id = %trace_context.trace_id, + rpc.trace_id = %context.trace_id, "Received cancellation, but response handler is already complete.", ); } @@ -425,15 +670,17 @@ where Poll::Pending => Pending, }; + let status = cancellation_status + .combine(expiration_status) + .combine(request_status); + tracing::trace!( - "Expired requests: {:?}, Inbound: {:?}", - expiration_status, - request_status + "Cancellations: {cancellation_status:?}, \ + Expired requests: {expiration_status:?}, \ + Inbound: {request_status:?}, \ + Overall: {status:?}", ); - match cancellation_status - .combine(expiration_status) - .combine(request_status) - { + match status { Ready => continue, Closed => return Poll::Ready(None), Pending => return Poll::Pending, @@ -453,7 +700,7 @@ where self.project() .transport .poll_ready(cx) - .map_err(ChannelError::Ready) + .map_err(ChannelError::Transport) } fn start_send(mut self: Pin<&mut Self>, response: Response) -> Result<(), Self::Error> { @@ -466,7 +713,7 @@ where self.project() .transport .start_send(response) - .map_err(ChannelError::Write) + .map_err(ChannelError::Transport) } else { // If the request isn't tracked anymore, there's no need to send the response. Ok(()) @@ -478,14 +725,14 @@ where self.project() .transport .poll_flush(cx) - .map_err(ChannelError::Flush) + .map_err(ChannelError::Transport) } fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { self.project() .transport .poll_close(cx) - .map_err(ChannelError::Close) + .map_err(ChannelError::Transport) } } @@ -563,6 +810,10 @@ where span, mut response_guard, }| { + { + let _entered = span.enter(); + tracing::info!("BeginRequest"); + } // The response guard becomes active once in an InFlightRequest. response_guard.cancel = true; InFlightRequest { @@ -639,6 +890,51 @@ where } Poll::Ready(Some(Ok(()))) } + + /// Returns a stream of request execution futures. Each future represents an in-flight request + /// being responded to by the server. The futures must be awaited or spawned to complete their + /// requests. + /// + /// If the channel encounters an error, the stream is terminated and the error is logged. + /// + /// # Example + /// + /// ```rust + /// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport}; + /// use futures::prelude::*; + /// + /// # #[cfg(not(feature = "tokio1"))] + /// # fn main() {} + /// # #[cfg(feature = "tokio1")] + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx) = transport::channel::unbounded(); + /// let requests = BaseChannel::new(server::Config::default(), rx).requests(); + /// let client = client::new(client::Config::default(), tx).spawn(); + /// tokio::spawn( + /// requests.execute(serve(|_, i| async move { Ok(i + 1) })) + /// .for_each(|response| async move { + /// tokio::spawn(response); + /// })); + /// assert_eq!(client.call(context::current(), "AddOne", 1).await.unwrap(), 2); + /// } + /// ``` + pub fn execute(self, serve: S) -> impl Stream> + where + S: Serve + Clone, + { + self.take_while(|result| { + if let Err(e) = result { + tracing::warn!("Requests stream errored out: {}", e); + } + futures::future::ready(result.is_ok()) + }) + .filter_map(|result| async move { result.ok() }) + .map(move |request| { + let serve = serve.clone(); + request.execute(serve) + }) + } } impl fmt::Debug for Requests @@ -700,9 +996,39 @@ impl InFlightRequest { /// /// If the returned Future is dropped before completion, a cancellation message will be sent to /// the Channel to clean up associated request state. + /// + /// # Example + /// + /// ```rust + /// use tarpc::{ + /// context, + /// client::{self, NewClient}, + /// server::{self, BaseChannel, Channel, serve}, + /// transport, + /// }; + /// use futures::prelude::*; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx) = transport::channel::unbounded(); + /// let server = BaseChannel::new(server::Config::default(), rx); + /// let NewClient { client, dispatch } = client::new(client::Config::default(), tx); + /// tokio::spawn(dispatch); + /// + /// tokio::spawn(async move { + /// let mut requests = server.requests(); + /// while let Some(Ok(in_flight_request)) = requests.next().await { + /// in_flight_request.execute(serve(|_, i| async move { Ok(i + 1) })).await; + /// } + /// + /// }); + /// assert_eq!(client.call(context::current(), "AddOne", 1).await.unwrap(), 2); + /// } + /// ``` + /// pub async fn execute(self, serve: S) where - S: Serve, + S: Serve, { let Self { response_tx, @@ -711,9 +1037,9 @@ impl InFlightRequest { span, request: Request { - context, + mut context, message, - id: request_id, + request_id, }, } = self; let method = serve.method(&message); @@ -723,12 +1049,12 @@ impl InFlightRequest { span.record("otel.name", &method.unwrap_or("")); let _ = Abortable::new( async move { - tracing::info!("BeginRequest"); - let response = serve.serve(context, message).await; + let message = serve.serve(&mut context, message).await; tracing::info!("CompleteRequest"); let response = Response { + context, request_id, - message: Ok(response), + message, }; let _ = response_tx.send(response).await; tracing::info!("BufferResponse"); @@ -744,6 +1070,13 @@ impl InFlightRequest { } } +fn print_err(e: &(dyn Error + 'static)) -> String { + anyhow::Chain::new(e) + .map(|e| e.to_string()) + .intersperse(": ".into()) + .collect::() +} + impl Stream for Requests where C: Channel, @@ -752,17 +1085,33 @@ where fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { - let read = self.as_mut().pump_read(cx)?; + let read = self.as_mut().pump_read(cx).map_err(|e| { + tracing::trace!("read: {}", print_err(&e)); + e + })?; let read_closed = matches!(read, Poll::Ready(None)); - match (read, self.as_mut().pump_write(cx, read_closed)?) { + let write = self.as_mut().pump_write(cx, read_closed).map_err(|e| { + tracing::trace!("write: {}", print_err(&e)); + e + })?; + match (read, write) { (Poll::Ready(None), Poll::Ready(None)) => { + tracing::trace!("read: Poll::Ready(None), write: Poll::Ready(None)"); return Poll::Ready(None); } (Poll::Ready(Some(request_handler)), _) => { + tracing::trace!("read: Poll::Ready(Some), write: _"); return Poll::Ready(Some(Ok(request_handler))); } - (_, Poll::Ready(Some(()))) => {} - _ => { + (_, Poll::Ready(Some(()))) => { + tracing::trace!("read: _, write: Poll::Ready(Some)"); + } + (read @ Poll::Pending, write) | (read, write @ Poll::Pending) => { + tracing::trace!( + "read pending: {}, write pending: {}", + read.is_pending(), + write.is_pending() + ); return Poll::Pending; } } @@ -772,11 +1121,14 @@ where #[cfg(test)] mod tests { - use super::{in_flight_requests::AlreadyExistsError, BaseChannel, Channel, Config, Requests}; + use super::{ + in_flight_requests::AlreadyExistsError, serve, AfterRequest, BaseChannel, BeforeRequest, + Channel, Config, Requests, Serve, + }; use crate::{ context, trace, transport::channel::{self, UnboundedChannel}, - ClientMessage, Request, Response, + ClientMessage, Request, Response, ServerError, }; use assert_matches::assert_matches; use futures::{ @@ -785,7 +1137,13 @@ mod tests { Future, }; use futures_test::task::noop_context; - use std::{pin::Pin, task::Poll}; + use std::{ + io, + pin::Pin, + task::Poll, + time::{Duration, Instant, SystemTime}, + }; + use std::ops::Deref; fn test_channel() -> ( Pin, Response>>>>, @@ -835,7 +1193,7 @@ mod tests { fn fake_request(req: Req) -> ClientMessage { ClientMessage::Request(Request { context: context::current(), - id: 0, + request_id: 0, message: req, }) } @@ -846,6 +1204,105 @@ mod tests { Abortable::new(pending(), abort_registration) } + #[tokio::test] + async fn test_serve() { + let serve = serve(|_, i| async move { Ok(i) }); + assert_matches!(serve.serve(&mut context::current(), 7).await, Ok(7)); + } + + #[tokio::test] + async fn serve_before_mutates_context() -> anyhow::Result<()> { + struct SetDeadline(SystemTime); + type SetDeadlineFut<'a, Req: 'a> = impl Future> + 'a; + impl BeforeRequest for SetDeadline { + type Fut<'a> = SetDeadlineFut<'a, Req> where Self: 'a, Req: 'a; + fn before<'a>( + &'a mut self, + ctx: &'a mut context::Context, + _: &'a Req, + ) -> Self::Fut<'a> { + async move { + *ctx.deadline = self.0; + Ok(()) + } + } + } + + let some_time = SystemTime::UNIX_EPOCH + Duration::from_secs(37); + let some_other_time = SystemTime::UNIX_EPOCH + Duration::from_secs(83); + + let serve = serve(|ctx: &mut context::Context, i| { + let deadline = ctx.deadline.deref().clone(); + + async move { + assert_eq!(deadline, some_time); + Ok(i) + } + }); + let deadline_hook = serve.before(SetDeadline(some_time)); + let mut ctx = context::current(); + *ctx.deadline = some_other_time; + deadline_hook.serve(&mut ctx, 7).await?; + Ok(()) + } + + #[tokio::test] + async fn serve_before_and_after() -> anyhow::Result<()> { + let _ = tracing_subscriber::fmt::try_init(); + + struct PrintLatency { + start: Instant, + } + impl PrintLatency { + fn new() -> Self { + Self { + start: Instant::now(), + } + } + } + type StartFut<'a, Req: 'a> = impl Future> + 'a; + type EndFut<'a, Resp: 'a> = impl Future + 'a; + impl BeforeRequest for PrintLatency { + type Fut<'a> = StartFut<'a, Req> where Self: 'a, Req: 'a; + fn before<'a>(&'a mut self, _: &'a mut context::Context, _: &'a Req) -> Self::Fut<'a> { + async move { + self.start = Instant::now(); + Ok(()) + } + } + } + impl AfterRequest for PrintLatency { + type Fut<'a> = EndFut<'a, Resp> where Self: 'a, Resp: 'a; + fn after<'a>( + &'a mut self, + _: &'a mut context::Context, + _: &'a mut Result, + ) -> Self::Fut<'a> { + async move { + tracing::info!("Elapsed: {:?}", self.start.elapsed()); + } + } + } + + let serve = serve(move |_: &mut context::Context, i| async move { Ok(i) }); + serve + .before_and_after(PrintLatency::new()) + .serve(&mut context::current(), 7) + .await?; + Ok(()) + } + + #[tokio::test] + async fn serve_before_error_aborts_request() -> anyhow::Result<()> { + let serve = serve(|_, _| async { panic!("Shouldn't get here") }); + let deadline_hook = serve.before(|_: &mut context::Context, _: &i32| async { + Err(ServerError::new(io::ErrorKind::Other, "oops".into())) + }); + let resp: Result = deadline_hook.serve(&mut context::current(), 7).await; + assert_matches!(resp, Err(_)); + Ok(()) + } + #[tokio::test] async fn base_channel_start_send_duplicate_request_returns_error() { let (mut channel, _tx) = test_channel::<(), ()>(); @@ -853,14 +1310,14 @@ mod tests { channel .as_mut() .start_request(Request { - id: 0, + request_id: 0, context: context::current(), message: (), }) .unwrap(); assert_matches!( channel.as_mut().start_request(Request { - id: 0, + request_id: 0, context: context::current(), message: () }), @@ -876,7 +1333,7 @@ mod tests { let req0 = channel .as_mut() .start_request(Request { - id: 0, + request_id: 0, context: context::current(), message: (), }) @@ -884,7 +1341,7 @@ mod tests { let req1 = channel .as_mut() .start_request(Request { - id: 1, + request_id: 1, context: context::current(), message: (), }) @@ -907,7 +1364,7 @@ mod tests { let req = channel .as_mut() .start_request(Request { - id: 0, + request_id: 0, context: context::current(), message: (), }) @@ -936,7 +1393,7 @@ mod tests { let _abort_registration = channel .as_mut() .start_request(Request { - id: 0, + request_id: 0, context: context::current(), message: (), }) @@ -978,7 +1435,7 @@ mod tests { let req = channel .as_mut() .start_request(Request { - id: 0, + request_id: 0, context: context::current(), message: (), }) @@ -1001,7 +1458,7 @@ mod tests { channel .as_mut() .start_request(Request { - id: 0, + request_id: 0, context: context::current(), message: (), }) @@ -1010,6 +1467,7 @@ mod tests { channel .as_mut() .start_send(Response { + context: context::current(), request_id: 0, message: Ok(()), }) @@ -1046,7 +1504,7 @@ mod tests { Poll::Ready(Some(Ok(request))) => request, result => panic!("Unexpected result: {:?}", result), }; - request.execute(|_, _| async {}).await; + request.execute(serve(|_, _| async { Ok(()) })).await; assert!(requests .as_mut() .channel_pin_mut() @@ -1064,7 +1522,7 @@ mod tests { .as_mut() .channel_pin_mut() .start_request(Request { - id: 0, + request_id: 0, context: context::current(), message: (), }) @@ -1073,6 +1531,7 @@ mod tests { .as_mut() .channel_pin_mut() .start_send(Response { + context: context::current(), request_id: 0, message: Ok(()), }) @@ -1084,6 +1543,7 @@ mod tests { .project() .responses_tx .send(Response { + context: context::current(), request_id: 1, message: Ok(()), }) @@ -1094,7 +1554,7 @@ mod tests { .as_mut() .channel_pin_mut() .start_request(Request { - id: 1, + request_id: 1, context: context::current(), message: (), }) @@ -1115,7 +1575,7 @@ mod tests { .as_mut() .channel_pin_mut() .start_request(Request { - id: 0, + request_id: 0, context: context::current(), message: (), }) @@ -1124,6 +1584,7 @@ mod tests { .as_mut() .channel_pin_mut() .start_send(Response { + context: context::current(), request_id: 0, message: Ok(()), }) @@ -1134,7 +1595,7 @@ mod tests { .as_mut() .channel_pin_mut() .start_request(Request { - id: 1, + request_id: 1, context: context::current(), message: (), }) @@ -1144,6 +1605,7 @@ mod tests { .project() .responses_tx .send(Response { + context: context::current(), request_id: 1, message: Ok(()), }) diff --git a/tarpc/src/server/incoming.rs b/tarpc/src/server/incoming.rs index 445fc3e8..9195ee30 100644 --- a/tarpc/src/server/incoming.rs +++ b/tarpc/src/server/incoming.rs @@ -1,13 +1,10 @@ use super::{ limits::{channels_per_key::MaxChannelsPerKey, requests_per_channel::MaxRequestsPerChannel}, - Channel, + Channel, Serve, }; use futures::prelude::*; use std::{fmt, hash::Hash}; -#[cfg(feature = "tokio1")] -use super::{tokio::TokioServerExecutor, Serve}; - /// An extension trait for [streams](futures::prelude::Stream) of [`Channels`](Channel). pub trait Incoming where @@ -28,16 +25,62 @@ where MaxRequestsPerChannel::new(self, n) } - /// [Executes](Channel::execute) each incoming channel. Each channel will be handled - /// concurrently by spawning on tokio's default executor, and each request will be also - /// be spawned on tokio's default executor. - #[cfg(feature = "tokio1")] - #[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] - fn execute(self, serve: S) -> TokioServerExecutor + /// Returns a stream of channels in execution. Each channel in execution is a stream of + /// futures, where each future is an in-flight request being rsponded to. + fn execute( + self, + serve: S, + ) -> impl Stream>> where - S: Serve, + S: Serve + Clone, { - TokioServerExecutor::new(self, serve) + self.map(move |channel| channel.execute(serve.clone())) + } +} + +#[cfg(feature = "tokio1")] +/// Spawns all channels-in-execution, delegating to the tokio runtime to manage their completion. +/// Each channel is spawned, and each request from each channel is spawned. +/// Note that this function is generic over any stream-of-streams-of-futures, but it is intended +/// for spawning streams of channels. +/// +/// # Example +/// ```rust +/// use tarpc::{ +/// context, +/// client::{self, NewClient}, +/// server::{self, BaseChannel, Channel, incoming::{Incoming, spawn_incoming}, serve}, +/// transport, +/// }; +/// use futures::prelude::*; +/// +/// #[tokio::main] +/// async fn main() { +/// let (tx, rx) = transport::channel::unbounded(); +/// let NewClient { client, dispatch } = client::new(client::Config::default(), tx); +/// tokio::spawn(dispatch); +/// +/// let incoming = stream::once(async move { +/// BaseChannel::new(server::Config::default(), rx) +/// }).execute(serve(|_, i| async move { Ok(i + 1) })); +/// tokio::spawn(spawn_incoming(incoming)); +/// assert_eq!(client.call(context::current(), "AddOne", 1).await.unwrap(), 2); +/// } +/// ``` +pub async fn spawn_incoming( + incoming: impl Stream< + Item = impl Stream + Send + 'static> + Send + 'static, + >, +) { + use futures::pin_mut; + pin_mut!(incoming); + while let Some(channel) = incoming.next().await { + tokio::spawn(async move { + pin_mut!(channel); + while let Some(request) = channel.next().await { + tokio::spawn(request); + } + }); } } diff --git a/tarpc/src/server/limits/requests_per_channel.rs b/tarpc/src/server/limits/requests_per_channel.rs index 3c29836a..a6d5e344 100644 --- a/tarpc/src/server/limits/requests_per_channel.rs +++ b/tarpc/src/server/limits/requests_per_channel.rs @@ -66,7 +66,8 @@ where ); self.as_mut().start_send(Response { - request_id: r.request.id, + context: r.request.context, + request_id: r.request.request_id, message: Err(ServerError { kind: io::ErrorKind::WouldBlock, detail: "server throttled the request.".into(), @@ -188,6 +189,7 @@ mod tests { time::{Duration, SystemTime}, }; use tracing::Span; + use crate::context; #[tokio::test] async fn throttler_in_flight_requests() { @@ -236,7 +238,7 @@ mod tests { throttler .as_mut() .poll_next(&mut testing::cx())? - .map(|r| r.map(|r| (r.request.id, r.request.message))), + .map(|r| r.map(|r| (r.request.request_id, r.request.message))), Poll::Ready(Some((0, 1))) ); Ok(()) @@ -335,15 +337,13 @@ mod tests { .start_send(Response { request_id: 0, message: Ok(1), + context: context::current() }) .unwrap(); assert_eq!(throttler.inner.in_flight_requests.len(), 0); assert_eq!( - throttler.inner.sink.get(0), - Some(&Response { - request_id: 0, - message: Ok(1), - }) + throttler.inner.sink.get(0).map(|resp| (resp.request_id, &resp.message)), + Some((0, &Ok(1))), ); } } diff --git a/tarpc/src/server/request_hook.rs b/tarpc/src/server/request_hook.rs new file mode 100644 index 00000000..ef23d73b --- /dev/null +++ b/tarpc/src/server/request_hook.rs @@ -0,0 +1,22 @@ +// Copyright 2022 Google LLC +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +//! Hooks for horizontal functionality that can run either before or after a request is executed. + +/// A request hook that runs before a request is executed. +mod before; + +/// A request hook that runs after a request is completed. +mod after; + +/// A request hook that runs both before a request is executed and after it is completed. +mod before_and_after; + +pub use { + after::{AfterRequest, AfterRequestHook}, + before::{BeforeRequest, BeforeRequestHook}, + before_and_after::BeforeAndAfterRequestHook, +}; diff --git a/tarpc/src/server/request_hook/after.rs b/tarpc/src/server/request_hook/after.rs new file mode 100644 index 00000000..4c108a64 --- /dev/null +++ b/tarpc/src/server/request_hook/after.rs @@ -0,0 +1,87 @@ +// Copyright 2022 Google LLC +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +//! Provides a hook that runs after request execution. + +use crate::{context, server::Serve, ServerError}; +use futures::prelude::*; + +/// A hook that runs after request execution. +pub trait AfterRequest { + /// The type of future returned by the hook. + type Fut<'a>: Future + where + Self: 'a, + Resp: 'a; + + /// The function that is called after request execution. + /// + /// The hook can modify the request context and the response. + fn after<'a>( + &'a mut self, + ctx: &'a mut context::Context, + resp: &'a mut Result, + ) -> Self::Fut<'a>; +} + +impl AfterRequest for F +where + F: FnMut(&mut context::Context, &mut Result) -> Fut, + Fut: Future, +{ + type Fut<'a> = Fut where Self: 'a, Resp: 'a; + + fn after<'a>( + &'a mut self, + ctx: &'a mut context::Context, + resp: &'a mut Result, + ) -> Self::Fut<'a> { + self(ctx, resp) + } +} + +/// A Service function that runs a hook after request execution. +pub struct AfterRequestHook { + serve: Serv, + hook: Hook, +} + +impl AfterRequestHook { + pub(crate) fn new(serve: Serv, hook: Hook) -> Self { + Self { serve, hook } + } +} + +impl Clone for AfterRequestHook { + fn clone(&self) -> Self { + Self { + serve: self.serve.clone(), + hook: self.hook.clone(), + } + } +} + +impl Serve for AfterRequestHook +where + Serv: Serve, + Hook: AfterRequest, +{ + type Req = Serv::Req; + type Resp = Serv::Resp; + + async fn serve( + self, + ctx: &mut context::Context, + req: Serv::Req, + ) -> Result { + let AfterRequestHook { + serve, mut hook, .. + } = self; + let mut resp = serve.serve(ctx, req).await; + hook.after(ctx, &mut resp).await; + resp + } +} diff --git a/tarpc/src/server/request_hook/before.rs b/tarpc/src/server/request_hook/before.rs new file mode 100644 index 00000000..a0040e63 --- /dev/null +++ b/tarpc/src/server/request_hook/before.rs @@ -0,0 +1,82 @@ +// Copyright 2022 Google LLC +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +//! Provides a hook that runs before request execution. + +use crate::{context, server::Serve, ServerError}; +use futures::prelude::*; + +/// A hook that runs before request execution. +pub trait BeforeRequest { + /// The type of future returned by the hook. + type Fut<'a>: Future> + where + Self: 'a, + Req: 'a; + + /// The function that is called before request execution. + /// + /// If this function returns an error, the request will not be executed and the error will be + /// returned instead. + /// + /// This function can also modify the request context. This could be used, for example, to + /// enforce a maximum deadline on all requests. + fn before<'a>(&'a mut self, ctx: &'a mut context::Context, req: &'a Req) -> Self::Fut<'a>; +} + +impl BeforeRequest for F +where + F: FnMut(&mut context::Context, &Req) -> Fut, + Fut: Future>, +{ + type Fut<'a> = Fut where Self: 'a, Req: 'a; + + fn before<'a>(&'a mut self, ctx: &'a mut context::Context, req: &'a Req) -> Self::Fut<'a> { + self(ctx, req) + } +} + +/// A Service function that runs a hook before request execution. +pub struct BeforeRequestHook { + serve: Serv, + hook: Hook, +} + +impl BeforeRequestHook { + pub(crate) fn new(serve: Serv, hook: Hook) -> Self { + Self { serve, hook } + } +} + +impl Clone for BeforeRequestHook { + fn clone(&self) -> Self { + Self { + serve: self.serve.clone(), + hook: self.hook.clone(), + } + } +} + +impl Serve for BeforeRequestHook +where + Serv: Serve, + Hook: BeforeRequest, +{ + type Req = Serv::Req; + type Resp = Serv::Resp; + + async fn serve( + self, + ctx: &mut context::Context, + req: Self::Req, + ) -> Result { + let BeforeRequestHook { + serve, mut hook, .. + } = self; + hook.before(ctx, &req).await?; + serve.serve(ctx, req).await + } +} diff --git a/tarpc/src/server/request_hook/before_and_after.rs b/tarpc/src/server/request_hook/before_and_after.rs new file mode 100644 index 00000000..5acdaea6 --- /dev/null +++ b/tarpc/src/server/request_hook/before_and_after.rs @@ -0,0 +1,59 @@ +// Copyright 2022 Google LLC +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +//! Provides a hook that runs both before and after request execution. + +use super::{after::AfterRequest, before::BeforeRequest}; +use crate::{context, server::Serve, ServerError}; +use std::marker::PhantomData; + +/// A Service function that runs a hook both before and after request execution. +pub struct BeforeAndAfterRequestHook { + serve: Serv, + hook: Hook, + fns: PhantomData<(fn(Req), fn(Resp))>, +} + +impl BeforeAndAfterRequestHook { + pub(crate) fn new(serve: Serv, hook: Hook) -> Self { + Self { + serve, + hook, + fns: PhantomData, + } + } +} + +impl Clone + for BeforeAndAfterRequestHook +{ + fn clone(&self) -> Self { + Self { + serve: self.serve.clone(), + hook: self.hook.clone(), + fns: PhantomData, + } + } +} + +impl Serve for BeforeAndAfterRequestHook +where + Serv: Serve, + Hook: BeforeRequest + AfterRequest, +{ + type Req = Req; + type Resp = Resp; + + async fn serve(self, ctx: &mut context::Context, req: Req) -> Result { + let BeforeAndAfterRequestHook { + serve, mut hook, .. + } = self; + hook.before(ctx, &req).await?; + let mut resp = serve.serve(ctx, req).await; + hook.after(ctx, &mut resp).await; + resp + } +} diff --git a/tarpc/src/server/testing.rs b/tarpc/src/server/testing.rs index 938865c0..82bfb71d 100644 --- a/tarpc/src/server/testing.rs +++ b/tarpc/src/server/testing.rs @@ -14,6 +14,7 @@ use futures::{task::*, Sink, Stream}; use pin_project::pin_project; use std::{collections::VecDeque, io, pin::Pin, time::SystemTime}; use tracing::Span; +use crate::context::Deadline; #[pin_project] pub(crate) struct FakeChannel { @@ -93,10 +94,11 @@ impl FakeChannel>, Response> { self.stream.push_back(Ok(TrackedRequest { request: Request { context: context::Context { - deadline: SystemTime::UNIX_EPOCH, + deadline: Deadline::new(SystemTime::UNIX_EPOCH), trace_context: Default::default(), + extensions: Default::default() }, - id, + request_id: id, message, }, abort_registration, diff --git a/tarpc/src/server/tokio.rs b/tarpc/src/server/tokio.rs deleted file mode 100644 index a44e8469..00000000 --- a/tarpc/src/server/tokio.rs +++ /dev/null @@ -1,113 +0,0 @@ -use super::{Channel, Requests, Serve}; -use futures::{prelude::*, ready, task::*}; -use pin_project::pin_project; -use std::pin::Pin; - -/// A future that drives the server by [spawning](tokio::spawn) a [`TokioChannelExecutor`](TokioChannelExecutor) -/// for each new channel. Returned by -/// [`Incoming::execute`](crate::server::incoming::Incoming::execute). -#[must_use] -#[pin_project] -#[derive(Debug)] -pub struct TokioServerExecutor { - #[pin] - inner: T, - serve: S, -} - -impl TokioServerExecutor { - pub(crate) fn new(inner: T, serve: S) -> Self { - Self { inner, serve } - } -} - -/// A future that drives the server by [spawning](tokio::spawn) each [response -/// handler](super::InFlightRequest::execute) on tokio's default executor. Returned by -/// [`Channel::execute`](crate::server::Channel::execute). -#[must_use] -#[pin_project] -#[derive(Debug)] -pub struct TokioChannelExecutor { - #[pin] - inner: T, - serve: S, -} - -impl TokioServerExecutor { - fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> { - self.as_mut().project().inner - } -} - -impl TokioChannelExecutor { - fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> { - self.as_mut().project().inner - } -} - -// Send + 'static execution helper methods. - -impl Requests -where - C: Channel, - C::Req: Send + 'static, - C::Resp: Send + 'static, -{ - /// Executes all requests using the given service function. Requests are handled concurrently - /// by [spawning](::tokio::spawn) each handler on tokio's default executor. - pub fn execute(self, serve: S) -> TokioChannelExecutor - where - S: Serve + Send + 'static, - { - TokioChannelExecutor { inner: self, serve } - } -} - -impl Future for TokioServerExecutor -where - St: Sized + Stream, - C: Channel + Send + 'static, - C::Req: Send + 'static, - C::Resp: Send + 'static, - Se: Serve + Send + 'static + Clone, - Se::Fut: Send, -{ - type Output = (); - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { - while let Some(channel) = ready!(self.inner_pin_mut().poll_next(cx)) { - tokio::spawn(channel.execute(self.serve.clone())); - } - tracing::info!("Server shutting down."); - Poll::Ready(()) - } -} - -impl Future for TokioChannelExecutor, S> -where - C: Channel + 'static, - C::Req: Send + 'static, - C::Resp: Send + 'static, - S: Serve + Send + 'static + Clone, - S::Fut: Send, -{ - type Output = (); - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - while let Some(response_handler) = ready!(self.inner_pin_mut().poll_next(cx)) { - match response_handler { - Ok(resp) => { - let server = self.serve.clone(); - tokio::spawn(async move { - resp.execute(server).await; - }); - } - Err(e) => { - tracing::warn!("Requests stream errored out: {}", e); - break; - } - } - } - Poll::Ready(()) - } -} diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index 529ae8f5..98ea0aac 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -14,9 +14,15 @@ use tokio::sync::mpsc; /// Errors that occur in the sending or receiving of messages over a channel. #[derive(thiserror::Error, Debug)] pub enum ChannelError { - /// An error occurred sending over the channel. - #[error("an error occurred sending over the channel")] + /// An error occurred readying to send into the channel. + #[error("an error occurred readying to send into the channel")] + Ready(#[source] Box), + /// An error occurred sending into the channel. + #[error("an error occurred sending into the channel")] Send(#[source] Box), + /// An error occurred receiving from the channel. + #[error("an error occurred receiving from the channel")] + Receive(#[source] Box), } /// Returns two unbounded channel peers. Each [`Stream`] yields items sent through the other's @@ -48,7 +54,10 @@ impl Stream for UnboundedChannel { mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>> { - self.rx.poll_recv(cx).map(|option| option.map(Ok)) + self.rx + .poll_recv(cx) + .map(|option| option.map(Ok)) + .map_err(ChannelError::Receive) } } @@ -59,7 +68,7 @@ impl Sink for UnboundedChannel { fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { Poll::Ready(if self.tx.is_closed() { - Err(ChannelError::Send(CLOSED_MESSAGE.into())) + Err(ChannelError::Ready(CLOSED_MESSAGE.into())) } else { Ok(()) }) @@ -110,7 +119,11 @@ impl Stream for Channel { self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>> { - self.project().rx.poll_next(cx).map(|option| option.map(Ok)) + self.project() + .rx + .poll_next(cx) + .map(|option| option.map(Ok)) + .map_err(ChannelError::Receive) } } @@ -121,7 +134,7 @@ impl Sink for Channel { self.project() .tx .poll_ready(cx) - .map_err(|e| ChannelError::Send(Box::new(e))) + .map_err(|e| ChannelError::Ready(Box::new(e))) } fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> { @@ -146,16 +159,17 @@ impl Sink for Channel { } } -#[cfg(test)] -#[cfg(feature = "tokio1")] +#[cfg(all(test, feature = "tokio1"))] mod tests { use crate::{ - client, context, - server::{incoming::Incoming, BaseChannel}, + client::{self, RpcError}, + context, + server::{incoming::Incoming, serve, BaseChannel}, transport::{ self, channel::{Channel, UnboundedChannel}, }, + ServerError, }; use assert_matches::assert_matches; use futures::{prelude::*, stream}; @@ -177,25 +191,28 @@ mod tests { tokio::spawn( stream::once(future::ready(server_channel)) .map(BaseChannel::with_defaults) - .execute(|_ctx, request: String| { - future::ready(request.parse::().map_err(|_| { - io::Error::new( + .execute(serve(|_ctx, request: String| async move { + request.parse::().map_err(|_| { + ServerError::new( io::ErrorKind::InvalidInput, format!("{request:?} is not an int"), ) - })) + }) + })) + .for_each(|channel| async move { + tokio::spawn(channel.for_each(|response| response)); }), ); let client = client::new(client::Config::default(), client_channel).spawn(); - let response1 = client.call(context::current(), "", "123".into()).await?; - let response2 = client.call(context::current(), "", "abc".into()).await?; + let response1 = client.call(context::current(), "", "123".into()).await; + let response2 = client.call(context::current(), "", "abc".into()).await; trace!("response1: {:?}, response2: {:?}", response1, response2); assert_matches!(response1, Ok(123)); - assert_matches!(response2, Err(ref e) if e.kind() == io::ErrorKind::InvalidInput); + assert_matches!(response2, Err(RpcError::Server(e)) if e.kind == io::ErrorKind::InvalidInput); Ok(()) } diff --git a/tarpc/tests/compile_fail.rs b/tarpc/tests/compile_fail.rs index 4c5a28ec..c28fe2fa 100644 --- a/tarpc/tests/compile_fail.rs +++ b/tarpc/tests/compile_fail.rs @@ -2,8 +2,6 @@ fn ui() { let t = trybuild::TestCases::new(); t.compile_fail("tests/compile_fail/*.rs"); - #[cfg(feature = "tokio1")] - t.compile_fail("tests/compile_fail/tokio/*.rs"); #[cfg(all(feature = "serde-transport", feature = "tcp"))] t.compile_fail("tests/compile_fail/serde_transport/*.rs"); } diff --git a/tarpc/tests/compile_fail/must_use_request_dispatch.stderr b/tarpc/tests/compile_fail/must_use_request_dispatch.stderr index f7aa3ea6..e652cc8e 100644 --- a/tarpc/tests/compile_fail/must_use_request_dispatch.stderr +++ b/tarpc/tests/compile_fail/must_use_request_dispatch.stderr @@ -9,3 +9,7 @@ note: the lint level is defined here | 11 | #[deny(unused_must_use)] | ^^^^^^^^^^^^^^^ +help: use `let _ = ...` to ignore the resulting value + | +13 | let _ = WorldClient::new(client::Config::default(), client_transport).dispatch; + | +++++++ diff --git a/tarpc/tests/compile_fail/serde_transport/must_use_tcp_connect.stderr b/tarpc/tests/compile_fail/serde_transport/must_use_tcp_connect.stderr index d3f4eb62..b6e9bdef 100644 --- a/tarpc/tests/compile_fail/serde_transport/must_use_tcp_connect.stderr +++ b/tarpc/tests/compile_fail/serde_transport/must_use_tcp_connect.stderr @@ -9,3 +9,7 @@ note: the lint level is defined here | 5 | #[deny(unused_must_use)] | ^^^^^^^^^^^^^^^ +help: use `let _ = ...` to ignore the resulting value + | +7 | let _ = serde_transport::tcp::connect::<_, (), (), _, _>("0.0.0.0:0", Json::default); + | +++++++ diff --git a/tarpc/tests/compile_fail/tarpc_server_missing_async.rs b/tarpc/tests/compile_fail/tarpc_server_missing_async.rs deleted file mode 100644 index 99d858b6..00000000 --- a/tarpc/tests/compile_fail/tarpc_server_missing_async.rs +++ /dev/null @@ -1,15 +0,0 @@ -#[tarpc::service(derive_serde = false)] -trait World { - async fn hello(name: String) -> String; -} - -struct HelloServer; - -#[tarpc::server] -impl World for HelloServer { - fn hello(name: String) -> String { - format!("Hello, {name}!", name) - } -} - -fn main() {} diff --git a/tarpc/tests/compile_fail/tarpc_server_missing_async.stderr b/tarpc/tests/compile_fail/tarpc_server_missing_async.stderr deleted file mode 100644 index 28106e63..00000000 --- a/tarpc/tests/compile_fail/tarpc_server_missing_async.stderr +++ /dev/null @@ -1,11 +0,0 @@ -error: not all trait items implemented, missing: `HelloFut` - --> $DIR/tarpc_server_missing_async.rs:9:1 - | -9 | impl World for HelloServer { - | ^^^^ - -error: hint: `#[tarpc::server]` only rewrites async fns, and `fn hello` is not async - --> $DIR/tarpc_server_missing_async.rs:10:5 - | -10 | fn hello(name: String) -> String { - | ^^ diff --git a/tarpc/tests/compile_fail/tokio/must_use_channel_executor.rs b/tarpc/tests/compile_fail/tokio/must_use_channel_executor.rs deleted file mode 100644 index 6fc2f2bf..00000000 --- a/tarpc/tests/compile_fail/tokio/must_use_channel_executor.rs +++ /dev/null @@ -1,29 +0,0 @@ -use tarpc::{ - context, - server::{self, Channel}, -}; - -#[tarpc::service] -trait World { - async fn hello(name: String) -> String; -} - -#[derive(Clone)] -struct HelloServer; - -#[tarpc::server] -impl World for HelloServer { - async fn hello(self, _: context::Context, name: String) -> String { - format!("Hello, {name}!") - } -} - -fn main() { - let (_, server_transport) = tarpc::transport::channel::unbounded(); - let server = server::BaseChannel::with_defaults(server_transport); - - #[deny(unused_must_use)] - { - server.execute(HelloServer.serve()); - } -} diff --git a/tarpc/tests/compile_fail/tokio/must_use_channel_executor.stderr b/tarpc/tests/compile_fail/tokio/must_use_channel_executor.stderr deleted file mode 100644 index 446f224f..00000000 --- a/tarpc/tests/compile_fail/tokio/must_use_channel_executor.stderr +++ /dev/null @@ -1,11 +0,0 @@ -error: unused `TokioChannelExecutor` that must be used - --> tests/compile_fail/tokio/must_use_channel_executor.rs:27:9 - | -27 | server.execute(HelloServer.serve()); - | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - | -note: the lint level is defined here - --> tests/compile_fail/tokio/must_use_channel_executor.rs:25:12 - | -25 | #[deny(unused_must_use)] - | ^^^^^^^^^^^^^^^ diff --git a/tarpc/tests/compile_fail/tokio/must_use_server_executor.rs b/tarpc/tests/compile_fail/tokio/must_use_server_executor.rs deleted file mode 100644 index 950cf74e..00000000 --- a/tarpc/tests/compile_fail/tokio/must_use_server_executor.rs +++ /dev/null @@ -1,30 +0,0 @@ -use futures::stream::once; -use tarpc::{ - context, - server::{self, incoming::Incoming}, -}; - -#[tarpc::service] -trait World { - async fn hello(name: String) -> String; -} - -#[derive(Clone)] -struct HelloServer; - -#[tarpc::server] -impl World for HelloServer { - async fn hello(self, _: context::Context, name: String) -> String { - format!("Hello, {name}!") - } -} - -fn main() { - let (_, server_transport) = tarpc::transport::channel::unbounded(); - let server = once(async move { server::BaseChannel::with_defaults(server_transport) }); - - #[deny(unused_must_use)] - { - server.execute(HelloServer.serve()); - } -} diff --git a/tarpc/tests/compile_fail/tokio/must_use_server_executor.stderr b/tarpc/tests/compile_fail/tokio/must_use_server_executor.stderr deleted file mode 100644 index 07d4b5a9..00000000 --- a/tarpc/tests/compile_fail/tokio/must_use_server_executor.stderr +++ /dev/null @@ -1,11 +0,0 @@ -error: unused `TokioServerExecutor` that must be used - --> tests/compile_fail/tokio/must_use_server_executor.rs:28:9 - | -28 | server.execute(HelloServer.serve()); - | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - | -note: the lint level is defined here - --> tests/compile_fail/tokio/must_use_server_executor.rs:26:12 - | -26 | #[deny(unused_must_use)] - | ^^^^^^^^^^^^^^^ diff --git a/tarpc/tests/dataservice.rs b/tarpc/tests/dataservice.rs index 365594bd..78e28cd7 100644 --- a/tarpc/tests/dataservice.rs +++ b/tarpc/tests/dataservice.rs @@ -21,9 +21,8 @@ pub trait ColorProtocol { #[derive(Clone)] struct ColorServer; -#[tarpc::server] impl ColorProtocol for ColorServer { - async fn get_opposite_color(self, _: context::Context, color: TestData) -> TestData { + async fn get_opposite_color(self, _: &mut context::Context, color: TestData) -> TestData { match color { TestData::White => TestData::Black, TestData::Black => TestData::White, @@ -31,6 +30,11 @@ impl ColorProtocol for ColorServer { } } +#[cfg(test)] +async fn spawn(fut: impl Future + Send + 'static) { + tokio::spawn(fut); +} + #[tokio::test] async fn test_call() -> anyhow::Result<()> { let transport = tarpc::serde_transport::tcp::listen("localhost:56797", Json::default).await?; @@ -40,7 +44,9 @@ async fn test_call() -> anyhow::Result<()> { .take(1) .filter_map(|r| async { r.ok() }) .map(BaseChannel::with_defaults) - .execute(ColorServer.serve()), + .execute(ColorServer.serve()) + .map(|channel| channel.for_each(spawn)) + .for_each(spawn), ); let transport = serde_transport::tcp::connect(addr, Json::default).await?; diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index 50d19b0e..58066b41 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -1,13 +1,13 @@ use assert_matches::assert_matches; use futures::{ - future::{join_all, ready, Ready}, + future::{join_all, ready}, prelude::*, }; use std::time::{Duration, SystemTime}; use tarpc::{ client::{self}, context, - server::{self, incoming::Incoming, BaseChannel, Channel}, + server::{incoming::Incoming, BaseChannel, Channel}, transport::channel, }; use tokio::join; @@ -22,39 +22,29 @@ trait Service { struct Server; impl Service for Server { - type AddFut = Ready; - - fn add(self, _: context::Context, x: i32, y: i32) -> Self::AddFut { - ready(x + y) + async fn add(self, _: &mut context::Context, x: i32, y: i32) -> i32 { + x + y } - type HeyFut = Ready; - - fn hey(self, _: context::Context, name: String) -> Self::HeyFut { - ready(format!("Hey, {name}.")) + async fn hey(self, _: &mut context::Context, name: String) -> String { + format!("Hey, {name}.") } } #[tokio::test] -async fn sequential() -> anyhow::Result<()> { - let _ = tracing_subscriber::fmt::try_init(); - - let (tx, rx) = channel::unbounded(); - +async fn sequential() { + let (tx, rx) = tarpc::transport::channel::unbounded(); + let client = client::new(client::Config::default(), tx).spawn(); + let channel = BaseChannel::with_defaults(rx); tokio::spawn( - BaseChannel::new(server::Config::default(), rx) - .requests() - .execute(Server.serve()), + channel + .execute(tarpc::server::serve(|_, i| async move { Ok(i + 1) })) + .for_each(|response| response), + ); + assert_eq!( + client.call(context::current(), "AddOne", 1).await.unwrap(), + 2 ); - - let client = ServiceClient::new(client::Config::default(), tx).spawn(); - - assert_matches!(client.add(context::current(), 1, 2).await, Ok(3)); - assert_matches!( - client.hey(context::current(), "Tim".into()).await, - Ok(ref s) if s == "Hey, Tim."); - - Ok(()) } #[tokio::test] @@ -70,9 +60,8 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { #[derive(Debug)] struct AllHandlersComplete; - #[tarpc::server] impl Loop for LoopServer { - async fn r#loop(self, _: context::Context) { + async fn r#loop(self, _: &mut context::Context) { loop { futures::pending!(); } @@ -89,7 +78,7 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { let client = LoopClient::new(client::Config::default(), tx).spawn(); let mut ctx = context::current(); - ctx.deadline = SystemTime::now() + Duration::from_secs(60 * 60); + *ctx.deadline = SystemTime::now() + Duration::from_secs(60 * 60); let _ = client.r#loop(ctx).await; }); @@ -121,7 +110,9 @@ async fn serde_tcp() -> anyhow::Result<()> { .take(1) .filter_map(|r| async { r.ok() }) .map(BaseChannel::with_defaults) - .execute(Server.serve()), + .execute(Server.serve()) + .map(|channel| channel.for_each(spawn)) + .for_each(spawn), ); let transport = serde_transport::tcp::connect(addr, Json::default).await?; @@ -151,7 +142,9 @@ async fn serde_uds() -> anyhow::Result<()> { .take(1) .filter_map(|r| async { r.ok() }) .map(BaseChannel::with_defaults) - .execute(Server.serve()), + .execute(Server.serve()) + .map(|channel| channel.for_each(spawn)) + .for_each(spawn), ); let transport = serde_transport::unix::connect(&sock, Json::default).await?; @@ -175,7 +168,9 @@ async fn concurrent() -> anyhow::Result<()> { tokio::spawn( stream::once(ready(rx)) .map(BaseChannel::with_defaults) - .execute(Server.serve()), + .execute(Server.serve()) + .map(|channel| channel.for_each(spawn)) + .for_each(spawn), ); let client = ServiceClient::new(client::Config::default(), tx).spawn(); @@ -199,7 +194,9 @@ async fn concurrent_join() -> anyhow::Result<()> { tokio::spawn( stream::once(ready(rx)) .map(BaseChannel::with_defaults) - .execute(Server.serve()), + .execute(Server.serve()) + .map(|channel| channel.for_each(spawn)) + .for_each(spawn), ); let client = ServiceClient::new(client::Config::default(), tx).spawn(); @@ -216,15 +213,20 @@ async fn concurrent_join() -> anyhow::Result<()> { Ok(()) } +#[cfg(test)] +async fn spawn(fut: impl Future + Send + 'static) { + tokio::spawn(fut); +} + #[tokio::test] async fn concurrent_join_all() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); let (tx, rx) = channel::unbounded(); tokio::spawn( - stream::once(ready(rx)) - .map(BaseChannel::with_defaults) - .execute(Server.serve()), + BaseChannel::with_defaults(rx) + .execute(Server.serve()) + .for_each(spawn), ); let client = ServiceClient::new(client::Config::default(), tx).spawn(); @@ -249,11 +251,9 @@ async fn counter() -> anyhow::Result<()> { struct CountService(u32); impl Counter for &mut CountService { - type CountFut = futures::future::Ready; - - fn count(self, _: context::Context) -> Self::CountFut { + async fn count(self, _: &mut context::Context) -> u32 { self.0 += 1; - futures::future::ready(self.0) + self.0 } }