Skip to content

Commit

Permalink
Add Parler based TTS (POC)
Browse files Browse the repository at this point in the history
  • Loading branch information
rkusa committed Nov 10, 2024
1 parent 803b060 commit 1845eea
Show file tree
Hide file tree
Showing 14 changed files with 1,942 additions and 66 deletions.
1,113 changes: 1,099 additions & 14 deletions Cargo.lock

Large diffs are not rendered by default.

18 changes: 11 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,11 @@ throughputLimit = 600
integrityCheckDisabled = false

-- Whether or not authentication is required
auth.enabled = false
-- Authentication tokens table with client names and their tokens for split tokens.
auth.enabled = false
-- Authentication tokens table with client names and their tokens for split tokens.
auth.tokens = {
-- client => clientName, token => Any token. Advice to use UTF-8 only. Length not limited explicitly
{ client = "SomeClient", token = "SomeToken" },
-- client => clientName, token => Any token. Advice to use UTF-8 only. Length not limited explicitly
{ client = "SomeClient", token = "SomeToken" },
{ client = "SomeClient2", token = "SomeOtherToken" }
}

Expand Down Expand Up @@ -124,6 +124,9 @@ tts.provider.gcloud.defaultVoice = "en-GB-Neural2-A"
-- Requires at least Windows Server 2019 to work properly.
tts.provider.win.defaultVoice = "David"

-- The default Parler speaker prompt.
tts.provider.parler.defaultSpeaker = "..."

-- Your SRS server's address.
srs.addr = "127.0.0.1:5002"
```
Expand Down Expand Up @@ -211,6 +214,7 @@ The server will be running on port 50051 by default.
-- `= { azure = {} }` / `= { azure = { voice = "..." } }` enable Azure TTS
-- `= { gcloud = {} }` / `= { gcloud = { voice = "..." } }` enable Google Cloud TTS
-- `= { win = {} }` / `= { win = { voice = "..." } }` enable Windows TTS
-- `= { parler = {} }` / `= { parler = { voice = "...", speed = 1.0 } }` enable Parler TTS
provider = null,
}
```
Expand All @@ -228,7 +232,7 @@ The gRPC .proto files are available in the `Docs/DCS-gRPC` folder and also avail

### Client Authentication

If authentication is enabled on the server you will have to add `X-API-Key` to the metadata/headers.
If authentication is enabled on the server you will have to add `X-API-Key` to the metadata/headers.
Below are some example on what it could look like in your code.

#### Examples
Expand All @@ -238,7 +242,7 @@ Below are some example on what it could look like in your code.

You can either set the `Metadata` for each request or you can create a `GrpcChannel` with an interceptor that will set the key each time.

For a single request:
For a single request:

```c#
var client = new MissionService.MissionServiceClient(channel);
Expand All @@ -251,7 +255,7 @@ Metadata metadata = new Metadata()
var response = client.GetScenarioCurrentTime(new GetScenarioCurrentTimeRequest { }, headers: metadata, deadline: DateTime.UtcNow.AddSeconds(2));
```

For all requests on a channel:
For all requests on a channel:
```c#
public GrpcChannel CreateChannel(string host, string post, string? apiKey)
{
Expand Down
64 changes: 32 additions & 32 deletions lua/DCS-gRPC/grpc-mission.lua
Original file line number Diff line number Diff line change
@@ -1,71 +1,71 @@
if not GRPC then
GRPC = {
-- scaffold nested tables to allow direct assignment in config file
tts = { provider = { gcloud = {}, aws = {}, azure = {}, win = {} } },
srs = {},
auth = { tokens = {} }
}
GRPC = {
-- scaffold nested tables to allow direct assignment in config file
tts = { provider = { gcloud = {}, aws = {}, azure = {}, win = {}, parler = {} } },
srs = {},
auth = { tokens = {} }
}
end

-- load settings from `Saved Games/DCS/Config/dcs-grpc.lua`
do
env.info("[GRPC] Checking optional config at `Config/dcs-grpc.lua` ...")
local file, err = io.open(lfs.writedir() .. [[Config\dcs-grpc.lua]], "r")
if file then
local f = assert(loadstring(file:read("*all")))
setfenv(f, GRPC)
f()
env.info("[GRPC] `Config/dcs-grpc.lua` successfully read")
else
env.info("[GRPC] `Config/dcs-grpc.lua` not found (" .. tostring(err) .. ")")
end
env.info("[GRPC] Checking optional config at `Config/dcs-grpc.lua` ...")
local file, err = io.open(lfs.writedir() .. [[Config\dcs-grpc.lua]], "r")
if file then
local f = assert(loadstring(file:read("*all")))
setfenv(f, GRPC)
f()
env.info("[GRPC] `Config/dcs-grpc.lua` successfully read")
else
env.info("[GRPC] `Config/dcs-grpc.lua` not found (" .. tostring(err) .. ")")
end
end

-- Set default settings.
if not GRPC.luaPath then
GRPC.luaPath = lfs.writedir() .. [[Scripts\DCS-gRPC\]]
GRPC.luaPath = lfs.writedir() .. [[Scripts\DCS-gRPC\]]
end
if not GRPC.dllPath then
GRPC.dllPath = lfs.writedir() .. [[Mods\tech\DCS-gRPC\]]
GRPC.dllPath = lfs.writedir() .. [[Mods\tech\DCS-gRPC\]]
end
if GRPC.throughputLimit == nil or GRPC.throughputLimit == 0 or type(GRPC.throughputLimit) ~= "number" then
GRPC.throughputLimit = 600
GRPC.throughputLimit = 600
end

-- load version
dofile(GRPC.luaPath .. [[version.lua]])

-- Let DCS know where to find the DLLs
if not string.find(package.cpath, GRPC.dllPath) then
package.cpath = package.cpath .. [[;]] .. GRPC.dllPath .. [[?.dll;]]
package.cpath = package.cpath .. [[;]] .. GRPC.dllPath .. [[?.dll;]]
end

-- Load DLL before `require` gets sanitized.
local ok, grpc = pcall(require, "dcs_grpc_hot_reload")
if ok then
env.info("[GRPC] loaded hot reload version")
env.info("[GRPC] loaded hot reload version")
else
grpc = require("dcs_grpc")
grpc = require("dcs_grpc")
end

-- Keep a reference to `lfs` before it gets sanitized
local lfs = _G.lfs

local loaded = false
function GRPC.load()
if loaded then
env.info("[GRPC] already loaded")
return
end
if loaded then
env.info("[GRPC] already loaded")
return
end

local env = setmetatable({grpc = grpc, lfs = lfs}, {__index = _G})
local f = setfenv(assert(loadfile(GRPC.luaPath .. [[grpc.lua]])), env)
f()
local env = setmetatable({ grpc = grpc, lfs = lfs }, { __index = _G })
local f = setfenv(assert(loadfile(GRPC.luaPath .. [[grpc.lua]])), env)
f()

loaded = true
loaded = true
end

if GRPC.autostart == true then
env.info("[GRPC] auto starting")
GRPC.load()
env.info("[GRPC] auto starting")
GRPC.load()
end
2 changes: 1 addition & 1 deletion lua/Hooks/DCS-gRPC.lua
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ local function init()
if not GRPC then
_G.GRPC = {
-- scaffold nested tables to allow direct assignment in config file
tts = { provider = { gcloud = {}, aws = {}, azure = {}, win = {} } },
tts = { provider = { gcloud = {}, aws = {}, azure = {}, win = {}, parler = {} } },
srs = {},
auth = { tokens = {} }
}
Expand Down
6 changes: 6 additions & 0 deletions protos/dcs/srs/v0/srs.proto
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,19 @@ message TransmitRequest {
optional string voice = 1;
}

message Parler {
optional string speaker = 1;
}

// Optional TTS provider to be use. Defaults to the one configured in your
// config or to Windows' built-in TTS.
oneof provider {
Aws aws = 8;
Azure azure = 9;
GCloud gcloud = 10;
Windows win = 11;
// Parler does not support SSML, only use it with plain text.
Parler parler = 12;
}
}

Expand Down
3 changes: 2 additions & 1 deletion src/authentication.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use crate::config::AuthConfig;
use tonic::codegen::http::Request;
use tonic::transport::Body;
use tonic::{async_trait, Status};
use tonic_middleware::RequestInterceptor;

use crate::config::AuthConfig;

#[derive(Clone)]
pub struct AuthInterceptor {
pub auth_config: AuthConfig,
Expand Down
8 changes: 8 additions & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ pub struct TtsProviderConfig {
pub azure: Option<AzureConfig>,
pub gcloud: Option<GCloudConfig>,
pub win: Option<WinConfig>,
pub parler: Option<ParlerConfig>,
}

#[derive(Debug, Clone, Default, Deserialize, Serialize)]
Expand All @@ -49,6 +50,7 @@ pub enum TtsProvider {
GCloud,
#[default]
Win,
Parler,
}

#[derive(Clone, Deserialize, Serialize)]
Expand Down Expand Up @@ -81,6 +83,12 @@ pub struct WinConfig {
pub default_voice: Option<String>,
}

#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ParlerConfig {
pub default_speaker: Option<String>,
}

#[derive(Debug, Clone, Default, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct SrsConfig {
Expand Down
3 changes: 1 addition & 2 deletions src/rpc/metadata.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use stubs::metadata::v0::metadata_service_server::MetadataService;
use stubs::*;
use tonic::async_trait;
use tonic::{Request, Response, Status};
use tonic::{async_trait, Request, Response, Status};

use super::MissionRpc;

Expand Down
27 changes: 26 additions & 1 deletion src/rpc/srs.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
use std::error;
use std::future::Future;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::PathBuf;
use std::str::FromStr;
use std::time::{Duration, Instant};

use ::srs::Sender;
#[cfg(target_os = "windows")]
use ::tts::WinConfig;
use ::tts::{AwsConfig, AwsRegion, AzureConfig, GCloudConfig, TtsConfig};
use ::tts::{AwsConfig, AwsRegion, AzureConfig, GCloudConfig, ParlerConfig, TtsConfig};
use futures_util::FutureExt;
use stubs::common::v0::{Coalition, Unit};
use stubs::mission::v0::stream_events_response::{Event, TtsEvent};
Expand All @@ -27,6 +28,7 @@ use crate::srs::SrsClients;
pub struct Srs {
tts_config: crate::config::TtsConfig,
srs_config: crate::config::SrsConfig,
write_dir: PathBuf,
rpc: MissionRpc,
srs_clients: SrsClients,
shutdown_signal: ShutdownHandle,
Expand All @@ -36,13 +38,15 @@ impl Srs {
pub fn new(
tts_config: crate::config::TtsConfig,
srs_config: crate::config::SrsConfig,
write_dir: PathBuf,
rpc: MissionRpc,
srs_clients: SrsClients,
shutdown_signal: ShutdownHandle,
) -> Self {
Self {
tts_config,
srs_config,
write_dir,
rpc,
srs_clients,
shutdown_signal,
Expand Down Expand Up @@ -105,6 +109,9 @@ impl SrsService for Srs {
TtsProvider::Win => {
transmit_request::Provider::Win(transmit_request::Windows { voice: None })
}
TtsProvider::Parler => {
transmit_request::Provider::Parler(transmit_request::Parler { speaker: None })
}
}) {
transmit_request::Provider::Aws(transmit_request::Aws { voice }) => {
TtsConfig::Aws(AwsConfig {
Expand Down Expand Up @@ -215,6 +222,24 @@ impl SrsService for Srs {
"Windows TTS is only available on Windows",
));
}
transmit_request::Provider::Parler(transmit_request::Parler { speaker }) => {
TtsConfig::Parler(ParlerConfig {
speaker: speaker
.or_else(|| {
self.tts_config
.provider
.as_ref()
.and_then(|p| p.parler.as_ref())
.and_then(|p| p.default_speaker.clone())
})
.filter(|v| !v.is_empty())
.ok_or_else(|| {
Status::failed_precondition(
"tts.provider.parler.default_speaker not set",
)
})?,
})
}
};

let frames = ::tts::synthesize(&request.ssml, &config)
Expand Down
19 changes: 13 additions & 6 deletions src/server.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
use std::future::Future;
use std::net::SocketAddr;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;

use crate::authentication::AuthInterceptor;
use crate::config::{AuthConfig, Config, SrsConfig, TtsConfig};
use crate::rpc::{HookRpc, MissionRpc, Srs};
use crate::shutdown::{Shutdown, ShutdownHandle};
use crate::srs::SrsClients;
use crate::stats::Stats;
use dcs_module_ipc::IPC;
use futures_util::FutureExt;
use stubs::atmosphere::v0::atmosphere_service_server::AtmosphereServiceServer;
Expand All @@ -34,6 +29,13 @@ use tokio::time::sleep;
use tonic::transport;
use tonic_middleware::RequestInterceptorLayer;

use crate::authentication::AuthInterceptor;
use crate::config::{AuthConfig, Config, SrsConfig, TtsConfig};
use crate::rpc::{HookRpc, MissionRpc, Srs};
use crate::shutdown::{Shutdown, ShutdownHandle};
use crate::srs::SrsClients;
use crate::stats::Stats;

pub struct Server {
runtime: Runtime,
shutdown: Shutdown,
Expand All @@ -51,6 +53,7 @@ struct ServerState {
stats: Stats,
tts_config: TtsConfig,
srs_config: SrsConfig,
write_dir: PathBuf,
srs_transmit: Arc<Mutex<mpsc::Receiver<TransmitRequest>>>,
auth_config: AuthConfig,
}
Expand All @@ -73,6 +76,7 @@ impl Server {
stats: Stats::new(shutdown.handle()),
tts_config: config.tts.clone().unwrap_or_default(),
srs_config: config.srs.clone().unwrap_or_default(),
write_dir: PathBuf::from(&config.write_dir),
srs_transmit: Arc::new(Mutex::new(rx)),
auth_config: config.auth.clone().unwrap_or_default(),
},
Expand Down Expand Up @@ -206,6 +210,7 @@ async fn try_run(
stats,
tts_config,
srs_config,
write_dir,
srs_transmit,
auth_config,
} = state;
Expand All @@ -230,6 +235,7 @@ async fn try_run(
let srs = Srs::new(
tts_config.clone(),
srs_config.clone(),
write_dir.clone(),
mission_rpc.clone(),
srs_clients.clone(),
shutdown_signal.clone(),
Expand Down Expand Up @@ -269,6 +275,7 @@ async fn try_run(
.add_service(SrsServiceServer::new(Srs::new(
tts_config,
srs_config,
write_dir,
mission_rpc.clone(),
srs_clients,
shutdown_signal.clone(),
Expand Down
Loading

0 comments on commit 1845eea

Please sign in to comment.