diff --git a/lib/pegboard/manager/Cargo.toml b/lib/pegboard/manager/Cargo.toml index 286442406f..fd47ba187f 100644 --- a/lib/pegboard/manager/Cargo.toml +++ b/lib/pegboard/manager/Cargo.toml @@ -14,6 +14,7 @@ lazy_static = "1.4" nix = { version = "0.27", default-features = false, features = ["user", "signal"] } notify = { version = "6.1.1", default-features = false, features = [ "serde" ] } prometheus = "0.13" +rand = "0.8" reqwest = { version = "0.11", features = ["stream"] } serde = { version = "1.0.195", features = ["derive"] } serde_json = "1.0.111" diff --git a/lib/pegboard/manager/src/container/mod.rs b/lib/pegboard/manager/src/container/mod.rs index ac666d1c60..7724bd2ec8 100644 --- a/lib/pegboard/manager/src/container/mod.rs +++ b/lib/pegboard/manager/src/container/mod.rs @@ -8,6 +8,7 @@ use anyhow::*; use futures_util::{stream::FuturesUnordered, FutureExt, StreamExt}; use indoc::indoc; use nix::{ + errno::Errno, sys::signal::{kill, Signal}, unistd::Pid, }; @@ -295,6 +296,22 @@ impl Container { }) .await?; + // Unbind ports + utils::query(|| async { + sqlx::query(indoc!( + " + UPDATE container_ports + SET delete_ts = ?2 + WHERE container_id = ?1 + ", + )) + .bind(self.container_id) + .bind(utils::now()) + .execute(&mut *ctx.sql().await?) + .await + }) + .await?; + ctx.event(protocol::Event::ContainerStateUpdate { container_id: self.container_id, state: protocol::ContainerState::Exited { exit_code }, @@ -306,31 +323,31 @@ impl Container { Ok(()) } - pub async fn stop(self: &Arc, ctx: &Arc) -> Result<()> { - tracing::info!(container_id=?self.container_id, "stopping"); + pub async fn signal(self: &Arc, ctx: &Arc, signal: Signal) -> Result<()> { + tracing::info!(container_id=?self.container_id, ?signal, "sending signal"); let self2 = self.clone(); let ctx2 = ctx.clone(); tokio::spawn(async move { - if let Err(err) = self2.stop_inner(&ctx2).await { - tracing::error!(?err, "container stop failed"); + if let Err(err) = self2.signal_inner(&ctx2, signal).await { + tracing::error!(?err, "container signal failed"); } - - self2.cleanup(&ctx2).await }); Ok(()) } - async fn stop_inner(self: &Arc, ctx: &Arc) -> Result<()> { + async fn signal_inner(self: &Arc, ctx: &Arc, signal: Signal) -> Result<()> { let mut i = 0; + // Signal command might be sent before the container has a valid PID. This loop waits for the PID to + // be set let pid = loop { if let Some(pid) = *self.pid.lock().await { break Some(pid); } - tracing::warn!(container_id=?self.container_id, "waiting for pid to stop workflow"); + tracing::warn!(container_id=?self.container_id, "waiting for pid to signal container"); if i > STOP_PID_RETRIES { tracing::error!( @@ -345,31 +362,42 @@ impl Container { tokio::time::sleep(STOP_PID_INTERVAL).await; }; - // Kill if PID found + // Kill if PID set if let Some(pid) = pid { - kill(pid, Signal::SIGTERM)?; + use std::result::Result::{Err, Ok}; + + match kill(pid, signal) { + Ok(_) => {} + Err(Errno::ESRCH) => { + tracing::warn!(container_id=?self.container_id, ?pid, "pid not found for signalling") + } + Err(err) => return Err(err.into()), + } } - utils::query(|| async { - sqlx::query(indoc!( - " - UPDATE containers - SET stop_ts = ?2 - container_id = ?1 - ", - )) - .bind(self.container_id) - .bind(utils::now()) - .execute(&mut *ctx.sql().await?) - .await - }) - .await?; + // Update stop_ts + if matches!(signal, Signal::SIGTERM) || pid.is_none() { + utils::query(|| async { + sqlx::query(indoc!( + " + UPDATE containers + SET stop_ts = ?2 + container_id = ?1 + ", + )) + .bind(self.container_id) + .bind(utils::now()) + .execute(&mut *ctx.sql().await?) + .await + }) + .await?; - ctx.event(protocol::Event::ContainerStateUpdate { - container_id: self.container_id, - state: protocol::ContainerState::Stopped, - }) - .await?; + ctx.event(protocol::Event::ContainerStateUpdate { + container_id: self.container_id, + state: protocol::ContainerState::Stopped, + }) + .await?; + } Ok(()) } diff --git a/lib/pegboard/manager/src/container/setup.rs b/lib/pegboard/manager/src/container/setup.rs index 0ce96e1456..3cd9e902bf 100644 --- a/lib/pegboard/manager/src/container/setup.rs +++ b/lib/pegboard/manager/src/container/setup.rs @@ -12,6 +12,7 @@ use nix::{ unistd::{fork, pipe, read, write, ForkResult, Pid}, }; use pegboard::protocol; +use rand::Rng; use serde_json::{json, Value}; use tokio::{ fs::{self, File}, @@ -20,9 +21,11 @@ use tokio::{ }; use super::{oci_config, Container}; -use crate::ctx::Ctx; +use crate::{ctx::Ctx, utils}; const NETWORK_NAME: &str = "rivet-pegboard"; +const MIN_INGRESS_PORT: u16 = 20000; +const MAX_INGRESS_PORT: u16 = 31999; impl Container { pub async fn setup_oci_bundle(&self, ctx: &Ctx) -> Result<()> { @@ -334,21 +337,7 @@ impl Container { tracing::info!(container_id=?self.container_id, "writing cni params"); - let cni_port_mappings = self - .config - .ports - .iter() - .map(|(_, port)| { - // Pick random port that isn't taken - let host_port = todo!(); - - json!({ - "HostPort": host_port, - "ContainerPort": port.internal_port, - "Protocol": port.proxy_protocol.to_string(), - }) - }) - .collect::>(); + let cni_port_mappings = self.bind_ports(ctx).await?; // MARK: Generate CNI parameters // @@ -414,6 +403,137 @@ impl Container { Ok(()) } + pub(crate) async fn bind_ports(&self, ctx: &Ctx) -> Result> { + let mut tcp_count = 0; + let mut udp_count = 0; + + // Count ports + for (_, port) in &self.config.ports { + match port.proxy_protocol { + protocol::TransportProtocol::Tcp => tcp_count += 1, + protocol::TransportProtocol::Udp => udp_count += 1, + } + } + + let max = MAX_INGRESS_PORT - MIN_INGRESS_PORT; + let tcp_offset = rand::thread_rng().gen_range(0..max); + let udp_offset = rand::thread_rng().gen_range(0..max); + + // Selects available TCP and UDP ports + let rows = utils::query(|| async { + sqlx::query_as::<_, (i64, i64)>(indoc!( + " + INSERT INTO container_ports (container_id, port, protocol) + SELECT ?1, port, protocol + -- Select TCP ports + FROM ( + WITH RECURSIVE + nums(n, i) AS ( + SELECT ?4, ?4 + UNION ALL + SELECT (n + 1) % (?6 + 1), i + 1 + FROM nums + WHERE i < ?6 + ?4 + ), + available_ports(port) AS ( + SELECT nums.n + FROM nums + LEFT JOIN container_ports AS p + ON + nums.n = p.port AND + p.protocol = 0 AND + delete_ts IS NULL + WHERE + p.port IS NULL OR + delete_ts IS NOT NULL + LIMIT ?2 + ) + SELECT port, 0 AS protocol FROM available_ports + ) + UNION ALL + SELECT ?1, port, protocol + -- Select UDP ports + FROM ( + WITH RECURSIVE + nums(n, i) AS ( + SELECT ?5, ?5 + UNION ALL + SELECT (n + 1) % (?6 + 1), i + 1 + FROM nums + WHERE i < ?6 + ?5 + ), + available_ports(port) AS ( + SELECT nums.n + FROM nums + LEFT JOIN container_ports AS p + ON + nums.n = p.port AND + p.protocol = 1 AND + delete_ts IS NULL + WHERE + p.port IS NULL OR + delete_ts IS NOT NULL + LIMIT ?3 + ) + SELECT port, 1 AS protocol FROM available_ports + ) + RETURNING port, protocol + ", + )) + .bind(self.container_id) + .bind(tcp_count as i64) // ?2 + .bind(udp_count as i64) // ?3 + .bind(tcp_offset as i64) // ?4 + .bind(udp_offset as i64) // ?5 + .bind(max as i64) // ?6 + .fetch_all(&mut *ctx.sql().await?) + .await + }) + .await?; + + if rows.len() != tcp_count + udp_count { + bail!("not enough available ports"); + } + + let cni_port_mappings = self + .config + .ports + .iter() + .filter(|(_, port)| matches!(port.proxy_protocol, protocol::TransportProtocol::Tcp)) + .zip( + rows.iter() + .filter(|(_, protocol)| *protocol == protocol::TransportProtocol::Tcp as i64), + ) + .map(|((_, port), (host_port, _))| { + json!({ + "HostPort": host_port, + "ContainerPort": port.internal_port, + "Protocol": port.proxy_protocol.to_string(), + }) + }) + .chain( + self.config + .ports + .iter() + .filter(|(_, port)| { + matches!(port.proxy_protocol, protocol::TransportProtocol::Udp) + }) + .zip(rows.iter().filter(|(_, protocol)| { + *protocol == protocol::TransportProtocol::Udp as i64 + })) + .map(|((_, port), (host_port, _))| { + json!({ + "HostPort": host_port, + "ContainerPort": port.internal_port, + "Protocol": port.proxy_protocol.to_string(), + }) + }), + ) + .collect::>(); + + Ok(cni_port_mappings) + } + #[tracing::instrument(skip_all)] pub async fn cleanup(&self, ctx: &Ctx) -> Result<()> { use std::result::Result::{Err, Ok}; diff --git a/lib/pegboard/manager/src/ctx.rs b/lib/pegboard/manager/src/ctx.rs index 6ca37d511d..3ddfd01c47 100644 --- a/lib/pegboard/manager/src/ctx.rs +++ b/lib/pegboard/manager/src/ctx.rs @@ -236,9 +236,12 @@ impl Ctx { // Spawn container container.start(&self).await?; } - protocol::Command::StopContainer { container_id } => { + protocol::Command::SignalContainer { + container_id, + signal, + } => { if let Some(container) = self.containers.read().await.get(&container_id) { - container.stop(&self).await?; + container.signal(&self, signal.try_into()?).await?; } else { tracing::warn!( ?container_id, @@ -318,7 +321,7 @@ impl Ctx { " SELECT container_id, config, pid FROM containers - WHERE stop_ts IS NULL AND exit_ts IS NULL + WHERE exit_ts IS NULL ", )) .fetch_all(&mut *self.sql().await?) diff --git a/lib/pegboard/manager/src/utils.rs b/lib/pegboard/manager/src/utils.rs index a0da46278a..6d298d64cb 100644 --- a/lib/pegboard/manager/src/utils.rs +++ b/lib/pegboard/manager/src/utils.rs @@ -179,6 +179,39 @@ pub async fn init_sqlite_schema(pool: &SqlitePool) -> Result<()> { .execute(&mut *conn) .await?; + sqlx::query(indoc!( + " + CREATE TABLE IF NOT EXISTS container_ports ( + container_id TEXT NOT NULL, -- UUID + port INT NOT NULL, + protocol INT NOT NULL, -- protocol::TransportProtocol + + delete_ts INT + ) + ", + )) + .execute(&mut *conn) + .await?; + + sqlx::query(indoc!( + " + CREATE INDEX IF NOT EXISTS container_ports_id_idx + ON container_ports(container_id) + ", + )) + .execute(&mut *conn) + .await?; + + sqlx::query(indoc!( + " + CREATE UNIQUE INDEX IF NOT EXISTS container_ports_unique_idx + ON container_ports(port, protocol) + WHERE delete_ts IS NULL + ", + )) + .execute(&mut *conn) + .await?; + Ok(()) } diff --git a/lib/util/core/src/serde.rs b/lib/util/core/src/serde.rs index 7d1c09c1b6..5aecd9e5f8 100644 --- a/lib/util/core/src/serde.rs +++ b/lib/util/core/src/serde.rs @@ -82,6 +82,33 @@ impl FromIterator<(K, V)> for HashableMap { } } +impl IntoIterator for HashableMap { + type Item = (K, V); + type IntoIter = indexmap::map::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +impl<'a, K: Eq + Hash, V: Hash> IntoIterator for &'a HashableMap { + type Item = (&'a K, &'a V); + type IntoIter = indexmap::map::Iter<'a, K, V>; + + fn into_iter(self) -> Self::IntoIter { + self.0.iter() + } +} + +impl<'a, K: Eq + Hash, V: Hash> IntoIterator for &'a mut HashableMap { + type Item = (&'a K, &'a mut V); + type IntoIter = indexmap::map::IterMut<'a, K, V>; + + fn into_iter(self) -> Self::IntoIter { + self.0.iter_mut() + } +} + /// Allows partial json ser/de. /// Effectively a `serde_json::value::RawValue` with type information. pub struct Raw { diff --git a/svc/pkg/pegboard/Cargo.toml b/svc/pkg/pegboard/Cargo.toml index d953b4dea0..e243c4f798 100644 --- a/svc/pkg/pegboard/Cargo.toml +++ b/svc/pkg/pegboard/Cargo.toml @@ -8,6 +8,7 @@ license = "Apache-2.0" [dependencies] chirp-workflow = { path = "../../../lib/chirp-workflow/core" } serde = { version = "1.0.198", features = ["derive"] } +strum = { version = "0.24", features = ["derive"] } thiserror = "1.0" [dependencies.sqlx] diff --git a/svc/pkg/pegboard/db/pegboard/migrations/20240913005543_init.up.sql b/svc/pkg/pegboard/db/pegboard/migrations/20240913005543_init.up.sql index 3d36199a36..82b2f53bd3 100644 --- a/svc/pkg/pegboard/db/pegboard/migrations/20240913005543_init.up.sql +++ b/svc/pkg/pegboard/db/pegboard/migrations/20240913005543_init.up.sql @@ -1,8 +1,11 @@ CREATE TABLE clients ( client_id UUID PRIMARY KEY, create_ts INT NOT NULL, + last_ping_ts INT NOT NULL, last_event_idx INT NOT NULL DEFAULT 0, - last_command_idx INT NOT NULL DEFAULT 0 + last_command_idx INT NOT NULL DEFAULT 0, + + drain_ts INT ); CREATE TABLE client_events ( diff --git a/svc/pkg/pegboard/src/ops/client/usage_get.rs b/svc/pkg/pegboard/src/ops/client/usage_get.rs index 0d8e0ad70d..192fa284e4 100644 --- a/svc/pkg/pegboard/src/ops/client/usage_get.rs +++ b/svc/pkg/pegboard/src/ops/client/usage_get.rs @@ -48,7 +48,6 @@ pub async fn pegboard_client_usage_get(ctx: &OperationCtx, input: &Input) -> Glo FROM db_pegboard.containers WHERE client_id = ANY($1) AND - stop_ts IS NULL AND exit_ts IS NULL GROUP BY client_id ", diff --git a/svc/pkg/pegboard/src/protocol.rs b/svc/pkg/pegboard/src/protocol.rs index 1c28fc6605..274f4107e9 100644 --- a/svc/pkg/pegboard/src/protocol.rs +++ b/svc/pkg/pegboard/src/protocol.rs @@ -1,4 +1,5 @@ use chirp_workflow::prelude::*; +use strum::FromRepr; // Reexport for ease of use in pegboard manager pub use util::serde::{HashableMap, Raw}; @@ -59,8 +60,10 @@ pub enum Command { container_id: Uuid, config: Box, }, - StopContainer { + SignalContainer { container_id: Uuid, + // See nix::sys::signal::Signal + signal: i32, }, } @@ -102,10 +105,10 @@ pub struct Port { pub proxy_protocol: TransportProtocol, } -#[derive(Serialize, Deserialize, Hash, Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Serialize, Deserialize, Hash, Debug, Clone, Copy, PartialEq, Eq, FromRepr)] pub enum TransportProtocol { - Tcp, - Udp, + Tcp = 0, + Udp = 1, } impl std::fmt::Display for TransportProtocol { diff --git a/svc/pkg/pegboard/src/workflows/client.rs b/svc/pkg/pegboard/src/workflows/client.rs index e0e711ae74..5aecb85ac6 100644 --- a/svc/pkg/pegboard/src/workflows/client.rs +++ b/svc/pkg/pegboard/src/workflows/client.rs @@ -271,11 +271,11 @@ async fn set_drain(ctx: &ActivityCtx, input: &SetDrainInput) -> GlobalResult<()> [ctx] " UPDATE db_pegboard.clients - SET draining = $2 + SET drain_ts = $2 WHERE client_id = $1 ", input.client_id, - input.drain, + input.drain.then(util::timestamp::now), ) .await?; diff --git a/svc/pkg/pegboard/standalone/ws/src/lib.rs b/svc/pkg/pegboard/standalone/ws/src/lib.rs index d59ff2c1de..540c493b87 100644 --- a/svc/pkg/pegboard/standalone/ws/src/lib.rs +++ b/svc/pkg/pegboard/standalone/ws/src/lib.rs @@ -155,8 +155,8 @@ async fn handle_connection_inner( .send() .await?; } - // TODO: Implement timeout for clients that haven't pinged in a while Message::Ping(_) => { + update_ping(ctx, client_id).await?; tx.lock().await.send(Message::Pong(Vec::new())).await?; } Message::Close(_) => { @@ -167,7 +167,7 @@ async fn handle_connection_inner( } } } - + bail!(format!("stream closed {client_id}")); // Only way I could figure out to help the complier infer type @@ -180,8 +180,8 @@ async fn upsert_client(ctx: &StandaloneCtx, client_id: Uuid) -> GlobalResult<()> let inserted = sql_fetch_optional!( [ctx, (i64,)] " - INSERT INTO db_pegboard.clients (client_id, create_ts) - VALUES ($1, $2) + INSERT INTO db_pegboard.clients (client_id, create_ts, last_ping_ts) + VALUES ($1, $2, $2) ON CONFLICT (client_id) DO NOTHING RETURNING 1 ", @@ -202,6 +202,22 @@ async fn upsert_client(ctx: &StandaloneCtx, client_id: Uuid) -> GlobalResult<()> Ok(()) } +async fn update_ping(ctx: &StandaloneCtx, client_id: Uuid) -> GlobalResult<()> { + sql_execute!( + [ctx, (i64,)] + " + UPDATE db_pegboard.clients + SET last_ping_ts = $2 + WHERE client_id = $1 + ", + client_id, + util::timestamp::now(), + ) + .await?; + + Ok(()) +} + async fn signal_thread(ctx: &StandaloneCtx, conns: Arc>) -> GlobalResult<()> { // Listen for commands from client workflows let mut sub = ctx