diff --git a/Cargo.lock b/Cargo.lock index 158d0e31..60996af2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -440,6 +440,7 @@ dependencies = [ "tokio", "tokio-stream", "tonic", + "tonic-middleware", "walkdir", ] @@ -918,9 +919,9 @@ dependencies = [ [[package]] name = "hyper" -version = "1.3.1" +version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe575dd17d0862a9a33781c8c4696a55c320909004a67a00fb286ba8b1bc496d" +checksum = "50dfd22e0e76d0f662d429a5f80fcaf3855009297eab6a0a9f8543834744ba05" dependencies = [ "bytes", "futures-channel", @@ -958,7 +959,7 @@ checksum = "a0bea761b46ae2b24eb4aef630d8d1c398157b6fc29e6350ecf090a0b70c952c" dependencies = [ "futures-util", "http 1.1.0", - "hyper 1.3.1", + "hyper 1.4.1", "hyper-util", "rustls 0.22.4", "rustls-pki-types", @@ -981,16 +982,16 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.3" +version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca38ef113da30126bbff9cd1705f9273e15d45498615d138b0c20279ac7a76aa" +checksum = "cde7055719c54e36e95e8719f95883f22072a48ede39db7fc17a4e1d5281e9b9" dependencies = [ "bytes", "futures-channel", "futures-util", "http 1.1.0", "http-body 1.0.0", - "hyper 1.3.1", + "hyper 1.4.1", "pin-project-lite", "socket2", "tokio", @@ -1639,7 +1640,7 @@ dependencies = [ "http 1.1.0", "http-body 1.0.0", "http-body-util", - "hyper 1.3.1", + "hyper 1.4.1", "hyper-rustls 0.26.0", "hyper-util", "ipnet", @@ -2361,6 +2362,18 @@ dependencies = [ "syn 2.0.66", ] +[[package]] +name = "tonic-middleware" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92d34dab0f18194ddb9164685a3d8cf777ff35042752aba2be208b1384d7a304" +dependencies = [ + "async-trait", + "futures-util", + "tonic", + "tower", +] + [[package]] name = "tower" version = "0.4.13" diff --git a/Cargo.toml b/Cargo.toml index c0fb3404..bd9e10c2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -65,6 +65,7 @@ time = { version = "0.3", features = ["formatting", "parsing"] } tokio.workspace = true tokio-stream.workspace = true tonic.workspace = true +tonic-middleware = "0.1.4" [build-dependencies] walkdir = "2.3" diff --git a/lua/DCS-gRPC/grpc-mission.lua b/lua/DCS-gRPC/grpc-mission.lua index da41487e..e1d90a50 100644 --- a/lua/DCS-gRPC/grpc-mission.lua +++ b/lua/DCS-gRPC/grpc-mission.lua @@ -3,6 +3,7 @@ if not GRPC then -- scaffold nested tables to allow direct assignment in config file tts = { provider = { gcloud = {}, aws = {}, azure = {}, win = {} } }, srs = {}, + auth = {} } end diff --git a/lua/DCS-gRPC/grpc.lua b/lua/DCS-gRPC/grpc.lua index 6fa42a7b..b6ee79b9 100644 --- a/lua/DCS-gRPC/grpc.lua +++ b/lua/DCS-gRPC/grpc.lua @@ -21,6 +21,7 @@ if isMissionEnv then integrityCheckDisabled = GRPC.integrityCheckDisabled, tts = GRPC.tts, srs = GRPC.srs, + auth = GRPC.auth })) end diff --git a/lua/Hooks/DCS-gRPC.lua b/lua/Hooks/DCS-gRPC.lua index c14e8a97..09f659b7 100644 --- a/lua/Hooks/DCS-gRPC.lua +++ b/lua/Hooks/DCS-gRPC.lua @@ -9,6 +9,7 @@ local function init() -- scaffold nested tables to allow direct assignment in config file tts = { provider = { gcloud = {}, aws = {}, azure = {}, win = {} } }, srs = {}, + auth = {} } end diff --git a/src/authentication.rs b/src/authentication.rs new file mode 100644 index 00000000..69dde31a --- /dev/null +++ b/src/authentication.rs @@ -0,0 +1,27 @@ +use crate::config::AuthConfig; +use tonic::codegen::http::Request; +use tonic::transport::Body; +use tonic::{async_trait, Status}; +use tonic_middleware::RequestInterceptor; + +#[derive(Clone)] +pub struct AuthInterceptor { + pub auth_config: AuthConfig, +} + +#[async_trait] +impl RequestInterceptor for AuthInterceptor { + async fn intercept(&self, req: Request) -> Result, Status> { + match req.headers().get("X-API-Key").map(|v| v.to_str()) { + Some(Ok(token)) => { + //check if token is correct if auth is enabled + if self.auth_config.enabled == false || token == self.auth_config.token { + Ok(req) + } else { + Err(Status::unauthenticated("Unauthenticated")) + } + } + _ => Err(Status::unauthenticated("Unauthenticated")), + } + } +} diff --git a/src/config.rs b/src/config.rs index f626090b..a24395a6 100644 --- a/src/config.rs +++ b/src/config.rs @@ -21,6 +21,7 @@ pub struct Config { pub integrity_check_disabled: bool, pub tts: Option, pub srs: Option, + pub auth: Option, } #[derive(Debug, Clone, Default, Deserialize, Serialize)] @@ -87,6 +88,14 @@ pub struct SrsConfig { pub addr: Option, } +#[derive(Debug, Clone, Default, Deserialize, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct AuthConfig { + #[serde(default)] + pub enabled: bool, + pub token: String, +} + fn default_host() -> String { String::from("127.0.0.1") } diff --git a/src/lib.rs b/src/lib.rs index fa518a1b..888ba849 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ #![allow(dead_code)] #![recursion_limit = "256"] +mod authentication; mod config; mod fps; #[cfg(feature = "hot-reload")] diff --git a/src/server.rs b/src/server.rs index c70b8acf..28803b1b 100644 --- a/src/server.rs +++ b/src/server.rs @@ -3,6 +3,12 @@ use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; +use crate::authentication::AuthInterceptor; +use crate::config::{AuthConfig, Config, SrsConfig, TtsConfig}; +use crate::rpc::{HookRpc, MissionRpc, Srs}; +use crate::shutdown::{Shutdown, ShutdownHandle}; +use crate::srs::SrsClients; +use crate::stats::Stats; use dcs_module_ipc::IPC; use futures_util::FutureExt; use stubs::atmosphere::v0::atmosphere_service_server::AtmosphereServiceServer; @@ -25,12 +31,7 @@ use tokio::sync::oneshot::{self, Receiver}; use tokio::sync::{mpsc, Mutex}; use tokio::time::sleep; use tonic::transport; - -use crate::config::{Config, SrsConfig, TtsConfig}; -use crate::rpc::{HookRpc, MissionRpc, Srs}; -use crate::shutdown::{Shutdown, ShutdownHandle}; -use crate::srs::SrsClients; -use crate::stats::Stats; +use tonic_middleware::RequestInterceptorLayer; pub struct Server { runtime: Runtime, @@ -50,6 +51,7 @@ struct ServerState { tts_config: TtsConfig, srs_config: SrsConfig, srs_transmit: Arc>>, + auth_config: AuthConfig, } impl Server { @@ -71,6 +73,7 @@ impl Server { tts_config: config.tts.clone().unwrap_or_default(), srs_config: config.srs.clone().unwrap_or_default(), srs_transmit: Arc::new(Mutex::new(rx)), + auth_config: config.auth.clone().unwrap_or_default(), }, srs_transmit: tx, shutdown, @@ -203,6 +206,7 @@ async fn try_run( tts_config, srs_config, srs_transmit, + auth_config, } = state; let mut mission_rpc = @@ -242,7 +246,14 @@ async fn try_run( } }); + let auth_interceptor = AuthInterceptor { + auth_config: auth_config.clone(), + }; + + log::info!("Authentication enabled: {}", auth_config.enabled); + transport::Server::builder() + .layer(RequestInterceptorLayer::new(auth_interceptor.clone())) .add_service(AtmosphereServiceServer::new(mission_rpc.clone())) .add_service(CoalitionServiceServer::new(mission_rpc.clone())) .add_service(ControllerServiceServer::new(mission_rpc.clone()))