Skip to content

Commit

Permalink
ServiceCallHandler moved to server/call_handler.rs. Some ServiceCallH…
Browse files Browse the repository at this point in the history
…andler refactoring
  • Loading branch information
AlexSherbinin committed Feb 25, 2024
1 parent 227ffd3 commit 28e7d2d
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 80 deletions.
85 changes: 5 additions & 80 deletions rustyrpc/src/server.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,22 @@
mod builder;
mod call_handler;
mod call_stream;
mod client_connection;
mod private_service;
mod task_pool;

use self::{call_stream::CallHandler, client_connection::ClientConnection, task_pool::TaskPool};
use self::{client_connection::ClientConnection, task_pool::TaskPool};
use crate::{
format::{
DecodeZeroCopy, DecodeZeroCopyFallible, Encode, EncodingFormat, ZeroCopyEncodingFormat,
},
protocol::{
RemoteServiceIdRequestError, RequestKind, ServiceCallRequestError,
ServiceCallRequestResult, ServiceIdRequestResult, ServiceKind,
},
protocol::{RequestKind, ServiceCallRequestResult, ServiceIdRequestResult},
server::call_handler::ServerCallHandler,
service::Service,
transport,
};
use alloc::sync::Arc;
use core::marker::PhantomData;
use derive_where::derive_where;
use futures::lock::Mutex;
use log::trace;
use std::collections::HashMap;
Expand Down Expand Up @@ -73,10 +71,7 @@ where
) {
trace!("New connection accepted");

let call_handler = ServerCallHandler {
server: Arc::clone(&self),
private_service_allocator: Arc::default(),
};
let call_handler = ServerCallHandler::new_for_connection(Arc::clone(&self));

loop {
let call_stream = connection.accept_call_stream().await.unwrap();
Expand All @@ -88,73 +83,3 @@ where
}
}
}

#[derive_where(Clone)]
struct ServerCallHandler<Listener: transport::ConnectionListener, Format: EncodingFormat> {
server: Arc<Server<Listener, Format>>,
private_service_allocator: Arc<PrivateServiceAllocator<Format>>,
}

impl<Listener: transport::ConnectionListener, Format: EncodingFormat> CallHandler
for ServerCallHandler<Listener, Format>
{
async fn handle_call(
self,
kind: ServiceKind,
service_id: u32,
function_id: u32,
args: Vec<u8>,
) -> Result<Vec<u8>, ServiceCallRequestError> {
trace!("Received service call. Kind: {kind:?}, service id: {service_id}, function_id: {function_id}");

#[allow(clippy::map_err_ignore)]
let service_id: usize = service_id
.try_into()
.map_err(|_| ServiceCallRequestError::InvalidServiceId)?;

match kind {
ServiceKind::Public if let Some(service) = self.server.services.get(service_id) => {
service
.call(
Arc::clone(&self.private_service_allocator),
function_id,
args,
)
.await
}
ServiceKind::Private
if let Some(service) = self.private_service_allocator.get(service_id).await =>
{
service
.call(
Arc::clone(&self.private_service_allocator),
function_id,
args,
)
.await
}
ServiceKind::Public | ServiceKind::Private => {
Err(ServiceCallRequestError::InvalidServiceId)
}
}
}

async fn handle_service_request(
self,
name: &str,
checksum: &[u8],
) -> Result<u32, RemoteServiceIdRequestError> {
trace!("Received service request. Service name: {name}, checksum: {checksum:?}");

let (expected_checksum, service_id) = self
.server
.service_map
.get(name)
.ok_or(RemoteServiceIdRequestError::ServiceNotFound)?;
if checksum == &**expected_checksum {
Ok(*service_id)
} else {
Err(RemoteServiceIdRequestError::InvalidChecksum)
}
}
}
114 changes: 114 additions & 0 deletions rustyrpc/src/server/call_handler.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
use super::{call_stream::CallHandler, PrivateServiceAllocator, Server};
use crate::{
format::EncodingFormat,
protocol::{RemoteServiceIdRequestError, ServiceCallRequestError, ServiceKind},
service::Service,
transport,
};
use alloc::sync::Arc;
use derive_where::derive_where;
use log::trace;

#[derive_where(Clone)]
pub(super) struct ServerCallHandler<Listener: transport::ConnectionListener, Format: EncodingFormat>
{
server: Arc<Server<Listener, Format>>,
private_service_allocator: Arc<PrivateServiceAllocator<Format>>,
}

impl<Listener: transport::ConnectionListener, Format: EncodingFormat>
ServerCallHandler<Listener, Format>
{
pub(super) fn new_for_connection(server: Arc<Server<Listener, Format>>) -> Self {
Self {
server,
private_service_allocator: Arc::default(),
}
}

async fn handle_private_service_call(
self,
service: super::private_service::ServiceRefLock<'_, Format>,
function_id: u32,
args: Vec<u8>,
) -> Result<Vec<u8>, ServiceCallRequestError> {
service
.call(
Arc::clone(&self.private_service_allocator),
function_id,
args,
)
.await
}

async fn handle_public_service_call(
self,
service: &dyn Service<Format>,
function_id: u32,
args: Vec<u8>,
) -> Result<Vec<u8>, ServiceCallRequestError> {
service
.call(
Arc::clone(&self.private_service_allocator),
function_id,
args,
)
.await
}
}

impl<Listener: transport::ConnectionListener, Format: EncodingFormat> CallHandler
for ServerCallHandler<Listener, Format>
{
async fn handle_call(
self,
kind: ServiceKind,
service_id: u32,
function_id: u32,
args: Vec<u8>,
) -> Result<Vec<u8>, ServiceCallRequestError> {
trace!("Received service call. Kind: {kind:?}, service id: {service_id}, function_id: {function_id}");

#[allow(clippy::map_err_ignore)]
let service_id: usize = service_id
.try_into()
.map_err(|_| ServiceCallRequestError::InvalidServiceId)?;

match kind {
ServiceKind::Public if let Some(service) = self.server.services.get(service_id) => {
self.clone()
.handle_public_service_call(service.as_ref(), function_id, args)
.await
}
ServiceKind::Private
if let Some(service) = self.private_service_allocator.get(service_id).await =>
{
self.clone()
.handle_private_service_call(service, function_id, args)
.await
}
ServiceKind::Public | ServiceKind::Private => {
Err(ServiceCallRequestError::InvalidServiceId)
}
}
}

async fn handle_service_request(
self,
name: &str,
checksum: &[u8],
) -> Result<u32, RemoteServiceIdRequestError> {
trace!("Received service request. Service name: {name}, checksum: {checksum:?}");

let (expected_checksum, service_id) = self
.server
.service_map
.get(name)
.ok_or(RemoteServiceIdRequestError::ServiceNotFound)?;
if checksum == &**expected_checksum {
Ok(*service_id)
} else {
Err(RemoteServiceIdRequestError::InvalidChecksum)
}
}
}

0 comments on commit 28e7d2d

Please sign in to comment.