Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(workflows): add nats for signals #1149

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/libraries/workflow/GOTCHAS.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,12 @@ Be careful when writing your struct definitions.
When force waking a sleeping workflow by setting `wake_immediate = true`, know that if the workflow is
currently on a `sleep` step it will go back to sleep if it has not reached its `wake_deadline` yet. For all
other steps, the workflow will continue normally (usually just go back to sleep).

## Long-lived tasks in `ctx.join`

When executing multiple long-lived activities in a `ctx.join` call using a tuple, remember that internally it
uses `tokio::join!` and not `tokio::try_join`. This means it will wait until all items finish and does not
short circuit when an `Err` is returned from any branch.

So if you have an activity that errors immediately and another that takes a while to finish, the `ctx.join`
call will wait until the long task is complete (or errors) before returning.
56 changes: 29 additions & 27 deletions lib/bolt/core/src/context/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -510,33 +510,35 @@ impl ServiceContextData {
);
}

let can_depend =
if self.is_monolith_worker() {
matches!(
dep.config().kind,
ServiceKind::Database { .. }
| ServiceKind::Cache { .. } | ServiceKind::Operation { .. }
| ServiceKind::Package { .. }
| ServiceKind::Consumer { .. }
)
} else if matches!(self.config().kind, ServiceKind::Api { .. }) {
matches!(
dep.config().kind,
ServiceKind::Database { .. }
| ServiceKind::Cache { .. } | ServiceKind::Operation { .. }
| ServiceKind::Package { .. }
| ServiceKind::ApiRoutes { .. }
| ServiceKind::Consumer { .. }
)
} else {
matches!(
dep.config().kind,
ServiceKind::Database { .. }
| ServiceKind::Cache { .. } | ServiceKind::Operation { .. }
| ServiceKind::Package { .. }
| ServiceKind::Consumer { .. }
)
};
let can_depend = if self.is_monolith_worker() {
matches!(
dep.config().kind,
ServiceKind::Database { .. }
| ServiceKind::Cache { .. }
| ServiceKind::Operation { .. }
| ServiceKind::Package { .. }
| ServiceKind::Consumer { .. }
)
} else if matches!(self.config().kind, ServiceKind::Api { .. }) {
matches!(
dep.config().kind,
ServiceKind::Database { .. }
| ServiceKind::Cache { .. }
| ServiceKind::Operation { .. }
| ServiceKind::Package { .. }
| ServiceKind::ApiRoutes { .. }
| ServiceKind::Consumer { .. }
)
} else {
matches!(
dep.config().kind,
ServiceKind::Database { .. }
| ServiceKind::Cache { .. }
| ServiceKind::Operation { .. }
| ServiceKind::Package { .. }
| ServiceKind::Consumer { .. }
)
};

if !can_depend {
panic!(
Expand Down
1 change: 0 additions & 1 deletion lib/chirp-workflow/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ indoc = "2.0.5"
lazy_static = "1.4"
prost = "0.12.4"
prost-types = "0.12.4"
rand = "0.8.5"
rivet-cache = { path = "../../cache/build" }
rivet-connection = { path = "../../connection" }
rivet-metrics = { path = "../../metrics" }
Expand Down
2 changes: 1 addition & 1 deletion lib/chirp-workflow/core/src/compat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ where
M: Message,
B: Debug + Clone,
{
let msg_ctx = MessageCtx::new(ctx.conn(), ctx.req_id(), ctx.ray_id())
let msg_ctx = MessageCtx::new(ctx.conn(), ctx.ray_id())
.await
.map_err(GlobalError::raw)?;

Expand Down
6 changes: 3 additions & 3 deletions lib/chirp-workflow/core/src/ctx/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::{
},
db::DatabaseHandle,
error::WorkflowResult,
message::{Message, ReceivedMessage},
message::{Message, NatsMessage},
operation::{Operation, OperationInput},
signal::Signal,
workflow::{Workflow, WorkflowInput},
Expand Down Expand Up @@ -55,7 +55,7 @@ impl ApiCtx {
(),
);

let msg_ctx = MessageCtx::new(&conn, req_id, ray_id).await?;
let msg_ctx = MessageCtx::new(&conn, ray_id).await?;

Ok(ApiCtx {
ray_id,
Expand Down Expand Up @@ -129,7 +129,7 @@ impl ApiCtx {
pub async fn tail_read<M>(
&self,
tags: serde_json::Value,
) -> GlobalResult<Option<ReceivedMessage<M>>>
) -> GlobalResult<Option<NatsMessage<M>>>
where
M: Message,
{
Expand Down
2 changes: 1 addition & 1 deletion lib/chirp-workflow/core/src/ctx/backfill.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use std::{
};
use uuid::Uuid;

use crate::util::Location;
use crate::utils::Location;

// Yes
type Query = Box<
Expand Down
131 changes: 23 additions & 108 deletions lib/chirp-workflow/core/src/ctx/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ use uuid::Uuid;

use crate::{
error::{WorkflowError, WorkflowResult},
message::{self, Message, MessageWrapper, ReceivedMessage, TraceEntry},
message::{redis_keys, Message, NatsMessage, NatsMessageWrapper},
utils,
};

/// Time (in ms) that we subtract from the anchor grace period in order to
Expand All @@ -29,29 +30,15 @@ pub struct MessageCtx {
/// Used for writing to message tails. This cache is ephemeral.
redis_chirp_ephemeral: RedisPool,

req_id: Uuid,
ray_id: Uuid,
trace: Vec<TraceEntry>,
}

impl MessageCtx {
pub async fn new(
conn: &rivet_connection::Connection,
req_id: Uuid,
ray_id: Uuid,
) -> WorkflowResult<Self> {
pub async fn new(conn: &rivet_connection::Connection, ray_id: Uuid) -> WorkflowResult<Self> {
Ok(MessageCtx {
nats: conn.nats().await?,
redis_chirp_ephemeral: conn.redis_chirp_ephemeral().await?,
req_id,
ray_id,
trace: conn
.chirp()
.trace()
.iter()
.cloned()
.map(TryInto::try_into)
.collect::<WorkflowResult<Vec<_>>>()?,
})
}
}
Expand Down Expand Up @@ -109,7 +96,7 @@ impl MessageCtx {
M: Message,
{
let tags_str = cjson::to_string(&tags).map_err(WorkflowError::SerializeMessageTags)?;
let nats_subject = message::serialize_message_nats_subject::<M>(&tags_str);
let nats_subject = M::nats_subject();
let duration_since_epoch = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|err| unreachable!("time is broken: {}", err));
Expand All @@ -124,12 +111,11 @@ impl MessageCtx {

// Serialize message
let req_id = Uuid::new_v4();
let message = MessageWrapper {
let message = NatsMessageWrapper {
req_id: req_id,
ray_id: self.ray_id,
tags: tags.clone(),
tags,
ts,
trace: self.trace.clone(),
allow_recursive: false, // TODO:
body: &body_buf,
};
Expand Down Expand Up @@ -278,8 +264,7 @@ impl MessageCtx {
where
M: Message,
{
let tags_str = cjson::to_string(opts.tags).map_err(WorkflowError::SerializeMessageTags)?;
let nats_subject = message::serialize_message_nats_subject::<M>(&tags_str);
let nats_subject = M::nats_subject();

// Create subscription and flush immediately.
tracing::info!(%nats_subject, tags = ?opts.tags, "creating subscription");
Expand All @@ -296,7 +281,7 @@ impl MessageCtx {
}

// Return handle
let subscription = SubscriptionHandle::new(nats_subject, subscription, self.req_id);
let subscription = SubscriptionHandle::new(nats_subject, subscription, opts.tags.clone());
Ok(subscription)
}

Expand All @@ -305,7 +290,7 @@ impl MessageCtx {
pub async fn tail_read<M>(
&self,
tags: serde_json::Value,
) -> WorkflowResult<Option<ReceivedMessage<M>>>
) -> WorkflowResult<Option<NatsMessage<M>>>
where
M: Message,
{
Expand All @@ -320,7 +305,7 @@ impl MessageCtx {

// Deserialize message
let message = if let Some(message_buf) = message_buf {
let message = ReceivedMessage::<M>::deserialize(message_buf.as_slice())?;
let message = NatsMessage::<M>::deserialize(message_buf.as_slice())?;
tracing::info!(?message, "immediate read tail message");

let recv_lag = (rivet_util::timestamp::now() as f64 - message.ts as f64) / 1000.;
Expand Down Expand Up @@ -410,7 +395,7 @@ where
_guard: DropGuard,
subject: String,
subscription: nats::Subscriber,
req_id: Uuid,
pub tags: serde_json::Value,
}

impl<M> Debug for SubscriptionHandle<M>
Expand All @@ -420,6 +405,7 @@ where
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SubscriptionHandle")
.field("subject", &self.subject)
.field("tags", &self.tags)
.finish()
}
}
Expand All @@ -429,7 +415,7 @@ where
M: Message,
{
#[tracing::instrument(level = "debug", skip_all)]
fn new(subject: String, subscription: nats::Subscriber, req_id: Uuid) -> Self {
fn new(subject: String, subscription: nats::Subscriber, tags: serde_json::Value) -> Self {
let token = CancellationToken::new();

{
Expand Down Expand Up @@ -458,34 +444,15 @@ where
_guard: token.drop_guard(),
subject,
subscription,
req_id,
tags,
}
}

/// Waits for the next message in the subscription.
///
/// This future can be safely dropped.
#[tracing::instrument]
pub async fn next(&mut self) -> WorkflowResult<ReceivedMessage<M>> {
self.next_inner(false).await
}

// TODO: Add a full config struct to pass to `next` that impl's `Default`
/// Waits for the next message in the subscription that originates from the
/// parent request ID via trace.
///
/// This future can be safely dropped.
#[tracing::instrument]
pub async fn next_with_trace(
&mut self,
filter_trace: bool,
) -> WorkflowResult<ReceivedMessage<M>> {
self.next_inner(filter_trace).await
}

/// This future can be safely dropped.
#[tracing::instrument(level = "trace")]
async fn next_inner(&mut self, filter_trace: bool) -> WorkflowResult<ReceivedMessage<M>> {
pub async fn next(&mut self) -> WorkflowResult<NatsMessage<M>> {
tracing::info!("waiting for message");

loop {
Expand All @@ -501,47 +468,22 @@ where
}
};

if filter_trace {
let message_wrapper =
ReceivedMessage::<M>::deserialize_wrapper(&nats_message.payload[..])?;

// Check if the message trace stack originates from this client
//
// We intentionally use the request ID instead of just checking the ray ID because
// there may be multiple calls to `message_with_subscribe` within the same ray.
// Explicitly checking the parent request ensures the response is unique to this
// message.
if message_wrapper
.trace
.iter()
.rev()
.any(|trace_entry| trace_entry.req_id == self.req_id)
{
let message = ReceivedMessage::<M>::deserialize(&nats_message.payload[..])?;
tracing::info!(?message, "received message");

return Ok(message);
}
} else {
let message = ReceivedMessage::<M>::deserialize(&nats_message.payload[..])?;
tracing::info!(?message, "received message");
let message_wrapper = NatsMessage::<M>::deserialize_wrapper(&nats_message.payload[..])?;

let recv_lag = (rivet_util::timestamp::now() as f64 - message.ts as f64) / 1000.;
crate::metrics::MESSAGE_RECV_LAG
.with_label_values(&[M::NAME])
.observe(recv_lag);
// Check if the subscription tags match a subset of the message tags
if utils::is_value_subset(&self.tags, &message_wrapper.tags) {
let message = NatsMessage::<M>::deserialize_from_wrapper(message_wrapper)?;
tracing::info!(?message, "received message");

return Ok(message);
}

// Message not from parent, continue with loop
// Message tags don't match, continue with loop
}
}

/// Converts the subscription in to a stream.
pub fn into_stream(
self,
) -> impl futures_util::Stream<Item = WorkflowResult<ReceivedMessage<M>>> {
pub fn into_stream(self) -> impl futures_util::Stream<Item = WorkflowResult<NatsMessage<M>>> {
futures_util::stream::try_unfold(self, |mut sub| async move {
let message = sub.next().await?;
Ok(Some((message, sub)))
Expand Down Expand Up @@ -569,7 +511,7 @@ pub enum TailAnchorResponse<M>
where
M: Message + Debug,
{
Message(ReceivedMessage<M>),
Message(NatsMessage<M>),

/// Anchor was older than the TTL of the message.
AnchorExpired,
Expand All @@ -589,30 +531,3 @@ where
}
}
}

mod redis_keys {
use std::{
collections::hash_map::DefaultHasher,
hash::{Hash, Hasher},
};

use crate::message::Message;

/// HASH
pub fn message_tail<M>(tags_str: &str) -> String
where
M: Message,
{
// Get hash of the tags
let mut hasher = DefaultHasher::new();
tags_str.hash(&mut hasher);

format!("{{topic:{}:{:x}}}:tail", M::NAME, hasher.finish())
}

pub mod message_tail {
pub const REQUEST_ID: &str = "r";
pub const TS: &str = "t";
pub const BODY: &str = "b";
}
}
Loading
Loading