diff --git a/rustyrpc/src/server.rs b/rustyrpc/src/server.rs index 1d933f2..a7af7ef 100644 --- a/rustyrpc/src/server.rs +++ b/rustyrpc/src/server.rs @@ -1,24 +1,22 @@ mod builder; +mod call_handler; mod call_stream; mod client_connection; mod private_service; mod task_pool; -use self::{call_stream::CallHandler, client_connection::ClientConnection, task_pool::TaskPool}; +use self::{client_connection::ClientConnection, task_pool::TaskPool}; use crate::{ format::{ DecodeZeroCopy, DecodeZeroCopyFallible, Encode, EncodingFormat, ZeroCopyEncodingFormat, }, - protocol::{ - RemoteServiceIdRequestError, RequestKind, ServiceCallRequestError, - ServiceCallRequestResult, ServiceIdRequestResult, ServiceKind, - }, + protocol::{RequestKind, ServiceCallRequestResult, ServiceIdRequestResult}, + server::call_handler::ServerCallHandler, service::Service, transport, }; use alloc::sync::Arc; use core::marker::PhantomData; -use derive_where::derive_where; use futures::lock::Mutex; use log::trace; use std::collections::HashMap; @@ -73,10 +71,7 @@ where ) { trace!("New connection accepted"); - let call_handler = ServerCallHandler { - server: Arc::clone(&self), - private_service_allocator: Arc::default(), - }; + let call_handler = ServerCallHandler::new_for_connection(Arc::clone(&self)); loop { let call_stream = connection.accept_call_stream().await.unwrap(); @@ -88,73 +83,3 @@ where } } } - -#[derive_where(Clone)] -struct ServerCallHandler { - server: Arc>, - private_service_allocator: Arc>, -} - -impl CallHandler - for ServerCallHandler -{ - async fn handle_call( - self, - kind: ServiceKind, - service_id: u32, - function_id: u32, - args: Vec, - ) -> Result, ServiceCallRequestError> { - trace!("Received service call. Kind: {kind:?}, service id: {service_id}, function_id: {function_id}"); - - #[allow(clippy::map_err_ignore)] - let service_id: usize = service_id - .try_into() - .map_err(|_| ServiceCallRequestError::InvalidServiceId)?; - - match kind { - ServiceKind::Public if let Some(service) = self.server.services.get(service_id) => { - service - .call( - Arc::clone(&self.private_service_allocator), - function_id, - args, - ) - .await - } - ServiceKind::Private - if let Some(service) = self.private_service_allocator.get(service_id).await => - { - service - .call( - Arc::clone(&self.private_service_allocator), - function_id, - args, - ) - .await - } - ServiceKind::Public | ServiceKind::Private => { - Err(ServiceCallRequestError::InvalidServiceId) - } - } - } - - async fn handle_service_request( - self, - name: &str, - checksum: &[u8], - ) -> Result { - trace!("Received service request. Service name: {name}, checksum: {checksum:?}"); - - let (expected_checksum, service_id) = self - .server - .service_map - .get(name) - .ok_or(RemoteServiceIdRequestError::ServiceNotFound)?; - if checksum == &**expected_checksum { - Ok(*service_id) - } else { - Err(RemoteServiceIdRequestError::InvalidChecksum) - } - } -} diff --git a/rustyrpc/src/server/call_handler.rs b/rustyrpc/src/server/call_handler.rs new file mode 100644 index 0000000..88c1942 --- /dev/null +++ b/rustyrpc/src/server/call_handler.rs @@ -0,0 +1,114 @@ +use super::{call_stream::CallHandler, PrivateServiceAllocator, Server}; +use crate::{ + format::EncodingFormat, + protocol::{RemoteServiceIdRequestError, ServiceCallRequestError, ServiceKind}, + service::Service, + transport, +}; +use alloc::sync::Arc; +use derive_where::derive_where; +use log::trace; + +#[derive_where(Clone)] +pub(super) struct ServerCallHandler +{ + server: Arc>, + private_service_allocator: Arc>, +} + +impl + ServerCallHandler +{ + pub(super) fn new_for_connection(server: Arc>) -> Self { + Self { + server, + private_service_allocator: Arc::default(), + } + } + + async fn handle_private_service_call( + self, + service: super::private_service::ServiceRefLock<'_, Format>, + function_id: u32, + args: Vec, + ) -> Result, ServiceCallRequestError> { + service + .call( + Arc::clone(&self.private_service_allocator), + function_id, + args, + ) + .await + } + + async fn handle_public_service_call( + self, + service: &dyn Service, + function_id: u32, + args: Vec, + ) -> Result, ServiceCallRequestError> { + service + .call( + Arc::clone(&self.private_service_allocator), + function_id, + args, + ) + .await + } +} + +impl CallHandler + for ServerCallHandler +{ + async fn handle_call( + self, + kind: ServiceKind, + service_id: u32, + function_id: u32, + args: Vec, + ) -> Result, ServiceCallRequestError> { + trace!("Received service call. Kind: {kind:?}, service id: {service_id}, function_id: {function_id}"); + + #[allow(clippy::map_err_ignore)] + let service_id: usize = service_id + .try_into() + .map_err(|_| ServiceCallRequestError::InvalidServiceId)?; + + match kind { + ServiceKind::Public if let Some(service) = self.server.services.get(service_id) => { + self.clone() + .handle_public_service_call(service.as_ref(), function_id, args) + .await + } + ServiceKind::Private + if let Some(service) = self.private_service_allocator.get(service_id).await => + { + self.clone() + .handle_private_service_call(service, function_id, args) + .await + } + ServiceKind::Public | ServiceKind::Private => { + Err(ServiceCallRequestError::InvalidServiceId) + } + } + } + + async fn handle_service_request( + self, + name: &str, + checksum: &[u8], + ) -> Result { + trace!("Received service request. Service name: {name}, checksum: {checksum:?}"); + + let (expected_checksum, service_id) = self + .server + .service_map + .get(name) + .ok_or(RemoteServiceIdRequestError::ServiceNotFound)?; + if checksum == &**expected_checksum { + Ok(*service_id) + } else { + Err(RemoteServiceIdRequestError::InvalidChecksum) + } + } +}