diff --git a/lib/chirp-workflow/core/src/ctx/workflow.rs b/lib/chirp-workflow/core/src/ctx/workflow.rs index 53ac516a34..b2ccc63a3e 100644 --- a/lib/chirp-workflow/core/src/ctx/workflow.rs +++ b/lib/chirp-workflow/core/src/ctx/workflow.rs @@ -31,10 +31,6 @@ use crate::{ workflow::{Workflow, WorkflowInput}, }; -/// Poll interval when polling for signals in-process -const SIGNAL_RETRY: Duration = Duration::from_millis(100); -/// Most in-process signal poll tries -const MAX_SIGNAL_RETRIES: usize = 16; /// Most in-process sub workflow poll tries const MAX_SUB_WORKFLOW_RETRIES: usize = 4; /// Retry interval for failed db actions @@ -620,26 +616,9 @@ impl WorkflowCtx { else { tracing::info!(name=%self.name, id=%self.workflow_id, "listening for signal"); - let mut retries = 0; - let mut interval = tokio::time::interval(SIGNAL_RETRY); - interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); - let ctx = ListenCtx::new(self); - loop { - interval.tick().await; - - match T::listen(&ctx).await { - Ok(res) => break res, - Err(err) if matches!(err, WorkflowError::NoSignalFound(_)) => { - if retries > MAX_SIGNAL_RETRIES { - return Err(err).map_err(GlobalError::raw); - } - retries += 1; - } - err => return err.map_err(GlobalError::raw), - } - } + T::listen(&ctx).await.map_err(GlobalError::raw)? }; // Move to next event @@ -674,26 +653,9 @@ impl WorkflowCtx { else { tracing::info!(name=%self.name, id=%self.workflow_id, "listening for signal"); - let mut retries = 0; - let mut interval = tokio::time::interval(SIGNAL_RETRY); - interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); - let ctx = ListenCtx::new(self); - loop { - interval.tick().await; - - match listener.listen(&ctx).await { - Ok(res) => break res, - Err(err) if matches!(err, WorkflowError::NoSignalFound(_)) => { - if retries > MAX_SIGNAL_RETRIES { - return Err(err).map_err(GlobalError::raw); - } - retries += 1; - } - err => return err.map_err(GlobalError::raw), - } - } + listener.listen(&ctx).await.map_err(GlobalError::raw)? }; // Move to next event diff --git a/lib/chirp-workflow/core/src/db/pg_nats.rs b/lib/chirp-workflow/core/src/db/pg_nats.rs index 8bb424b63d..e003cff2e1 100644 --- a/lib/chirp-workflow/core/src/db/pg_nats.rs +++ b/lib/chirp-workflow/core/src/db/pg_nats.rs @@ -17,9 +17,13 @@ use crate::{ activity::ActivityId, error::{WorkflowError, WorkflowResult}, event::combine_events, - message, worker, + message, + signal::{NatsSignal, NatsSignalTarget}, + utils, worker, }; +/// Time to wait before stopping listening for signals on NATS and going to sleep. +const SIGNAL_NATS_TIMEOUT: Duration = Duration::from_secs(3); /// Max amount of workflows pulled from the database with each call to `pull_workflows`. const MAX_PULLED_WORKFLOWS: i64 = 50; // Base retry for query retry backoff @@ -76,6 +80,93 @@ impl DatabasePgNats { } } + async fn publish_signal_nats(&self, signal_name: &str, target: NatsSignalTarget) -> WorkflowResult<()> { + let nats_signal = NatsSignal { + target, + }; + let nats_signal_buf = serde_json::to_vec(&nats_signal)?; + + nats + .publish(nats_signal_subject(signal_name), &nats_signal_buf) + .await + } + + pull_next_signal_inner() { + let signal = self + .query(|| async { + sqlx::query_as::<_, SignalRow>(indoc!( + " + WITH + -- Finds the oldest signal matching the signal name filter in either the normal signals table + -- or tagged signals table + next_signal AS ( + SELECT false AS tagged, signal_id, create_ts, signal_name, body + FROM db_workflow.signals + WHERE + workflow_id = $1 AND + signal_name = ANY($2) AND + ack_ts IS NULL + UNION ALL + SELECT true AS tagged, signal_id, create_ts, signal_name, body + FROM db_workflow.tagged_signals + WHERE + signal_name = ANY($2) AND + tags <@ (SELECT tags FROM db_workflow.workflows WHERE workflow_id = $1) AND + ack_ts IS NULL + ORDER BY create_ts ASC + LIMIT 1 + ), + -- If the next signal is not tagged, acknowledge it with this statement + ack_signal AS ( + UPDATE db_workflow.signals + SET ack_ts = $4 + WHERE signal_id = ( + SELECT signal_id FROM next_signal WHERE tagged = false + ) + RETURNING 1 + ), + -- If the next signal is tagged, acknowledge it with this statement + ack_tagged_signal AS ( + UPDATE db_workflow.tagged_signals + SET ack_ts = $4 + WHERE signal_id = ( + SELECT signal_id FROM next_signal WHERE tagged = true + ) + RETURNING 1 + ), + -- After acking the signal, add it to the events table + insert_event AS ( + INSERT INTO db_workflow.workflow_signal_events ( + workflow_id, location, signal_id, signal_name, body, ack_ts, loop_location + ) + SELECT + $1 AS workflow_id, + $3 AS location, + signal_id, + signal_name, + body, + $4 AS ack_ts, + $5 AS loop_location + FROM next_signal + RETURNING 1 + ) + SELECT * FROM next_signal + ", + )) + .bind(workflow_id) + .bind(filter) + .bind(location.iter().map(|x| *x as i64).collect::>()) + .bind(rivet_util::timestamp::now()) + .bind(loop_location.map(|l| l.iter().map(|x| *x as i64).collect::>())) + .fetch_optional(&mut *self.conn().await?) + .await + .map_err(WorkflowError::Sqlx) + }) + .await?; + + Ok(signal) + } + /// Executes queries and explicitly handles retry errors. async fn query<'a, F, Fut, T>(&self, mut cb: F) -> WorkflowResult where @@ -621,6 +712,17 @@ impl Database for DatabasePgNats { Ok(()) } + // Breakdown of the logic for this implementation: + // + // - Check the database if the signal exists + // - If not, create a NATS sub for each given signal name + // - Listen on all subs simultaneously + // - When a sub returns a NATS message, serialize it as a signal + // - Check if the signal's target matches + // - If it does, break from the NATS loop with a `true` value + // - If it does not, put the sub back into the queue again to wait for the next message + // - If the timeout is reached, stop the loop with a `false` value + // - Based on the above value, pull the signal from the database (for acking) or return `None` async fn pull_next_signal( &self, workflow_id: Uuid, @@ -628,79 +730,97 @@ impl Database for DatabasePgNats { location: &[usize], loop_location: Option<&[usize]>, ) -> WorkflowResult> { - let signal = self - .query(|| async { - sqlx::query_as::<_, SignalRow>(indoc!( + if let Some(signal) = self.pull_next_signal_inner().await { + return Ok(Some(signal)); + } + + // Future streams don't like references + let signal_names = filter + .iter() + .map(|signal_name| signal_name.to_string()) + .collect::>(); + + let ((tags,), subs) = tokio::try_join!( + // Fetch workflow tags + async { + sqlx::query_as::<_, (Option,)>(indoc!( " - WITH - -- Finds the oldest signal matching the signal name filter in either the normal signals table - -- or tagged signals table - next_signal AS ( - SELECT false AS tagged, signal_id, create_ts, signal_name, body - FROM db_workflow.signals - WHERE - workflow_id = $1 AND - signal_name = ANY($2) AND - ack_ts IS NULL - UNION ALL - SELECT true AS tagged, signal_id, create_ts, signal_name, body - FROM db_workflow.tagged_signals - WHERE - signal_name = ANY($2) AND - tags <@ (SELECT tags FROM db_workflow.workflows WHERE workflow_id = $1) AND - ack_ts IS NULL - ORDER BY create_ts ASC - LIMIT 1 - ), - -- If the next signal is not tagged, acknowledge it with this statement - ack_signal AS ( - UPDATE db_workflow.signals - SET ack_ts = $4 - WHERE signal_id = ( - SELECT signal_id FROM next_signal WHERE tagged = false - ) - RETURNING 1 - ), - -- If the next signal is tagged, acknowledge it with this statement - ack_tagged_signal AS ( - UPDATE db_workflow.tagged_signals - SET ack_ts = $4 - WHERE signal_id = ( - SELECT signal_id FROM next_signal WHERE tagged = true - ) - RETURNING 1 - ), - -- After acking the signal, add it to the events table - insert_event AS ( - INSERT INTO db_workflow.workflow_signal_events ( - workflow_id, location, signal_id, signal_name, body, ack_ts, loop_location - ) - SELECT - $1 AS workflow_id, - $3 AS location, - signal_id, - signal_name, - body, - $4 AS ack_ts, - $5 AS loop_location - FROM next_signal - RETURNING 1 - ) - SELECT * FROM next_signal + SELECT tags + FROM db_workflow.workflows + WHERE workflow_id = $1 ", )) .bind(workflow_id) - .bind(filter) - .bind(location.iter().map(|x| *x as i64).collect::>()) - .bind(rivet_util::timestamp::now()) - .bind(loop_location.map(|l| l.iter().map(|x| *x as i64).collect::>())) - .fetch_optional(&mut *self.conn().await?) + .fetch_one(&mut *self.conn().await?) .await .map_err(WorkflowError::Sqlx) - }) - .await?; + }, + // Create a NATS sub for each signal name + futures_util::stream::iter(signal_names) + .map(|signal_name| async move { + self.nats + .subscribe(signal_nats_subject(&signal_name)) + .await + .map_err(|x| WorkflowError::CreateSubscription(x.into())) + }) + .buffer_unordered(8) + .try_collect::>() + )?; - Ok(signal) + // Listen to each NATS sub + let mut nats_futs = subs + .into_iter() + .map(|mut sub| async move { (sub.next().await, sub) }.boxed()) + .collect::>(); + + let nats_signal_received = async move { + // Select the first sub to return a message + while let Some((nats_res, mut sub)) = nats_futs.next().await { + if let Some(nats_message) = nats_res { + let nats_signal = + serde_json::from_slice::(&nats_message.payload[..]) + .map_err(WorkflowError::DeserializeNatsSignal)?; + + match nats_signal.target { + NatsSignalTarget::WorkflowId(signal_workflow_id) => { + // Workflow ids match + if signal_workflow_id == workflow_id { + return Ok(true); + } + } + NatsSignalTarget::Tags(signal_tags) => { + // Check if workflow has tags + if let Some(tags) = &tags { + // Tags match + if utils::is_value_subset(&tags, &signal_tags) { + return Ok(true); + } + } + } + } + + // Put sub back in queue + nats_futs.push(async move { (sub.next().await, sub) }.boxed()); + } else { + return Err(WorkflowError::SubscriptionUnsubscribed); + } + } + + Ok(false) + }; + + // Add timeout + let nats_signal_received = tokio::time::timeout(SIGNAL_NATS_TIMEOUT, nats_signal_received) + .await + .ok() + .transpose()? + .unwrap_or_default(); + + if !nats_signal_received { + return Ok(None); + } + + self.pull_next_signal_inner().await } async fn publish_signal( @@ -711,24 +831,26 @@ impl Database for DatabasePgNats { signal_name: &str, body: serde_json::Value, ) -> WorkflowResult<()> { - self.query(|| async { - sqlx::query(indoc!( - " - INSERT INTO db_workflow.signals (signal_id, workflow_id, signal_name, body, ray_id, create_ts) - VALUES ($1, $2, $3, $4, $5, $6) - ", - )) - .bind(signal_id) - .bind(workflow_id) - .bind(signal_name) - .bind(&body) - .bind(ray_id) - .bind(rivet_util::timestamp::now()) - .execute(&mut *self.conn().await?) - .await - .map_err(WorkflowError::Sqlx) - }) - .await?; + tokio::try_join!( + self.query(|| async { + sqlx::query(indoc!( + " + INSERT INTO db_workflow.signals (signal_id, workflow_id, signal_name, body, ray_id, create_ts) + VALUES ($1, $2, $3, $4, $5, $6) + ", + )) + .bind(signal_id) + .bind(workflow_id) + .bind(signal_name) + .bind(&body) + .bind(ray_id) + .bind(rivet_util::timestamp::now()) + .execute(&mut *self.conn().await?) + .await + .map_err(WorkflowError::Sqlx) + }), + self.publish_signal_nats(signal_name, NatsSignalTarget::WorkflowId(workflow_id)), + ); self.wake_worker(); @@ -743,24 +865,26 @@ impl Database for DatabasePgNats { signal_name: &str, body: serde_json::Value, ) -> WorkflowResult<()> { - self.query(|| async { - sqlx::query(indoc!( - " - INSERT INTO db_workflow.tagged_signals (signal_id, tags, signal_name, body, ray_id, create_ts) - VALUES ($1, $2, $3, $4, $5, $6) - ", - )) - .bind(signal_id) - .bind(tags) - .bind(signal_name) - .bind(&body) - .bind(ray_id) - .bind(rivet_util::timestamp::now()) - .execute(&mut *self.conn().await?) - .await - .map_err(WorkflowError::Sqlx) - }) - .await?; + tokio::try_join!( + self.query(|| async { + sqlx::query(indoc!( + " + INSERT INTO db_workflow.tagged_signals (signal_id, tags, signal_name, body, ray_id, create_ts) + VALUES ($1, $2, $3, $4, $5, $6) + ", + )) + .bind(signal_id) + .bind(tags) + .bind(signal_name) + .bind(&body) + .bind(ray_id) + .bind(rivet_util::timestamp::now()) + .execute(&mut *self.conn().await?) + .await + .map_err(WorkflowError::Sqlx) + }), + self.publish_signal_nats(signal_name, NatsSignalTarget::Tags(tags.clone())), + )?; self.wake_worker(); @@ -1121,3 +1245,7 @@ impl Database for DatabasePgNats { Ok(()) } } + +fn signal_nats_subject(signal_name: &str) -> String { + format!("chirp.workflow.signal.{signal_name}") +} diff --git a/lib/chirp-workflow/core/src/error.rs b/lib/chirp-workflow/core/src/error.rs index cc1f2afafb..0bd521b017 100644 --- a/lib/chirp-workflow/core/src/error.rs +++ b/lib/chirp-workflow/core/src/error.rs @@ -63,6 +63,12 @@ pub enum WorkflowError { #[error("deserialize signal body: {0}")] DeserializeSignalBody(serde_json::Error), + #[error("serialize nats signal: {0}")] + SerializeNatsSignal(serde_json::Error), + + #[error("deserialize nats signal: {0}")] + DeserializeNatsSignal(serde_json::Error), + #[error("serialize message body: {0}")] SerializeMessageBody(serde_json::Error), diff --git a/lib/chirp-workflow/core/src/signal.rs b/lib/chirp-workflow/core/src/signal.rs index e45749c279..d484d79872 100644 --- a/lib/chirp-workflow/core/src/signal.rs +++ b/lib/chirp-workflow/core/src/signal.rs @@ -1,7 +1,23 @@ +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + pub trait Signal { const NAME: &'static str; } +/// A signal received from a NATS subscription. +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct NatsSignal { + pub target: NatsSignalTarget, + // body: , +} + +#[derive(Debug, Serialize, Deserialize)] +pub(crate) enum NatsSignalTarget { + WorkflowId(Uuid), + Tags(serde_json::Value), +} + /// Creates an enum that implements `Listen` and selects one of X signals. /// /// Example: