Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Parler (generative speech generation) based TTS (proof of conept) #274

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,113 changes: 1,099 additions & 14 deletions Cargo.lock

Large diffs are not rendered by default.

18 changes: 11 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,11 @@ throughputLimit = 600
integrityCheckDisabled = false

-- Whether or not authentication is required
auth.enabled = false
-- Authentication tokens table with client names and their tokens for split tokens.
auth.enabled = false
-- Authentication tokens table with client names and their tokens for split tokens.
auth.tokens = {
-- client => clientName, token => Any token. Advice to use UTF-8 only. Length not limited explicitly
{ client = "SomeClient", token = "SomeToken" },
-- client => clientName, token => Any token. Advice to use UTF-8 only. Length not limited explicitly
{ client = "SomeClient", token = "SomeToken" },
{ client = "SomeClient2", token = "SomeOtherToken" }
}

Expand Down Expand Up @@ -124,6 +124,9 @@ tts.provider.gcloud.defaultVoice = "en-GB-Neural2-A"
-- Requires at least Windows Server 2019 to work properly.
tts.provider.win.defaultVoice = "David"

-- The default Parler speaker prompt.
tts.provider.parler.defaultSpeaker = "..."

-- Your SRS server's address.
srs.addr = "127.0.0.1:5002"
```
Expand Down Expand Up @@ -211,6 +214,7 @@ The server will be running on port 50051 by default.
-- `= { azure = {} }` / `= { azure = { voice = "..." } }` enable Azure TTS
-- `= { gcloud = {} }` / `= { gcloud = { voice = "..." } }` enable Google Cloud TTS
-- `= { win = {} }` / `= { win = { voice = "..." } }` enable Windows TTS
-- `= { parler = {} }` / `= { parler = { voice = "...", speed = 1.0 } }` enable Parler TTS
provider = null,
}
```
Expand All @@ -228,7 +232,7 @@ The gRPC .proto files are available in the `Docs/DCS-gRPC` folder and also avail

### Client Authentication

If authentication is enabled on the server you will have to add `X-API-Key` to the metadata/headers.
If authentication is enabled on the server you will have to add `X-API-Key` to the metadata/headers.
Below are some example on what it could look like in your code.

#### Examples
Expand All @@ -238,7 +242,7 @@ Below are some example on what it could look like in your code.

You can either set the `Metadata` for each request or you can create a `GrpcChannel` with an interceptor that will set the key each time.

For a single request:
For a single request:

```c#
var client = new MissionService.MissionServiceClient(channel);
Expand All @@ -251,7 +255,7 @@ Metadata metadata = new Metadata()
var response = client.GetScenarioCurrentTime(new GetScenarioCurrentTimeRequest { }, headers: metadata, deadline: DateTime.UtcNow.AddSeconds(2));
```

For all requests on a channel:
For all requests on a channel:
```c#
public GrpcChannel CreateChannel(string host, string post, string? apiKey)
{
Expand Down
64 changes: 32 additions & 32 deletions lua/DCS-gRPC/grpc-mission.lua
Original file line number Diff line number Diff line change
@@ -1,71 +1,71 @@
if not GRPC then
GRPC = {
-- scaffold nested tables to allow direct assignment in config file
tts = { provider = { gcloud = {}, aws = {}, azure = {}, win = {} } },
srs = {},
auth = { tokens = {} }
}
GRPC = {
-- scaffold nested tables to allow direct assignment in config file
tts = { provider = { gcloud = {}, aws = {}, azure = {}, win = {}, parler = {} } },
srs = {},
auth = { tokens = {} }
}
end

-- load settings from `Saved Games/DCS/Config/dcs-grpc.lua`
do
env.info("[GRPC] Checking optional config at `Config/dcs-grpc.lua` ...")
local file, err = io.open(lfs.writedir() .. [[Config\dcs-grpc.lua]], "r")
if file then
local f = assert(loadstring(file:read("*all")))
setfenv(f, GRPC)
f()
env.info("[GRPC] `Config/dcs-grpc.lua` successfully read")
else
env.info("[GRPC] `Config/dcs-grpc.lua` not found (" .. tostring(err) .. ")")
end
env.info("[GRPC] Checking optional config at `Config/dcs-grpc.lua` ...")
local file, err = io.open(lfs.writedir() .. [[Config\dcs-grpc.lua]], "r")
if file then
local f = assert(loadstring(file:read("*all")))
setfenv(f, GRPC)
f()
env.info("[GRPC] `Config/dcs-grpc.lua` successfully read")
else
env.info("[GRPC] `Config/dcs-grpc.lua` not found (" .. tostring(err) .. ")")
end
end

-- Set default settings.
if not GRPC.luaPath then
GRPC.luaPath = lfs.writedir() .. [[Scripts\DCS-gRPC\]]
GRPC.luaPath = lfs.writedir() .. [[Scripts\DCS-gRPC\]]
end
if not GRPC.dllPath then
GRPC.dllPath = lfs.writedir() .. [[Mods\tech\DCS-gRPC\]]
GRPC.dllPath = lfs.writedir() .. [[Mods\tech\DCS-gRPC\]]
end
if GRPC.throughputLimit == nil or GRPC.throughputLimit == 0 or type(GRPC.throughputLimit) ~= "number" then
GRPC.throughputLimit = 600
GRPC.throughputLimit = 600
end

-- load version
dofile(GRPC.luaPath .. [[version.lua]])

-- Let DCS know where to find the DLLs
if not string.find(package.cpath, GRPC.dllPath) then
package.cpath = package.cpath .. [[;]] .. GRPC.dllPath .. [[?.dll;]]
package.cpath = package.cpath .. [[;]] .. GRPC.dllPath .. [[?.dll;]]
end

-- Load DLL before `require` gets sanitized.
local ok, grpc = pcall(require, "dcs_grpc_hot_reload")
if ok then
env.info("[GRPC] loaded hot reload version")
env.info("[GRPC] loaded hot reload version")
else
grpc = require("dcs_grpc")
grpc = require("dcs_grpc")
end

-- Keep a reference to `lfs` before it gets sanitized
local lfs = _G.lfs

local loaded = false
function GRPC.load()
if loaded then
env.info("[GRPC] already loaded")
return
end
if loaded then
env.info("[GRPC] already loaded")
return
end

local env = setmetatable({grpc = grpc, lfs = lfs}, {__index = _G})
local f = setfenv(assert(loadfile(GRPC.luaPath .. [[grpc.lua]])), env)
f()
local env = setmetatable({ grpc = grpc, lfs = lfs }, { __index = _G })
local f = setfenv(assert(loadfile(GRPC.luaPath .. [[grpc.lua]])), env)
f()

loaded = true
loaded = true
end

if GRPC.autostart == true then
env.info("[GRPC] auto starting")
GRPC.load()
env.info("[GRPC] auto starting")
GRPC.load()
end
2 changes: 1 addition & 1 deletion lua/Hooks/DCS-gRPC.lua
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ local function init()
if not GRPC then
_G.GRPC = {
-- scaffold nested tables to allow direct assignment in config file
tts = { provider = { gcloud = {}, aws = {}, azure = {}, win = {} } },
tts = { provider = { gcloud = {}, aws = {}, azure = {}, win = {}, parler = {} } },
srs = {},
auth = { tokens = {} }
}
Expand Down
6 changes: 6 additions & 0 deletions protos/dcs/srs/v0/srs.proto
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,19 @@ message TransmitRequest {
optional string voice = 1;
}

message Parler {
optional string speaker = 1;
}

// Optional TTS provider to be use. Defaults to the one configured in your
// config or to Windows' built-in TTS.
oneof provider {
Aws aws = 8;
Azure azure = 9;
GCloud gcloud = 10;
Windows win = 11;
// Parler does not support SSML, only use it with plain text.
Parler parler = 12;
}
}

Expand Down
3 changes: 2 additions & 1 deletion src/authentication.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use crate::config::AuthConfig;
use tonic::codegen::http::Request;
use tonic::transport::Body;
use tonic::{async_trait, Status};
use tonic_middleware::RequestInterceptor;

use crate::config::AuthConfig;

#[derive(Clone)]
pub struct AuthInterceptor {
pub auth_config: AuthConfig,
Expand Down
8 changes: 8 additions & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ pub struct TtsProviderConfig {
pub azure: Option<AzureConfig>,
pub gcloud: Option<GCloudConfig>,
pub win: Option<WinConfig>,
pub parler: Option<ParlerConfig>,
}

#[derive(Debug, Clone, Default, Deserialize, Serialize)]
Expand All @@ -49,6 +50,7 @@ pub enum TtsProvider {
GCloud,
#[default]
Win,
Parler,
}

#[derive(Clone, Deserialize, Serialize)]
Expand Down Expand Up @@ -81,6 +83,12 @@ pub struct WinConfig {
pub default_voice: Option<String>,
}

#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ParlerConfig {
pub default_speaker: Option<String>,
}

#[derive(Debug, Clone, Default, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct SrsConfig {
Expand Down
3 changes: 1 addition & 2 deletions src/rpc/metadata.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use stubs::metadata::v0::metadata_service_server::MetadataService;
use stubs::*;
use tonic::async_trait;
use tonic::{Request, Response, Status};
use tonic::{async_trait, Request, Response, Status};

use super::MissionRpc;

Expand Down
27 changes: 26 additions & 1 deletion src/rpc/srs.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
use std::error;
use std::future::Future;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::PathBuf;
use std::str::FromStr;
use std::time::{Duration, Instant};

use ::srs::Sender;
#[cfg(target_os = "windows")]
use ::tts::WinConfig;
use ::tts::{AwsConfig, AwsRegion, AzureConfig, GCloudConfig, TtsConfig};
use ::tts::{AwsConfig, AwsRegion, AzureConfig, GCloudConfig, ParlerConfig, TtsConfig};
use futures_util::FutureExt;
use stubs::common::v0::{Coalition, Unit};
use stubs::mission::v0::stream_events_response::{Event, TtsEvent};
Expand All @@ -27,6 +28,7 @@ use crate::srs::SrsClients;
pub struct Srs {
tts_config: crate::config::TtsConfig,
srs_config: crate::config::SrsConfig,
write_dir: PathBuf,
rpc: MissionRpc,
srs_clients: SrsClients,
shutdown_signal: ShutdownHandle,
Expand All @@ -36,13 +38,15 @@ impl Srs {
pub fn new(
tts_config: crate::config::TtsConfig,
srs_config: crate::config::SrsConfig,
write_dir: PathBuf,
rpc: MissionRpc,
srs_clients: SrsClients,
shutdown_signal: ShutdownHandle,
) -> Self {
Self {
tts_config,
srs_config,
write_dir,
rpc,
srs_clients,
shutdown_signal,
Expand Down Expand Up @@ -105,6 +109,9 @@ impl SrsService for Srs {
TtsProvider::Win => {
transmit_request::Provider::Win(transmit_request::Windows { voice: None })
}
TtsProvider::Parler => {
transmit_request::Provider::Parler(transmit_request::Parler { speaker: None })
}
}) {
transmit_request::Provider::Aws(transmit_request::Aws { voice }) => {
TtsConfig::Aws(AwsConfig {
Expand Down Expand Up @@ -215,6 +222,24 @@ impl SrsService for Srs {
"Windows TTS is only available on Windows",
));
}
transmit_request::Provider::Parler(transmit_request::Parler { speaker }) => {
TtsConfig::Parler(ParlerConfig {
speaker: speaker
.or_else(|| {
self.tts_config
.provider
.as_ref()
.and_then(|p| p.parler.as_ref())
.and_then(|p| p.default_speaker.clone())
})
.filter(|v| !v.is_empty())
.ok_or_else(|| {
Status::failed_precondition(
"tts.provider.parler.default_speaker not set",
)
})?,
})
}
};

let frames = ::tts::synthesize(&request.ssml, &config)
Expand Down
19 changes: 13 additions & 6 deletions src/server.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
use std::future::Future;
use std::net::SocketAddr;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;

use crate::authentication::AuthInterceptor;
use crate::config::{AuthConfig, Config, SrsConfig, TtsConfig};
use crate::rpc::{HookRpc, MissionRpc, Srs};
use crate::shutdown::{Shutdown, ShutdownHandle};
use crate::srs::SrsClients;
use crate::stats::Stats;
use dcs_module_ipc::IPC;
use futures_util::FutureExt;
use stubs::atmosphere::v0::atmosphere_service_server::AtmosphereServiceServer;
Expand All @@ -34,6 +29,13 @@ use tokio::time::sleep;
use tonic::transport;
use tonic_middleware::RequestInterceptorLayer;

use crate::authentication::AuthInterceptor;
use crate::config::{AuthConfig, Config, SrsConfig, TtsConfig};
use crate::rpc::{HookRpc, MissionRpc, Srs};
use crate::shutdown::{Shutdown, ShutdownHandle};
use crate::srs::SrsClients;
use crate::stats::Stats;

pub struct Server {
runtime: Runtime,
shutdown: Shutdown,
Expand All @@ -51,6 +53,7 @@ struct ServerState {
stats: Stats,
tts_config: TtsConfig,
srs_config: SrsConfig,
write_dir: PathBuf,
srs_transmit: Arc<Mutex<mpsc::Receiver<TransmitRequest>>>,
auth_config: AuthConfig,
}
Expand All @@ -73,6 +76,7 @@ impl Server {
stats: Stats::new(shutdown.handle()),
tts_config: config.tts.clone().unwrap_or_default(),
srs_config: config.srs.clone().unwrap_or_default(),
write_dir: PathBuf::from(&config.write_dir),
srs_transmit: Arc::new(Mutex::new(rx)),
auth_config: config.auth.clone().unwrap_or_default(),
},
Expand Down Expand Up @@ -206,6 +210,7 @@ async fn try_run(
stats,
tts_config,
srs_config,
write_dir,
srs_transmit,
auth_config,
} = state;
Expand All @@ -230,6 +235,7 @@ async fn try_run(
let srs = Srs::new(
tts_config.clone(),
srs_config.clone(),
write_dir.clone(),
mission_rpc.clone(),
srs_clients.clone(),
shutdown_signal.clone(),
Expand Down Expand Up @@ -269,6 +275,7 @@ async fn try_run(
.add_service(SrsServiceServer::new(Srs::new(
tts_config,
srs_config,
write_dir,
mission_rpc.clone(),
srs_clients,
shutdown_signal.clone(),
Expand Down
Loading
Loading