Skip to content

Commit

Permalink
Transport's connection separated to client and server side specific
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexSherbinin committed Feb 28, 2024
1 parent f1b7e0c commit 9bb8c13
Show file tree
Hide file tree
Showing 11 changed files with 104 additions and 75 deletions.
2 changes: 1 addition & 1 deletion rustyrpc/examples/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ async fn main() {
tokio::time::sleep(Duration::from_secs(2)).await; // Waiting to allow HelloService deallocation request to be sent.
}

async fn start_healthcheck<Connection: transport::Connection, Format: EncodingFormat>(
async fn start_healthcheck<Connection: transport::ClientConnection, Format: EncodingFormat>(
hello_service_client: HelloServiceClient<Connection, Format>,
) where
for<'a> RequestKind<'a>: Encode<Format>,
Expand Down
8 changes: 4 additions & 4 deletions rustyrpc/examples/common/auth_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,13 @@ where
}

#[derive_where(Clone)]
pub struct AuthServiceClient<Connection: transport::Connection, Format: EncodingFormat> {
pub struct AuthServiceClient<Connection: transport::ClientConnection, Format: EncodingFormat> {
service_kind: ServiceKind,
service_id: usize,
rpc_client: Arc<Client<Connection, Format>>,
}

impl<Connection: transport::Connection, Format: EncodingFormat>
impl<Connection: transport::ClientConnection, Format: EncodingFormat>
AuthServiceClient<Connection, Format>
where
for<'a> RequestKind<'a>: Encode<Format>,
Expand Down Expand Up @@ -142,8 +142,8 @@ where
}
}

impl<Connection: transport::Connection, Format: EncodingFormat> ServiceClient<Connection, Format>
for AuthServiceClient<Connection, Format>
impl<Connection: transport::ClientConnection, Format: EncodingFormat>
ServiceClient<Connection, Format> for AuthServiceClient<Connection, Format>
{
const SERVICE_NAME: &'static str = SERVICE_NAME;
const SERVICE_CHECKSUM: &'static [u8] = SERVICE_CHECKSUM;
Expand Down
10 changes: 5 additions & 5 deletions rustyrpc/examples/common/hello_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ where
}
}

pub struct HelloServiceClient<Connection: transport::Connection, Format: EncodingFormat>
pub struct HelloServiceClient<Connection: transport::ClientConnection, Format: EncodingFormat>
where
for<'a> RequestKind<'a>: Encode<Format>,
ServiceCallRequestResult: Decode<Format>,
Expand All @@ -82,8 +82,8 @@ where
rpc_client: Arc<Client<Connection, Format>>,
}

impl<Connection: transport::Connection, Format: EncodingFormat> ServiceClient<Connection, Format>
for HelloServiceClient<Connection, Format>
impl<Connection: transport::ClientConnection, Format: EncodingFormat>
ServiceClient<Connection, Format> for HelloServiceClient<Connection, Format>
where
for<'a> RequestKind<'a>: Encode<Format>,
ServiceCallRequestResult: Decode<Format>,
Expand All @@ -105,7 +105,7 @@ where
}
}

impl<Connection: transport::Connection, Format: EncodingFormat>
impl<Connection: transport::ClientConnection, Format: EncodingFormat>
HelloServiceClient<Connection, Format>
where
for<'a> RequestKind<'a>: Encode<Format>,
Expand All @@ -128,7 +128,7 @@ where
}
}

impl<Connection: transport::Connection, Format: EncodingFormat> Drop
impl<Connection: transport::ClientConnection, Format: EncodingFormat> Drop
for HelloServiceClient<Connection, Format>
where
for<'a> RequestKind<'a>: Encode<Format>,
Expand Down
7 changes: 4 additions & 3 deletions rustyrpc/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@ use crate::{
};

/// RPC client for calling remote services.
pub struct Client<Connection: transport::Connection, Format: format::EncodingFormat> {
pub struct Client<Connection: transport::ClientConnection, Format: format::EncodingFormat> {
connection: Mutex<DropOwned<ConnectionCloseOnDrop<Connection>>>,
_format: PhantomData<Format>,
}

impl<Connection: transport::Connection, Format: format::EncodingFormat> Client<Connection, Format>
impl<Connection: transport::ClientConnection, Format: format::EncodingFormat>
Client<Connection, Format>
where
for<'a> RequestKind<'a>: Encode<Format>,
{
Expand Down Expand Up @@ -125,7 +126,7 @@ where
}
}

impl<Connection: transport::Connection, Format: EncodingFormat> From<Connection>
impl<Connection: transport::ClientConnection, Format: EncodingFormat> From<Connection>
for Client<Connection, Format>
{
fn from(connection: Connection) -> Self {
Expand Down
7 changes: 4 additions & 3 deletions rustyrpc/src/server/client_connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@ use crate::{

use super::call_stream::CallStream;

pub(crate) struct ClientConnection<Connection: transport::Connection, Format: EncodingFormat> {
pub(crate) struct ClientConnection<Connection: transport::ServerConnection, Format: EncodingFormat>
{
connection: Connection,
_format: PhantomData<Format>,
}

impl<Connection: transport::Connection, Format: EncodingFormat>
impl<Connection: transport::ServerConnection, Format: EncodingFormat>
ClientConnection<Connection, Format>
{
pub(crate) async fn accept_call_stream(
Expand All @@ -23,7 +24,7 @@ impl<Connection: transport::Connection, Format: EncodingFormat>
}
}

impl<Connection: transport::Connection, Format: EncodingFormat> From<Connection>
impl<Connection: transport::ServerConnection, Format: EncodingFormat> From<Connection>
for ClientConnection<Connection, Format>
{
fn from(connection: Connection) -> Self {
Expand Down
2 changes: 1 addition & 1 deletion rustyrpc/src/server/private_service/service_ref.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ impl ServiceRef {
/// Creates service client from reference and [rpc client][Client]
pub fn into_client<
ServiceClient: service::ServiceClient<Connection, Format>,
Connection: transport::Connection,
Connection: transport::ClientConnection,
Format: EncodingFormat,
>(
self,
Expand Down
2 changes: 1 addition & 1 deletion rustyrpc/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::{
};

/// Service client for interaction with specific remote service.
pub trait ServiceClient<Connection: transport::Connection, Format: EncodingFormat>
pub trait ServiceClient<Connection: transport::ClientConnection, Format: EncodingFormat>
where
Self: Sized,
{
Expand Down
19 changes: 15 additions & 4 deletions rustyrpc/src/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,23 +40,34 @@ impl<T: Stream> T {
}
}

/// Transport specific incoming connection.
/// Transport specific connection.
pub trait Connection: Send + 'static {
/// Close connection.
fn close(self) -> impl Future<Output = io::Result<()>> + Send;
}

/// Transport specific connection on client side.
pub trait ClientConnection: Connection {
/// Stream produced by connection.
type Stream: Stream + 'static;

/// Create new stream and notify other side of connection about it.
fn new_stream(&mut self) -> impl Future<Output = io::Result<Self::Stream>> + Send;
}

/// Transport specific connection on server side.
pub trait ServerConnection: Connection {
/// Stream produced by connection.
type Stream: Stream + 'static;

/// Accept new stream created by other side of connection.
fn accept_stream(&mut self) -> impl Future<Output = io::Result<Self::Stream>> + Send;
/// Close connection.
fn close(self) -> impl Future<Output = io::Result<()>> + Send;
}

/// Transport specific incoming connections listener like a [`TcpListener`][`std::net::TcpListener`] or others
pub trait ConnectionListener: Send {
/// Connection produced by listener
type Connection: Connection;
type Connection: ServerConnection;

/// Accepts a new connection
fn accept_connection(&mut self) -> impl Future<Output = io::Result<Self::Connection>>;
Expand Down
55 changes: 4 additions & 51 deletions rustyrpc/src/transport/quic.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
mod connection;
mod listener;

use core::{net::SocketAddr, num::TryFromIntError};
use quinn::{ClientConfig, Endpoint, RecvStream, SendStream, StoppedError, VarInt};
use core::num::TryFromIntError;
use quinn::{RecvStream, SendStream, StoppedError};
use std::io;
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader, BufWriter};

pub use connection::Connection;
pub use listener::ConnectionListener;
use thiserror::Error;

/// Connection via QUIC protocol.
pub struct Connection(quinn::Connection);

/// Stream via QUIC protocol.
pub struct Stream {
send_stream: BufWriter<SendStream>,
Expand Down Expand Up @@ -92,49 +91,3 @@ impl From<(SendStream, RecvStream)> for Stream {
}
}
}

impl super::Connection for Connection {
type Stream = Stream;

async fn new_stream(&mut self) -> io::Result<Self::Stream> {
Ok(self.0.open_bi().await?.into())
}

async fn accept_stream(&mut self) -> io::Result<Self::Stream> {
Ok(self.0.accept_bi().await?.into())
}

async fn close(self) -> io::Result<()> {
self.0.close(VarInt::from_u32(0), b"Client is closed");
Ok(())
}
}

impl Connection {
/// Establishes connection to server via QUIC protocol.
///
/// # Errors
/// Returns error on fail of connection establishment.
pub async fn connect(
client_config: ClientConfig,
local_address: SocketAddr,
address: SocketAddr,
server_name: &str,
) -> io::Result<Self> {
let mut endpoint = Endpoint::client(local_address)?;
endpoint.set_default_client_config(client_config);

let connection = endpoint
.connect(address, server_name)
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))?
.await?;

Ok(connection.into())
}
}

impl From<quinn::Connection> for Connection {
fn from(connection: quinn::Connection) -> Self {
Self(connection)
}
}
63 changes: 63 additions & 0 deletions rustyrpc/src/transport/quic/connection.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
use core::net::SocketAddr;
use std::io;

use quinn::{ClientConfig, Endpoint, VarInt};

use crate::transport;

use super::Stream;

/// Connection via QUIC protocol.
pub struct Connection(quinn::Connection);

impl transport::Connection for Connection {
async fn close(self) -> io::Result<()> {
self.0.close(VarInt::from_u32(0), b"Client is closed");
Ok(())
}
}

impl transport::ServerConnection for Connection {
type Stream = Stream;

async fn accept_stream(&mut self) -> io::Result<Self::Stream> {
Ok(self.0.accept_bi().await?.into())
}
}

impl transport::ClientConnection for Connection {
type Stream = Stream;

async fn new_stream(&mut self) -> io::Result<Self::Stream> {
Ok(self.0.open_bi().await?.into())
}
}

impl Connection {
/// Establishes connection to server via QUIC protocol.
///
/// # Errors
/// Returns error on fail of connection establishment.
pub async fn connect(
client_config: ClientConfig,
local_address: SocketAddr,
address: SocketAddr,
server_name: &str,
) -> io::Result<Self> {
let mut endpoint = Endpoint::client(local_address)?;
endpoint.set_default_client_config(client_config);

let connection = endpoint
.connect(address, server_name)
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))?
.await?;

Ok(connection.into())
}
}

impl From<quinn::Connection> for Connection {
fn from(connection: quinn::Connection) -> Self {
Self(connection)
}
}
4 changes: 2 additions & 2 deletions rustyrpc/src/transport/quic/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::io;

use quinn::{Endpoint, ServerConfig};

use super::Connection;
use super::connection::Connection;

/// Listener for incoming connections via QUIC protocol.
pub struct ConnectionListener(quinn::Endpoint);
Expand All @@ -18,7 +18,7 @@ impl crate::transport::ConnectionListener for ConnectionListener {
.await
.ok_or_else(|| io::Error::new(io::ErrorKind::NotConnected, "Endpoint is closed"))?
.await
.map(Connection)?)
.map(Into::into)?)
}
}

Expand Down

0 comments on commit 9bb8c13

Please sign in to comment.