From 3e6d00a55cb10b929936c7c46974c8de68fb61ef Mon Sep 17 00:00:00 2001 From: Alec Embke Date: Wed, 4 Sep 2024 16:29:00 -0700 Subject: [PATCH] 9.2.0 (#291) * feat: glommio support * feat: add credential provider * fix: mocked pipeline commands * feat: support pipelined transactions --------- Co-authored-by: Nuutti Kotivuori --- .circleci/config.yml | 38 +- .gitignore | 1 + CHANGELOG.md | 7 + Cargo.toml | 38 +- README.md | 48 +- examples/README.md | 2 + examples/glommio.rs | 86 +++ examples/pubsub.rs | 7 +- examples/transactions.rs | 2 - src/_tokio.rs | 395 +++++++++++++ src/clients/mod.rs | 5 +- src/clients/options.rs | 5 +- src/clients/pipeline.rs | 25 +- src/clients/pool.rs | 166 +++--- src/clients/pubsub.rs | 162 +++--- src/clients/redis.rs | 11 +- src/clients/replica.rs | 12 +- src/clients/sentinel.rs | 11 +- src/clients/transaction.rs | 131 +++-- src/commands/impls/cluster.rs | 2 +- src/commands/impls/lua.rs | 6 +- src/commands/impls/scan.rs | 58 +- src/commands/impls/server.rs | 7 +- src/commands/impls/tracking.rs | 2 +- src/commands/interfaces/acl.rs | 6 +- src/commands/interfaces/client.rs | 12 +- src/commands/interfaces/cluster.rs | 4 +- src/commands/interfaces/config.rs | 2 + src/commands/interfaces/geo.rs | 5 +- src/commands/interfaces/hashes.rs | 5 +- src/commands/interfaces/hyperloglog.rs | 5 +- src/commands/interfaces/keys.rs | 5 +- src/commands/interfaces/lists.rs | 7 +- src/commands/interfaces/lua.rs | 3 + src/commands/interfaces/memory.rs | 2 + src/commands/interfaces/pubsub.rs | 4 +- src/commands/interfaces/redis_json.rs | 2 + src/commands/interfaces/redisearch.rs | 2 + src/commands/interfaces/sentinel.rs | 2 + src/commands/interfaces/server.rs | 2 + src/commands/interfaces/sets.rs | 5 +- src/commands/interfaces/slowlog.rs | 2 + src/commands/interfaces/sorted_sets.rs | 5 +- src/commands/interfaces/streams.rs | 2 + src/commands/interfaces/timeseries.rs | 2 + src/commands/interfaces/tracking.rs | 6 +- src/commands/mod.rs | 1 - src/error.rs | 22 +- src/glommio/README.md | 45 ++ src/glommio/broadcast.rs | 93 ++++ src/glommio/interfaces.rs | 361 ++++++++++++ src/glommio/io_compat.rs | 68 +++ src/glommio/mod.rs | 118 ++++ src/glommio/mpsc.rs | 83 +++ src/glommio/notes.md | 0 src/glommio/sync.rs | 163 ++++++ src/interfaces.rs | 375 +------------ src/lib.rs | 11 + src/modules/backchannel.rs | 13 +- src/modules/inner.rs | 306 ++++++---- src/modules/mocks.rs | 22 +- src/monitor/parser.rs | 10 +- src/monitor/utils.rs | 39 +- src/protocol/cluster.rs | 13 +- src/protocol/codec.rs | 12 +- src/protocol/command.rs | 69 +-- src/protocol/connection.rs | 143 +++-- src/protocol/responders.rs | 103 ++-- src/protocol/types.rs | 63 ++- src/protocol/utils.rs | 13 +- src/router/centralized.rs | 36 +- src/router/clustered.rs | 113 ++-- src/router/commands.rs | 128 +++-- src/router/mod.rs | 41 +- src/router/reader.rs | 1 - src/router/replicas.rs | 18 +- src/router/responses.rs | 31 +- src/router/sentinel.rs | 75 ++- src/router/transactions.rs | 558 ++++++++++++------- src/router/utils.rs | 54 +- src/trace/disabled.rs | 8 +- src/trace/enabled.rs | 20 +- src/types/builder.rs | 2 + src/types/config.rs | 128 +++-- src/types/mod.rs | 3 +- src/types/scan.rs | 11 +- src/utils.rs | 64 +-- tests/doc-glommio.sh | 7 + tests/doc.sh | 7 +- tests/docker/compose/base.yml | 4 + tests/docker/compose/glommio.yml | 25 + tests/docker/runners/bash/all-features.sh | 2 +- tests/docker/runners/bash/check-glommio.sh | 7 + tests/docker/runners/bash/mocks.sh | 4 +- tests/docker/runners/images/base.dockerfile | 2 +- tests/docker/runners/images/ci.dockerfile | 2 +- tests/docker/runners/images/debug.dockerfile | 7 +- tests/integration/centralized.rs | 20 +- tests/integration/clustered.rs | 26 +- tests/integration/other/mod.rs | 69 ++- tests/runners/check-glommio.sh | 4 + tests/runners/mocks.sh | 2 +- tests/scripts/build-gh-pages.sh | 19 + tests/scripts/check_glommio_features.sh | 14 + 104 files changed, 3509 insertions(+), 1466 deletions(-) create mode 100644 examples/glommio.rs create mode 100644 src/_tokio.rs create mode 100644 src/glommio/README.md create mode 100644 src/glommio/broadcast.rs create mode 100644 src/glommio/interfaces.rs create mode 100644 src/glommio/io_compat.rs create mode 100644 src/glommio/mod.rs create mode 100644 src/glommio/mpsc.rs create mode 100644 src/glommio/notes.md create mode 100644 src/glommio/sync.rs delete mode 100644 src/router/reader.rs create mode 100755 tests/doc-glommio.sh create mode 100644 tests/docker/compose/glommio.yml create mode 100755 tests/docker/runners/bash/check-glommio.sh create mode 100755 tests/runners/check-glommio.sh create mode 100755 tests/scripts/build-gh-pages.sh create mode 100755 tests/scripts/check_glommio_features.sh diff --git a/.circleci/config.yml b/.circleci/config.yml index 28d31d8c..2fab27b5 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -234,7 +234,7 @@ jobs: - test_valkey test-misc: docker: - - image: cimg/rust:1.78 + - image: cimg/rust:1.80 environment: CARGO_NET_GIT_FETCH_WITH_CLI: true steps: @@ -246,27 +246,53 @@ jobs: command: cargo check --features partial-tracing check-all-interface-features: docker: - - image: cimg/rust:1.78 + - image: rust:1.80-slim-bullseye environment: CARGO_NET_GIT_FETCH_WITH_CLI: true + YQ_VERSION: v4.44.3 steps: - checkout + - run: + name: Install build dependencies + command: | + apt-get update && apt-get install -y build-essential libssl-dev pkg-config git wget + wget -qO /usr/local/bin/yq https://github.com/mikefarah/yq/releases/download/${YQ_VERSION}/yq_linux_amd64 + chmod +x /usr/local/bin/yq + rustup component add clippy - run: name: Check all features command: tests/scripts/check_features.sh + - run: + name: Check all Glommio features + command: tests/scripts/check_glommio_features.sh clippy-lint: docker: - - image: cimg/rust:1.78 + - image: rust:1.80-slim-bullseye environment: CARGO_NET_GIT_FETCH_WITH_CLI: true steps: - checkout - run: - name: Clippy - command: cargo clippy --all-features --lib -p fred -- -Dwarnings + name: Install build dependencies + command: apt-get update && apt-get install -y build-essential libssl-dev pkg-config git + - run: + name: Install clippy + command: rustup component add clippy + - run: + name: Clippy Tokio features + command: | + cargo clippy --features "i-all i-redis-stack transactions blocking-encoding dns metrics mocks monitor \ + replicas sentinel-auth sentinel-client serde-json subscriber-client unix-sockets credential-provider \ + enable-rustls enable-native-tls full-tracing" --lib -p fred -- -Dwarnings + - run: + name: Clippy Glommio features + command: | + cargo clippy --features "i-all i-redis-stack transactions blocking-encoding dns metrics mocks monitor \ + replicas sentinel-auth sentinel-client serde-json subscriber-client glommio credential-provider \ + enable-rustls enable-native-tls full-tracing" --lib -p fred -- -Dwarnings cargo-fmt: docker: - - image: cimg/rust:1.78 + - image: cimg/rust:1.80 environment: CARGO_NET_GIT_FETCH_WITH_CLI: true steps: diff --git a/.gitignore b/.gitignore index 4d26f224..becd1b3f 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ target Cargo.lock .idea +.doc tests/tmp/* !tests/tmp/.gitkeep dump.rdb diff --git a/CHANGELOG.md b/CHANGELOG.md index 9c47e0f8..9852a42e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,10 @@ +## 9.2.0 + +* Add initial support for the [Glommio](https://github.com/DataDog/glommio) runtime +* Add `credential-provider` feature +* Fix pipeline processing in mocks +* Support pipelined transactions + ## 9.1.2 * Fix `FT.AGGREGATE` command with `SORTBY` operation diff --git a/Cargo.toml b/Cargo.toml index a1582cf0..f44952fe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,10 +11,30 @@ name = "fred" readme = "README.md" repository = "https://github.com/aembke/fred.rs" rust-version = "1.75" -version = "9.1.2" +version = "9.2.0" [package.metadata.docs.rs] -all-features = true +# do not show the glommio version of the docs +features = [ + "i-all", + "i-redis-stack", + "transactions", + "blocking-encoding", + "dns", + "metrics", + "mocks", + "monnitor", + "replicas", + "sentinel-auth", + "sentinel-client", + "serde-json", + "subscriber-client", + "unix-sockets", + "enable-rustls", + "enable-native-tls", + "full-tracing", + "credential-provider" +] rustdoc-args = ["--cfg", "docsrs"] [lib] @@ -40,6 +60,10 @@ subscriber-client = ["i-pubsub"] transactions = [] trust-dns-resolver = ["dep:trust-dns-resolver"] unix-sockets = [] +credential-provider = [] + +# Enable experimental support for the Glommio runtime. +glommio = ["dep:glommio", "futures-io", "pin-project", "fred-macros/enabled", "oneshot", "futures-lite"] # Enables rustls with the rustls/aws_lc_rs crypto backend enable-rustls = [ @@ -170,6 +194,12 @@ trust-dns-resolver = { version = "0.23", optional = true, features = ["tokio"] } hickory-resolver = { version = "0.24.1", optional = true, features = ["tokio"] } url = "2.4" urlencoding = "2.1" +fred-macros = "0.1" +glommio = { version = "0.9.0", optional = true } +futures-io = { version = "0.3", optional = true } +pin-project = { version = "1.1.5", optional = true } +oneshot = { version = "0.1.8", optional = true, features = ["async"] } +futures-lite = { version = "2.3", optional = true } [dev-dependencies] axum = { version = "0.7", features = ["macros"] } @@ -180,6 +210,10 @@ serde = { version = "1.0", features = ["derive"] } subprocess = "0.2" tokio-stream = { version = "0.1", features = ["sync"] } +[[example]] +name = "glommio" +required-features = ["glommio", "i-std"] + [[example]] name = "misc" required-features = ["i-all"] diff --git a/README.md b/README.md index 3e06a44b..05e03b41 100644 --- a/README.md +++ b/README.md @@ -58,29 +58,31 @@ See the build features for more information. ## Client Features -| Name | Default | Description | -|---------------------------|---------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| `transactions` | x | Enable a [Transaction](https://redis.io/docs/interact/transactions/) interface. | -| `enable-native-tls` | | Enable TLS support via [native-tls](https://crates.io/crates/native-tls). | -| `enable-rustls` | | Enable TLS support via [rustls](https://crates.io/crates/rustls) with the default crypto backend features. | -| `enable-rustls-ring` | | Enable TLS support via [rustls](https://crates.io/crates/rustls) and the ring crypto backend. | -| `vendored-openssl` | | Enable the `native-tls/vendored` feature. | -| `metrics` | | Enable the metrics interface to track overall latency, network latency, and request/response sizes. | -| `full-tracing` | | Enable full [tracing](./src/trace/README.md) support. This can emit a lot of data. | -| `partial-tracing` | | Enable partial [tracing](./src/trace/README.md) support, only emitting traces for top level commands and network latency. | -| `blocking-encoding` | | Use a blocking task for encoding or decoding frames. This can be useful for clients that send or receive large payloads, but requires a multi-thread Tokio runtime. | -| `custom-reconnect-errors` | | Enable an interface for callers to customize the types of errors that should automatically trigger reconnection logic. | -| `monitor` | | Enable an interface for running the `MONITOR` command. | -| `sentinel-client` | | Enable an interface for communicating directly with Sentinel nodes. This is not necessary to use normal Redis clients behind a sentinel layer. | -| `sentinel-auth` | | Enable an interface for using different authentication credentials to sentinel nodes. | -| `subscriber-client` | | Enable a subscriber client interface that manages channel subscription state for callers. | -| `serde-json` | | Enable an interface to automatically convert Redis types to JSON via `serde-json`. | -| `mocks` | | Enable a mocking layer interface that can be used to intercept and process commands in tests. | -| `dns` | | Enable an interface that allows callers to override the DNS lookup logic. | -| `replicas` | | Enable an interface that routes commands to replica nodes. | -| `default-nil-types` | | Enable a looser parsing interface for `nil` values. | -| `sha-1` | | Enable an interface for hashing Lua scripts. | -| `unix-sockets` | | Enable Unix socket support. | +| Name | Default | Description | +|---------------------------|---------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `transactions` | x | Enable a [Transaction](https://redis.io/docs/interact/transactions/) interface. | +| `enable-native-tls` | | Enable TLS support via [native-tls](https://crates.io/crates/native-tls). | +| `enable-rustls` | | Enable TLS support via [rustls](https://crates.io/crates/rustls) with the default crypto backend features. | +| `enable-rustls-ring` | | Enable TLS support via [rustls](https://crates.io/crates/rustls) and the ring crypto backend. | +| `vendored-openssl` | | Enable the `native-tls/vendored` feature. | +| `metrics` | | Enable the metrics interface to track overall latency, network latency, and request/response sizes. | +| `full-tracing` | | Enable full [tracing](./src/trace/README.md) support. This can emit a lot of data. | +| `partial-tracing` | | Enable partial [tracing](./src/trace/README.md) support, only emitting traces for top level commands and network latency. | +| `blocking-encoding` | | Use a blocking task for encoding or decoding frames. This can be useful for clients that send or receive large payloads, but requires a multi-thread Tokio runtime. | +| `custom-reconnect-errors` | | Enable an interface for callers to customize the types of errors that should automatically trigger reconnection logic. | +| `monitor` | | Enable an interface for running the `MONITOR` command. | +| `sentinel-client` | | Enable an interface for communicating directly with Sentinel nodes. This is not necessary to use normal Redis clients behind a sentinel layer. | +| `sentinel-auth` | | Enable an interface for using different authentication credentials to sentinel nodes. | +| `subscriber-client` | | Enable a subscriber client interface that manages channel subscription state for callers. | +| `serde-json` | | Enable an interface to automatically convert Redis types to JSON via `serde-json`. | +| `mocks` | | Enable a mocking layer interface that can be used to intercept and process commands in tests. | +| `dns` | | Enable an interface that allows callers to override the DNS lookup logic. | +| `replicas` | | Enable an interface that routes commands to replica nodes. | +| `default-nil-types` | | Enable a looser parsing interface for `nil` values. | +| `sha-1` | | Enable an interface for hashing Lua scripts. | +| `unix-sockets` | | Enable Unix socket support. | +| `glommio` | | Enable experimental [Glommio](https://github.com/DataDog/glommio) support. See the [Glommio Runtime](https://github.com/aembke/fred.rs/blob/main/src/glommio/README.md) docs for more information. When enabled the client will no longer work with Tokio runtimes. | +| `credential-provider` | | Enable an interface that can dynamically load auth credentials at runtime. | ## Interface Features diff --git a/examples/README.md b/examples/README.md index 71d836fe..45438553 100644 --- a/examples/README.md +++ b/examples/README.md @@ -26,5 +26,7 @@ Examples the [keyspace notifications](https://redis.io/docs/manual/keyspace-notifications/) interface. * [Misc](./misc.rs) - Miscellaneous or advanced features. * [Replicas](./replicas.rs) - Interact with cluster replica nodes via a `RedisPool`. +* [Glommio](./glommio.rs) - Use the [Glommio](https://github.com/DataDog/glommio) runtime. + See [the source docs](../src/glommio/README.md) for more information. Or see the [tests](../tests/integration) for more examples. \ No newline at end of file diff --git a/examples/glommio.rs b/examples/glommio.rs new file mode 100644 index 00000000..031e32fb --- /dev/null +++ b/examples/glommio.rs @@ -0,0 +1,86 @@ +use fred::prelude::*; +use futures::future::try_join_all; +use glommio::{prelude::*, DefaultStallDetectionHandler}; +use log::info; +use std::{cell::RefCell, rc::Rc, time::SystemTime}; + +/// The number of threads in the Glommio pool builder. +const THREADS: usize = 8; +/// The total number of Redis clients used across all threads. +const POOL_SIZE: usize = 16; +/// The number of concurrent tasks spawned on each thread. +const CONCURRENCY: usize = 500; +/// The total number of increment commands sent to the servers. +const COUNT: usize = 100_000_000; + +fn main() { + pretty_env_logger::init(); + let config = RedisConfig::from_url("redis-cluster://foo:bar@redis-cluster-1:30001").unwrap(); + let builder = Builder::from_config(config); + let started = SystemTime::now(); + + LocalExecutorPoolBuilder::new(PoolPlacement::Unbound(THREADS)) + .on_all_shards(move || { + // Each thread sends `COUNT / THREADS` commands to the server, sharing a client pool of `POOL_SIZE / THREADS` + // clients among `CONCURRENCY` local tasks. + let mut builder = builder.clone(); + let thread_id = executor().id(); + + async move { + // customize the task queues used by the client, if needed + builder.with_connection_config(|config| { + config.connection_task_queue = + Some(executor().create_task_queue(Shares::default(), Latency::NotImportant, "connection_queue")); + config.router_task_queue = + Some(executor().create_task_queue(Shares::default(), Latency::NotImportant, "router_queue")); + }); + + let clients = POOL_SIZE / THREADS; + let pool = builder.build_pool(clients)?; + info!("{}: Connecting to Redis with {} clients", thread_id, clients); + pool.init().await?; + info!("{}: Starting incr loop", thread_id); + incr_foo(&pool).await?; + + pool.quit().await?; + Ok::<_, RedisError>(thread_id) + } + }) + .unwrap() + .join_all() + .into_iter() + .for_each(|result| match result { + Ok(Ok(id)) => println!("Finished thread {}", id), + Ok(Err(e)) => println!("Redis error: {:?}", e), + Err(e) => println!("Glommio error: {:?}", e), + }); + + let dur = SystemTime::now().duration_since(started).unwrap(); + let dur_sec = dur.as_secs() as f64 + (dur.subsec_millis() as f64 / 1000.0); + println!( + "Performed {} operations in: {:?}. Throughput: {} req/sec", + COUNT, + dur, + (COUNT as f64 / dur_sec) as u64 + ); +} + +async fn incr_foo(pool: &RedisPool) -> Result<(), RedisError> { + let counter = Rc::new(RefCell::new(0)); + let mut tasks = Vec::with_capacity(CONCURRENCY); + for _ in 0 .. CONCURRENCY { + let counter = counter.clone(); + let pool = pool.clone(); + tasks.push(spawn_local(async move { + while *counter.borrow() < COUNT / THREADS { + pool.incr::<(), _>("foo").await?; + *counter.borrow_mut() += 1; + } + + Ok::<_, RedisError>(()) + })); + } + try_join_all(tasks).await?; + + Ok(()) +} diff --git a/examples/pubsub.rs b/examples/pubsub.rs index ca3de715..17a5ec45 100644 --- a/examples/pubsub.rs +++ b/examples/pubsub.rs @@ -9,7 +9,12 @@ use tokio::time::sleep; #[tokio::main] async fn main() -> Result<(), RedisError> { - let publisher_client = RedisClient::default(); + let publisher_client = Builder::default_centralized() + .with_performance_config(|config| { + // change the buffer size of the broadcast channels used by the EventInterface + config.broadcast_channel_capacity = 64; + }) + .build()?; let subscriber_client = publisher_client.clone_new(); publisher_client.init().await?; subscriber_client.init().await?; diff --git a/examples/transactions.rs b/examples/transactions.rs index 0f83c83a..f46fec9d 100644 --- a/examples/transactions.rs +++ b/examples/transactions.rs @@ -17,8 +17,6 @@ async fn main() -> Result<(), RedisError> { let result: RedisValue = trx.get("foo").await?; assert!(result.is_queued()); - // automatically send `WATCH ...` before `MULTI` - trx.watch_before(vec!["foo", "bar"]); let values: (Option, (), String) = trx.exec(true).await?; println!("Transaction results: {:?}", values); diff --git a/src/_tokio.rs b/src/_tokio.rs new file mode 100644 index 00000000..8720a234 --- /dev/null +++ b/src/_tokio.rs @@ -0,0 +1,395 @@ +use crate::{ + clients::WithOptions, + commands, + error::RedisError, + interfaces::{default_send_command, RedisResult}, + modules::inner::RedisClientInner, + protocol::command::RedisCommand, + router::commands as router_commands, + types::{ + ClientState, + ConnectHandle, + ConnectionConfig, + CustomCommand, + FromRedis, + InfoKind, + Options, + PerformanceConfig, + ReconnectPolicy, + RedisConfig, + RedisValue, + Resp3Frame, + RespVersion, + Server, + Version, + }, + utils, +}; +use arc_swap::ArcSwapAny; +use futures::Stream; +use std::{future::Future, sync::Arc}; +use tokio::sync::broadcast::{Receiver, Sender}; +pub use tokio::{ + spawn, + sync::{ + broadcast::{self, error::SendError as BroadcastSendError}, + mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}, + oneshot::{channel as oneshot_channel, Receiver as OneshotReceiver, Sender as OneshotSender}, + RwLock as AsyncRwLock, + }, + task::JoinHandle, + time::sleep, +}; +use tokio_stream::wrappers::UnboundedReceiverStream; + +#[cfg(any(feature = "dns", feature = "trust-dns-resolver"))] +use crate::protocol::types::Resolve; + +#[cfg(feature = "i-server")] +use crate::types::ShutdownFlags; + +/// The reference counting container type. +/// +/// This type may change based on the runtime feature flags used. +pub type RefCount = Arc; + +pub type AtomicBool = std::sync::atomic::AtomicBool; +pub type AtomicUsize = std::sync::atomic::AtomicUsize; +pub type Mutex = parking_lot::Mutex; +pub type RwLock = parking_lot::RwLock; +pub type RefSwap = ArcSwapAny; +pub type BroadcastSender = Sender; +pub type BroadcastReceiver = Receiver; + +pub fn broadcast_send(tx: &BroadcastSender, msg: &T, func: F) { + if let Err(BroadcastSendError(val)) = tx.send(msg.clone()) { + func(&val); + } +} + +pub fn broadcast_channel(capacity: usize) -> (BroadcastSender, BroadcastReceiver) { + broadcast::channel(capacity) +} + +pub fn rx_stream(rx: UnboundedReceiver) -> impl Stream { + UnboundedReceiverStream::new(rx) +} + +/// Any Redis client that implements any part of the Redis interface. +pub trait ClientLike: Clone + Send + Sync + Sized { + #[doc(hidden)] + fn inner(&self) -> &Arc; + + /// Helper function to intercept and modify a command without affecting how it is sent to the connection layer. + #[doc(hidden)] + fn change_command(&self, _: &mut RedisCommand) {} + + /// Helper function to intercept and customize how a command is sent to the connection layer. + #[doc(hidden)] + fn send_command(&self, command: C) -> Result<(), RedisError> + where + C: Into, + { + let mut command: RedisCommand = command.into(); + self.change_command(&mut command); + default_send_command(self.inner(), command) + } + + /// The unique ID identifying this client and underlying connections. + fn id(&self) -> &str { + &self.inner().id + } + + /// Read the config used to initialize the client. + fn client_config(&self) -> RedisConfig { + self.inner().config.as_ref().clone() + } + + /// Read the reconnect policy used to initialize the client. + fn client_reconnect_policy(&self) -> Option { + self.inner().policy.read().clone() + } + + /// Read the connection config used to initialize the client. + fn connection_config(&self) -> &ConnectionConfig { + self.inner().connection.as_ref() + } + + /// Read the RESP version used by the client when communicating with the server. + fn protocol_version(&self) -> RespVersion { + if self.inner().is_resp3() { + RespVersion::RESP3 + } else { + RespVersion::RESP2 + } + } + + /// Whether the client has a reconnection policy. + fn has_reconnect_policy(&self) -> bool { + self.inner().policy.read().is_some() + } + + /// Whether the client will automatically pipeline commands. + fn is_pipelined(&self) -> bool { + self.inner().is_pipelined() + } + + /// Whether the client is connected to a cluster. + fn is_clustered(&self) -> bool { + self.inner().config.server.is_clustered() + } + + /// Whether the client uses the sentinel interface. + fn uses_sentinels(&self) -> bool { + self.inner().config.server.is_sentinel() + } + + /// Update the internal [PerformanceConfig](crate::types::PerformanceConfig) in place with new values. + fn update_perf_config(&self, config: PerformanceConfig) { + self.inner().update_performance_config(config); + } + + /// Read the [PerformanceConfig](crate::types::PerformanceConfig) associated with this client. + fn perf_config(&self) -> PerformanceConfig { + self.inner().performance_config() + } + + /// Read the state of the underlying connection(s). + /// + /// If running against a cluster the underlying state will reflect the state of the least healthy connection. + fn state(&self) -> ClientState { + self.inner().state.read().clone() + } + + /// Whether all underlying connections are healthy. + fn is_connected(&self) -> bool { + *self.inner().state.read() == ClientState::Connected + } + + /// Read the set of active connections managed by the client. + fn active_connections(&self) -> impl Future, RedisError>> + Send { + commands::server::active_connections(self) + } + + /// Read the server version, if known. + fn server_version(&self) -> Option { + self.inner().server_state.read().kind.server_version() + } + + /// Override the DNS resolution logic for the client. + #[cfg(feature = "dns")] + #[cfg_attr(docsrs, doc(cfg(feature = "dns")))] + fn set_resolver(&self, resolver: Arc) -> impl Future + Send { + async move { self.inner().set_resolver(resolver).await } + } + + /// Connect to the server. + /// + /// This function returns a `JoinHandle` to a task that drives the connection. It will not resolve until the + /// connection closes, or if a reconnection policy with unlimited attempts is provided then it will + /// run until `QUIT` is called. Callers should avoid calling [abort](tokio::task::JoinHandle::abort) on the returned + /// `JoinHandle` unless the client will no longer be used. + /// + /// **Calling this function more than once will drop all state associated with the previous connection(s).** Any + /// pending commands on the old connection(s) will either finish or timeout, but they will not be retried on the + /// new connection(s). + /// + /// See [init](Self::init) for an alternative shorthand. + fn connect(&self) -> ConnectHandle { + let inner = self.inner().clone(); + utils::reset_router_task(&inner); + + tokio::spawn(async move { + utils::clear_backchannel_state(&inner).await; + let result = router_commands::start(&inner).await; + // a canceled error means we intentionally closed the client + _trace!(inner, "Ending connection task with {:?}", result); + + if let Err(ref error) = result { + if !error.is_canceled() { + inner.notifications.broadcast_connect(Err(error.clone())); + } + } + + utils::check_and_set_client_state(&inner.state, ClientState::Disconnecting, ClientState::Disconnected); + result + }) + } + + /// Force a reconnection to the server(s). + /// + /// When running against a cluster this function will also refresh the cached cluster routing table. + fn force_reconnection(&self) -> impl Future> + Send { + async move { commands::server::force_reconnection(self.inner()).await } + } + + /// Wait for the result of the next connection attempt. + /// + /// This can be used with `on_reconnect` to separate initialization logic that needs to occur only on the next + /// connection attempt vs all subsequent attempts. + fn wait_for_connect(&self) -> impl Future> + Send { + async move { + if utils::read_locked(&self.inner().state) == ClientState::Connected { + debug!("{}: Client is already connected.", self.inner().id); + Ok(()) + } else { + self.inner().notifications.connect.load().subscribe().recv().await? + } + } + } + + /// Initialize a new routing and connection task and wait for it to connect successfully. + /// + /// The returned [ConnectHandle](crate::types::ConnectHandle) refers to the task that drives the routing and + /// connection layer. It will not finish until the max reconnection count is reached. Callers should avoid calling + /// [abort](tokio::task::JoinHandle::abort) on the returned `JoinHandle` unless the client will no longer be used. + /// + /// Callers can also use [connect](Self::connect) and [wait_for_connect](Self::wait_for_connect) separately if + /// needed. + /// + /// ```rust + /// use fred::prelude::*; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), RedisError> { + /// let client = RedisClient::default(); + /// let connection_task = client.init().await?; + /// + /// // ... + /// + /// client.quit().await?; + /// connection_task.await? + /// } + /// ``` + fn init(&self) -> impl Future> + Send { + async move { + let mut rx = { self.inner().notifications.connect.load().subscribe() }; + let task = self.connect(); + let error = rx.recv().await.map_err(RedisError::from).and_then(|r| r).err(); + + if let Some(error) = error { + // the initial connection failed, so we should gracefully close the routing task + utils::reset_router_task(self.inner()); + Err(error) + } else { + Ok(task) + } + } + } + + /// Close the connection to the Redis server. The returned future resolves when the command has been written to the + /// socket, not when the connection has been fully closed. Some time after this future resolves the future + /// returned by [connect](Self::connect) will resolve which indicates that the connection has been fully closed. + /// + /// This function will also close all error, pubsub message, and reconnection event streams. + fn quit(&self) -> impl Future> + Send { + async move { commands::server::quit(self).await } + } + + /// Shut down the server and quit the client. + /// + /// + #[cfg(feature = "i-server")] + #[cfg_attr(docsrs, doc(cfg(feature = "i-server")))] + fn shutdown(&self, flags: Option) -> impl Future> + Send { + async move { commands::server::shutdown(self, flags).await } + } + + /// Delete the keys in all databases. + /// + /// + fn flushall(&self, r#async: bool) -> impl Future> + Send + where + R: FromRedis, + { + async move { commands::server::flushall(self, r#async).await?.convert() } + } + + /// Delete the keys on all nodes in the cluster. This is a special function that does not map directly to the Redis + /// interface. + fn flushall_cluster(&self) -> impl Future> + Send { + async move { commands::server::flushall_cluster(self).await } + } + + /// Ping the Redis server. + /// + /// + fn ping(&self) -> impl Future> + Send + where + R: FromRedis, + { + async move { commands::server::ping(self).await?.convert() } + } + + /// Read info about the server. + /// + /// + fn info(&self, section: Option) -> impl Future> + Send + where + R: FromRedis, + { + async move { commands::server::info(self, section).await?.convert() } + } + + /// Run a custom command that is not yet supported via another interface on this client. This is most useful when + /// interacting with third party modules or extensions. + /// + /// Callers should use the re-exported [redis_keyslot](crate::util::redis_keyslot) function to hash the command's + /// key, if necessary. + /// + /// This interface should be used with caution as it may break the automatic pipeline features in the client if + /// command flags are not properly configured. + fn custom(&self, cmd: CustomCommand, args: Vec) -> impl Future> + Send + where + R: FromRedis, + T: TryInto + Send, + T::Error: Into + Send, + { + async move { + let args = utils::try_into_vec(args)?; + commands::server::custom(self, cmd, args).await?.convert() + } + } + + /// Run a custom command similar to [custom](Self::custom), but return the response frame directly without any + /// parsing. + /// + /// Note: RESP2 frames from the server are automatically converted to the RESP3 format when parsed by the client. + fn custom_raw(&self, cmd: CustomCommand, args: Vec) -> impl Future> + Send + where + T: TryInto + Send, + T::Error: Into + Send, + { + async move { + let args = utils::try_into_vec(args)?; + commands::server::custom_raw(self, cmd, args).await + } + } + + /// Customize various configuration options on commands. + fn with_options(&self, options: &Options) -> WithOptions { + WithOptions { + client: self.clone(), + options: options.clone(), + } + } +} + +pub fn spawn_event_listener(mut rx: BroadcastReceiver, func: F) -> JoinHandle> +where + T: Clone + Send + 'static, + F: Fn(T) -> RedisResult<()> + Send + 'static, +{ + tokio::spawn(async move { + let mut result = Ok(()); + + while let Ok(val) = rx.recv().await { + if let Err(err) = func(val) { + result = Err(err); + break; + } + } + + result + }) +} diff --git a/src/clients/mod.rs b/src/clients/mod.rs index 0c502d89..2906a02e 100644 --- a/src/clients/mod.rs +++ b/src/clients/mod.rs @@ -5,9 +5,12 @@ mod redis; pub use options::WithOptions; pub use pipeline::Pipeline; -pub use pool::{ExclusivePool, RedisPool}; +pub use pool::RedisPool; pub use redis::RedisClient; +#[cfg(not(feature = "glommio"))] +pub use pool::ExclusivePool; + #[cfg(feature = "sentinel-client")] mod sentinel; #[cfg(feature = "sentinel-client")] diff --git a/src/clients/options.rs b/src/clients/options.rs index 10efbe4f..19a89fed 100644 --- a/src/clients/options.rs +++ b/src/clients/options.rs @@ -3,9 +3,10 @@ use crate::{ interfaces::*, modules::inner::RedisClientInner, protocol::command::RedisCommand, + runtime::RefCount, types::Options, }; -use std::{fmt, ops::Deref, sync::Arc}; +use std::{fmt, ops::Deref}; /// A client interface used to customize command configuration options. /// @@ -73,7 +74,7 @@ impl fmt::Debug for WithOptions { impl ClientLike for WithOptions { #[doc(hidden)] - fn inner(&self) -> &Arc { + fn inner(&self) -> &RefCount { self.client.inner() } diff --git a/src/clients/pipeline.rs b/src/clients/pipeline.rs index f28a2e46..952092b7 100644 --- a/src/clients/pipeline.rs +++ b/src/clients/pipeline.rs @@ -8,11 +8,10 @@ use crate::{ responders::ResponseKind, utils as protocol_utils, }, + runtime::{oneshot_channel, Mutex, OneshotReceiver, RefCount}, utils, }; -use parking_lot::Mutex; -use std::{collections::VecDeque, fmt, fmt::Formatter, sync::Arc}; -use tokio::sync::oneshot::{channel as oneshot_channel, Receiver as OneshotReceiver}; +use std::{collections::VecDeque, fmt, fmt::Formatter}; fn clone_buffered_commands(buffer: &Mutex>) -> VecDeque { let guard = buffer.lock(); @@ -55,7 +54,7 @@ fn prepare_all_commands( /// /// See the [all](Self::all), [last](Self::last), and [try_all](Self::try_all) functions for more information. pub struct Pipeline { - commands: Arc>>, + commands: RefCount>>, client: C, } @@ -83,14 +82,14 @@ impl From for Pipeline { fn from(client: C) -> Self { Pipeline { client, - commands: Arc::new(Mutex::new(VecDeque::new())), + commands: RefCount::new(Mutex::new(VecDeque::new())), } } } impl ClientLike for Pipeline { #[doc(hidden)] - fn inner(&self) -> &Arc { + fn inner(&self) -> &RefCount { self.client.inner() } @@ -100,6 +99,7 @@ impl ClientLike for Pipeline { } #[doc(hidden)] + #[allow(unused_mut)] fn send_command(&self, command: T) -> Result<(), RedisError> where T: Into, @@ -107,7 +107,7 @@ impl ClientLike for Pipeline { let mut command: RedisCommand = command.into(); self.change_command(&mut command); - if let Some(tx) = command.take_responder() { + if let Some(mut tx) = command.take_responder() { trace!( "{}: Respond early to {} command in pipeline.", &self.client.inner().id, @@ -225,7 +225,7 @@ impl Pipeline { /// let _: () = pipeline.hgetall("bar").await?; // this will error since `bar` is an integer /// /// let results = pipeline.try_all::().await; - /// assert_eq!(results[0].clone().unwrap().convert::(), 1); + /// assert_eq!(results[0].clone()?.convert::()?, 1); /// assert!(results[1].is_err()); /// /// Ok(()) @@ -268,7 +268,7 @@ impl Pipeline { } async fn try_send_all( - inner: &Arc, + inner: &RefCount, commands: VecDeque, ) -> Vec> { if commands.is_empty() { @@ -297,7 +297,10 @@ async fn try_send_all( } } -async fn send_all(inner: &Arc, commands: VecDeque) -> Result { +async fn send_all( + inner: &RefCount, + commands: VecDeque, +) -> Result { if commands.is_empty() { return Ok(RedisValue::Array(Vec::new())); } @@ -312,7 +315,7 @@ async fn send_all(inner: &Arc, commands: VecDeque, + inner: &RefCount, commands: VecDeque, ) -> Result { if commands.is_empty() { diff --git a/src/clients/pool.rs b/src/clients/pool.rs index d7a084bd..f427e5eb 100644 --- a/src/clients/pool.rs +++ b/src/clients/pool.rs @@ -3,27 +3,20 @@ use crate::{ error::{RedisError, RedisErrorKind}, interfaces::*, modules::inner::RedisClientInner, + runtime::{sleep, spawn, AtomicBool, AtomicUsize, RefCount}, types::{ConnectHandle, ConnectionConfig, PerformanceConfig, ReconnectPolicy, RedisConfig, Server}, utils, }; +use fred_macros::rm_send_if; use futures::future::{join_all, try_join_all}; -use std::{ - fmt, - sync::{ - atomic::{AtomicBool, AtomicUsize}, - Arc, - }, - time::Duration, -}; -use tokio::{ - sync::{Mutex as AsyncMutex, OwnedMutexGuard}, - time::interval as tokio_interval, -}; +use std::{fmt, future::Future, time::Duration}; #[cfg(feature = "replicas")] use crate::clients::Replicas; #[cfg(feature = "dns")] use crate::protocol::types::Resolve; +#[cfg(not(feature = "glommio"))] +pub use tokio::sync::{Mutex as AsyncMutex, OwnedMutexGuard}; struct RedisPoolInner { clients: Vec, @@ -48,7 +41,7 @@ struct RedisPoolInner { /// [clients](Self::clients), [next](Self::next), or [last](Self::last) to operate on individual clients if needed. #[derive(Clone)] pub struct RedisPool { - inner: Arc, + inner: RefCount, } impl fmt::Debug for RedisPool { @@ -70,7 +63,7 @@ impl RedisPool { Err(RedisError::new(RedisErrorKind::Config, "Pool cannot be empty.")) } else { Ok(RedisPool { - inner: Arc::new(RedisPoolInner { + inner: RefCount::new(RedisPoolInner { clients, counter: AtomicUsize::new(0), prefer_connected: AtomicBool::new(true), @@ -103,7 +96,7 @@ impl RedisPool { } Ok(RedisPool { - inner: Arc::new(RedisPoolInner { + inner: RefCount::new(RedisPoolInner { clients, counter: AtomicUsize::new(0), prefer_connected: AtomicBool::new(true), @@ -168,9 +161,10 @@ impl RedisPool { } } +#[rm_send_if(feature = "glommio")] impl ClientLike for RedisPool { #[doc(hidden)] - fn inner(&self) -> &Arc { + fn inner(&self) -> &RefCount { if utils::read_bool_atomic(&self.inner.prefer_connected) { &self.next_connected().inner } else { @@ -187,28 +181,32 @@ impl ClientLike for RedisPool { } /// Read the set of active connections across all clients in the pool. - async fn active_connections(&self) -> Result, RedisError> { - let all_connections = try_join_all(self.inner.clients.iter().map(|c| c.active_connections())).await?; - let total_size = if all_connections.is_empty() { - return Ok(Vec::new()); - } else { - all_connections.len() * all_connections[0].len() - }; - let mut out = Vec::with_capacity(total_size); + fn active_connections(&self) -> impl Future, RedisError>> + Send { + async move { + let all_connections = try_join_all(self.inner.clients.iter().map(|c| c.active_connections())).await?; + let total_size = if all_connections.is_empty() { + return Ok(Vec::new()); + } else { + all_connections.len() * all_connections[0].len() + }; + let mut out = Vec::with_capacity(total_size); - for connections in all_connections.into_iter() { - out.extend(connections); + for connections in all_connections.into_iter() { + out.extend(connections); + } + Ok(out) } - Ok(out) } /// Override the DNS resolution logic for all clients in the pool. #[cfg(feature = "dns")] #[cfg_attr(docsrs, doc(cfg(feature = "dns")))] #[allow(refining_impl_trait)] - async fn set_resolver(&self, resolver: Arc) { - for client in self.inner.clients.iter() { - client.set_resolver(resolver.clone()).await; + fn set_resolver(&self, resolver: RefCount) -> impl Future + Send { + async move { + for client in self.inner.clients.iter() { + client.set_resolver(resolver.clone()).await; + } } } @@ -222,30 +220,32 @@ impl ClientLike for RedisPool { /// See [init](Self::init) for an alternative shorthand. fn connect(&self) -> ConnectHandle { let clients = self.inner.clients.clone(); - tokio::spawn(async move { + spawn(async move { let tasks: Vec<_> = clients.iter().map(|c| c.connect()).collect(); for result in join_all(tasks).await.into_iter() { result??; } - Ok(()) + Ok::<(), RedisError>(()) }) } /// Force a reconnection to the server(s) for each client. /// /// When running against a cluster this function will also refresh the cached cluster routing table. - async fn force_reconnection(&self) -> RedisResult<()> { - try_join_all(self.inner.clients.iter().map(|c| c.force_reconnection())).await?; - - Ok(()) + fn force_reconnection(&self) -> impl Future> + Send { + async move { + try_join_all(self.inner.clients.iter().map(|c| c.force_reconnection())).await?; + Ok(()) + } } /// Wait for all the clients to connect to the server. - async fn wait_for_connect(&self) -> RedisResult<()> { - try_join_all(self.inner.clients.iter().map(|c| c.wait_for_connect())).await?; - - Ok(()) + fn wait_for_connect(&self) -> impl Future> + Send { + async move { + try_join_all(self.inner.clients.iter().map(|c| c.wait_for_connect())).await?; + Ok(()) + } } /// Initialize a new routing and connection task for each client and wait for them to connect successfully. @@ -270,20 +270,35 @@ impl ClientLike for RedisPool { /// connection_task.await? /// } /// ``` - async fn init(&self) -> RedisResult { - let rxs: Vec<_> = self.inner.clients.iter().map(|c| c.wait_for_connect()).collect(); - - let connect_task = self.connect(); - let init_err = futures::future::join_all(rxs).await.into_iter().find_map(|r| r.err()); + fn init(&self) -> impl Future> + Send { + #[allow(unused_mut)] + async move { + let mut rxs: Vec<_> = self + .inner + .clients + .iter() + .map(|c| c.inner().notifications.connect.load().subscribe()) + .collect(); + + let connect_task = self.connect(); + let init_err = futures::future::join_all(rxs.iter_mut().map(|rx| rx.recv())) + .await + .into_iter() + .find_map(|result| match result { + Ok(Err(e)) => Some(e), + Err(e) => Some(e.into()), + Ok(Ok(())) => None, + }); + + if let Some(err) = init_err { + for client in self.inner.clients.iter() { + utils::reset_router_task(client.inner()); + } - if let Some(err) = init_err { - for client in self.inner.clients.iter() { - utils::reset_router_task(client.inner()); + Err(err) + } else { + Ok(connect_task) } - - Err(err) - } else { - Ok(connect_task) } } @@ -294,23 +309,30 @@ impl ClientLike for RedisPool { /// /// This function will also close all error, pubsub message, and reconnection event streams on all clients in the /// pool. - async fn quit(&self) -> RedisResult<()> { - join_all(self.inner.clients.iter().map(|c| c.quit())).await; + fn quit(&self) -> impl Future> + Send { + async move { + join_all(self.inner.clients.iter().map(|c| c.quit())).await; - Ok(()) + Ok(()) + } } } +#[rm_send_if(feature = "glommio")] impl HeartbeatInterface for RedisPool { - async fn enable_heartbeat(&self, interval: Duration, break_on_error: bool) -> RedisResult<()> { - let mut interval = tokio_interval(interval); - - loop { - interval.tick().await; - - if let Err(error) = try_join_all(self.inner.clients.iter().map(|c| c.ping::<()>())).await { - if break_on_error { - return Err(error); + fn enable_heartbeat( + &self, + interval: Duration, + break_on_error: bool, + ) -> impl Future> + Send { + async move { + loop { + sleep(interval).await; + + if let Err(error) = try_join_all(self.inner.clients.iter().map(|c| c.ping::<()>())).await { + if break_on_error { + return Err(error); + } } } } @@ -381,8 +403,9 @@ impl TimeSeriesInterface for RedisPool {} #[cfg_attr(docsrs, doc(cfg(feature = "i-redisearch")))] impl RediSearchInterface for RedisPool {} +#[cfg(not(feature = "glommio"))] struct PoolInner { - clients: Vec>>, + clients: Vec>>, counter: AtomicUsize, } @@ -443,11 +466,13 @@ struct PoolInner { /// ``` /// /// Callers should avoid cloning the inner clients, if possible. +#[cfg(not(feature = "glommio"))] #[derive(Clone)] pub struct ExclusivePool { - inner: Arc, + inner: RefCount, } +#[cfg(not(feature = "glommio"))] impl fmt::Debug for ExclusivePool { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("ExclusivePool") @@ -456,6 +481,7 @@ impl fmt::Debug for ExclusivePool { } } +#[cfg(not(feature = "glommio"))] impl ExclusivePool { /// Create a new pool without connecting to the server. /// @@ -472,7 +498,7 @@ impl ExclusivePool { } else { let mut clients = Vec::with_capacity(size); for _ in 0 .. size { - clients.push(Arc::new(AsyncMutex::new(RedisClient::new( + clients.push(RefCount::new(AsyncMutex::new(RedisClient::new( config.clone(), perf.clone(), connection.clone(), @@ -481,7 +507,7 @@ impl ExclusivePool { } Ok(ExclusivePool { - inner: Arc::new(PoolInner { + inner: RefCount::new(PoolInner { clients, counter: AtomicUsize::new(0), }), @@ -490,7 +516,7 @@ impl ExclusivePool { } /// Read the clients in the pool. - pub fn clients(&self) -> &[Arc>] { + pub fn clients(&self) -> &[RefCount>] { &self.inner.clients } @@ -639,7 +665,7 @@ impl ExclusivePool { #[cfg(feature = "dns")] #[cfg_attr(docsrs, doc(cfg(feature = "dns")))] #[allow(refining_impl_trait)] - pub async fn set_resolver(&self, resolver: Arc) { + pub async fn set_resolver(&self, resolver: RefCount) { for client in self.inner.clients.iter() { client.lock().await.set_resolver(resolver.clone()).await; } diff --git a/src/clients/pubsub.rs b/src/clients/pubsub.rs index 20b5cc46..bb109309 100644 --- a/src/clients/pubsub.rs +++ b/src/clients/pubsub.rs @@ -4,15 +4,15 @@ use crate::{ interfaces::*, modules::inner::RedisClientInner, prelude::RedisClient, + runtime::{spawn, JoinHandle, RefCount, RwLock}, types::{ConnectionConfig, MultipleStrings, PerformanceConfig, ReconnectPolicy, RedisConfig, RedisKey}, util::group_by_hash_slot, }; use bytes_utils::Str; -use parking_lot::RwLock; -use std::{collections::BTreeSet, fmt, fmt::Formatter, mem, sync::Arc}; -use tokio::task::JoinHandle; +use fred_macros::rm_send_if; +use std::{collections::BTreeSet, fmt, fmt::Formatter, future::Future, mem}; -type ChannelSet = Arc>>; +type ChannelSet = RefCount>>; /// A subscriber client that will manage subscription state to any [pubsub](https://redis.io/docs/manual/pubsub/) channels or patterns for the caller. /// @@ -58,7 +58,7 @@ pub struct SubscriberClient { channels: ChannelSet, patterns: ChannelSet, shard_channels: ChannelSet, - inner: Arc, + inner: RefCount, } impl fmt::Debug for SubscriberClient { @@ -74,7 +74,7 @@ impl fmt::Debug for SubscriberClient { impl ClientLike for SubscriberClient { #[doc(hidden)] - fn inner(&self) -> &Arc { + fn inner(&self) -> &RefCount { &self.inner } } @@ -154,132 +154,145 @@ impl RediSearchInterface for SubscriberClient {} #[cfg(feature = "i-pubsub")] #[cfg_attr(docsrs, doc(cfg(feature = "i-pubsub")))] +#[rm_send_if(feature = "glommio")] impl PubsubInterface for SubscriberClient { - async fn subscribe(&self, channels: S) -> RedisResult<()> + fn subscribe(&self, channels: S) -> impl Future> + Send where S: Into + Send, { into!(channels); - let result = commands::pubsub::subscribe(self, channels.clone()).await; - if result.is_ok() { - let mut guard = self.channels.write(); + async move { + let result = commands::pubsub::subscribe(self, channels.clone()).await; + if result.is_ok() { + let mut guard = self.channels.write(); - for channel in channels.inner().into_iter() { - if let Some(channel) = channel.as_bytes_str() { - guard.insert(channel); + for channel in channels.inner().into_iter() { + if let Some(channel) = channel.as_bytes_str() { + guard.insert(channel); + } } } - } - result + result + } } - async fn unsubscribe(&self, channels: S) -> RedisResult<()> + fn unsubscribe(&self, channels: S) -> impl Future> + Send where S: Into + Send, { into!(channels); - let result = commands::pubsub::unsubscribe(self, channels.clone()).await; - if result.is_ok() { - let mut guard = self.channels.write(); - - if channels.len() == 0 { - guard.clear(); - } else { - for channel in channels.inner().into_iter() { - if let Some(channel) = channel.as_bytes_str() { - let _ = guard.remove(&channel); + async move { + let result = commands::pubsub::unsubscribe(self, channels.clone()).await; + if result.is_ok() { + let mut guard = self.channels.write(); + + if channels.len() == 0 { + guard.clear(); + } else { + for channel in channels.inner().into_iter() { + if let Some(channel) = channel.as_bytes_str() { + let _ = guard.remove(&channel); + } } } } + result } - result } - async fn psubscribe(&self, patterns: S) -> RedisResult<()> + fn psubscribe(&self, patterns: S) -> impl Future> + Send where S: Into + Send, { into!(patterns); - let result = commands::pubsub::psubscribe(self, patterns.clone()).await; - if result.is_ok() { - let mut guard = self.patterns.write(); + async move { + let result = commands::pubsub::psubscribe(self, patterns.clone()).await; + if result.is_ok() { + let mut guard = self.patterns.write(); - for pattern in patterns.inner().into_iter() { - if let Some(pattern) = pattern.as_bytes_str() { - guard.insert(pattern); + for pattern in patterns.inner().into_iter() { + if let Some(pattern) = pattern.as_bytes_str() { + guard.insert(pattern); + } } } + result } - result } - async fn punsubscribe(&self, patterns: S) -> RedisResult<()> + fn punsubscribe(&self, patterns: S) -> impl Future> + Send where S: Into + Send, { into!(patterns); - let result = commands::pubsub::punsubscribe(self, patterns.clone()).await; - if result.is_ok() { - let mut guard = self.patterns.write(); - - if patterns.len() == 0 { - guard.clear(); - } else { - for pattern in patterns.inner().into_iter() { - if let Some(pattern) = pattern.as_bytes_str() { - let _ = guard.remove(&pattern); + async move { + let result = commands::pubsub::punsubscribe(self, patterns.clone()).await; + if result.is_ok() { + let mut guard = self.patterns.write(); + + if patterns.len() == 0 { + guard.clear(); + } else { + for pattern in patterns.inner().into_iter() { + if let Some(pattern) = pattern.as_bytes_str() { + let _ = guard.remove(&pattern); + } } } } + result } - result } - async fn ssubscribe(&self, channels: C) -> RedisResult<()> + fn ssubscribe(&self, channels: C) -> impl Future> + Send where C: Into + Send, { into!(channels); - let result = commands::pubsub::ssubscribe(self, channels.clone()).await; - if result.is_ok() { - let mut guard = self.shard_channels.write(); + async move { + let result = commands::pubsub::ssubscribe(self, channels.clone()).await; + if result.is_ok() { + let mut guard = self.shard_channels.write(); - for channel in channels.inner().into_iter() { - if let Some(channel) = channel.as_bytes_str() { - guard.insert(channel); + for channel in channels.inner().into_iter() { + if let Some(channel) = channel.as_bytes_str() { + guard.insert(channel); + } } } + result } - result } - async fn sunsubscribe(&self, channels: C) -> RedisResult<()> + fn sunsubscribe(&self, channels: C) -> impl Future> + Send where C: Into + Send, { into!(channels); - let result = commands::pubsub::sunsubscribe(self, channels.clone()).await; - if result.is_ok() { - let mut guard = self.shard_channels.write(); - - if channels.len() == 0 { - guard.clear(); - } else { - for channel in channels.inner().into_iter() { - if let Some(channel) = channel.as_bytes_str() { - let _ = guard.remove(&channel); + async move { + let result = commands::pubsub::sunsubscribe(self, channels.clone()).await; + if result.is_ok() { + let mut guard = self.shard_channels.write(); + + if channels.len() == 0 { + guard.clear(); + } else { + for channel in channels.inner().into_iter() { + if let Some(channel) = channel.as_bytes_str() { + let _ = guard.remove(&channel); + } } } } + result } - result } } @@ -294,9 +307,9 @@ impl SubscriberClient { policy: Option, ) -> SubscriberClient { SubscriberClient { - channels: Arc::new(RwLock::new(BTreeSet::new())), - patterns: Arc::new(RwLock::new(BTreeSet::new())), - shard_channels: Arc::new(RwLock::new(BTreeSet::new())), + channels: RefCount::new(RwLock::new(BTreeSet::new())), + patterns: RefCount::new(RwLock::new(BTreeSet::new())), + shard_channels: RefCount::new(RwLock::new(BTreeSet::new())), inner: RedisClientInner::new(config, perf.unwrap_or_default(), connection.unwrap_or_default(), policy), } } @@ -315,16 +328,17 @@ impl SubscriberClient { SubscriberClient { inner, - channels: Arc::new(RwLock::new(self.channels.read().clone())), - patterns: Arc::new(RwLock::new(self.patterns.read().clone())), - shard_channels: Arc::new(RwLock::new(self.shard_channels.read().clone())), + channels: RefCount::new(RwLock::new(self.channels.read().clone())), + patterns: RefCount::new(RwLock::new(self.patterns.read().clone())), + shard_channels: RefCount::new(RwLock::new(self.shard_channels.read().clone())), } } /// Spawn a task that will automatically re-subscribe to any channels or channel patterns used by the client. pub fn manage_subscriptions(&self) -> JoinHandle<()> { let _self = self.clone(); - tokio::spawn(async move { + spawn(async move { + #[allow(unused_mut)] let mut stream = _self.reconnect_rx(); while let Ok(_) = stream.recv().await { diff --git a/src/clients/redis.rs b/src/clients/redis.rs index 083954c6..1057e212 100644 --- a/src/clients/redis.rs +++ b/src/clients/redis.rs @@ -5,11 +5,12 @@ use crate::{ interfaces::*, modules::inner::RedisClientInner, prelude::ClientLike, + runtime::RefCount, types::*, }; use bytes_utils::Str; use futures::Stream; -use std::{fmt, fmt::Formatter, sync::Arc}; +use std::{fmt, fmt::Formatter}; #[cfg(feature = "replicas")] use crate::clients::Replicas; @@ -19,7 +20,7 @@ use crate::interfaces::TrackingInterface; /// A cheaply cloneable Redis client struct. #[derive(Clone)] pub struct RedisClient { - pub(crate) inner: Arc, + pub(crate) inner: RefCount, } impl Default for RedisClient { @@ -44,15 +45,15 @@ impl fmt::Display for RedisClient { } #[doc(hidden)] -impl<'a> From<&'a Arc> for RedisClient { - fn from(inner: &'a Arc) -> RedisClient { +impl<'a> From<&'a RefCount> for RedisClient { + fn from(inner: &'a RefCount) -> RedisClient { RedisClient { inner: inner.clone() } } } impl ClientLike for RedisClient { #[doc(hidden)] - fn inner(&self) -> &Arc { + fn inner(&self) -> &RefCount { &self.inner } } diff --git a/src/clients/replica.rs b/src/clients/replica.rs index c6612723..5a07590b 100644 --- a/src/clients/replica.rs +++ b/src/clients/replica.rs @@ -4,10 +4,10 @@ use crate::{ interfaces::{self, *}, modules::inner::RedisClientInner, protocol::command::{RedisCommand, RouterCommand}, + runtime::{oneshot_channel, RefCount}, types::Server, }; -use std::{collections::HashMap, fmt, fmt::Formatter, sync::Arc}; -use tokio::sync::oneshot::channel as oneshot_channel; +use std::{collections::HashMap, fmt, fmt::Formatter}; /// A struct for interacting with cluster replica nodes. /// @@ -20,7 +20,7 @@ use tokio::sync::oneshot::channel as oneshot_channel; #[derive(Clone)] #[cfg_attr(docsrs, doc(cfg(feature = "replicas")))] pub struct Replicas { - inner: Arc, + inner: RefCount, } impl fmt::Debug for Replicas { @@ -30,15 +30,15 @@ impl fmt::Debug for Replicas { } #[doc(hidden)] -impl From<&Arc> for Replicas { - fn from(inner: &Arc) -> Self { +impl From<&RefCount> for Replicas { + fn from(inner: &RefCount) -> Self { Replicas { inner: inner.clone() } } } impl ClientLike for Replicas { #[doc(hidden)] - fn inner(&self) -> &Arc { + fn inner(&self) -> &RefCount { &self.inner } diff --git a/src/clients/sentinel.rs b/src/clients/sentinel.rs index bbc4c5ec..a11f15a7 100644 --- a/src/clients/sentinel.rs +++ b/src/clients/sentinel.rs @@ -1,9 +1,10 @@ use crate::{ interfaces::*, modules::inner::RedisClientInner, + runtime::RefCount, types::{ConnectionConfig, PerformanceConfig, ReconnectPolicy, SentinelConfig}, }; -use std::{fmt, sync::Arc}; +use std::fmt; /// A struct for interacting directly with Sentinel nodes. /// @@ -16,12 +17,12 @@ use std::{fmt, sync::Arc}; #[derive(Clone)] #[cfg_attr(docsrs, doc(cfg(feature = "sentinel-client")))] pub struct SentinelClient { - inner: Arc, + inner: RefCount, } impl ClientLike for SentinelClient { #[doc(hidden)] - fn inner(&self) -> &Arc { + fn inner(&self) -> &RefCount { &self.inner } } @@ -36,8 +37,8 @@ impl fmt::Debug for SentinelClient { } #[doc(hidden)] -impl<'a> From<&'a Arc> for SentinelClient { - fn from(inner: &'a Arc) -> Self { +impl<'a> From<&'a RefCount> for SentinelClient { + fn from(inner: &'a RefCount) -> Self { SentinelClient { inner: inner.clone() } } } diff --git a/src/clients/transaction.rs b/src/clients/transaction.rs index bf869642..2566613b 100644 --- a/src/clients/transaction.rs +++ b/src/clients/transaction.rs @@ -10,39 +10,43 @@ use crate::{ responders::ResponseKind, utils as protocol_utils, }, + runtime::{oneshot_channel, AtomicBool, Mutex, RefCount}, types::{FromRedis, MultipleKeys, Options, RedisKey, Server}, utils, }; -use parking_lot::Mutex; -use std::{collections::VecDeque, fmt, sync::Arc}; -use tokio::sync::oneshot::channel as oneshot_channel; +use std::{collections::VecDeque, fmt}; + +struct State { + id: u64, + commands: Mutex>, + watched: Mutex>, + hash_slot: Mutex>, + pipelined: AtomicBool, +} /// A cheaply cloneable transaction block. #[derive(Clone)] -#[cfg(feature = "transactions")] #[cfg_attr(docsrs, doc(cfg(feature = "transactions")))] pub struct Transaction { - id: u64, - inner: Arc, - commands: Arc>>, - watched: Arc>>, - hash_slot: Arc>>, + inner: RefCount, + state: RefCount, } impl fmt::Debug for Transaction { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Transaction") .field("client", &self.inner.id) - .field("id", &self.id) - .field("length", &self.commands.lock().len()) - .field("hash_slot", &self.hash_slot.lock()) + .field("id", &self.state.id) + .field("length", &self.state.commands.lock().len()) + .field("hash_slot", &self.state.hash_slot.lock()) + .field("pipelined", &utils::read_bool_atomic(&self.state.pipelined)) .finish() } } impl PartialEq for Transaction { fn eq(&self, other: &Self) -> bool { - self.id == other.id + self.state.id == other.state.id } } @@ -50,7 +54,7 @@ impl Eq for Transaction {} impl ClientLike for Transaction { #[doc(hidden)] - fn inner(&self) -> &Arc { + fn inner(&self) -> &RefCount { &self.inner } @@ -64,7 +68,8 @@ impl ClientLike for Transaction { // check cluster slot mappings as commands are added self.update_hash_slot(&command)?; - if let Some(tx) = command.take_responder() { + #[allow(unused_mut)] + if let Some(mut tx) = command.take_responder() { trace!( "{}: Respond early to {} command in transaction.", &self.inner.id, @@ -73,7 +78,7 @@ impl ClientLike for Transaction { let _ = tx.send(Ok(protocol_utils::queued_frame())); } - self.commands.lock().push_back(command); + self.state.commands.lock().push_back(command); Ok(()) } } @@ -136,13 +141,16 @@ impl RediSearchInterface for Transaction {} impl Transaction { /// Create a new transaction. - pub(crate) fn from_inner(inner: &Arc) -> Self { + pub(crate) fn from_inner(inner: &RefCount) -> Self { Transaction { - inner: inner.clone(), - commands: Arc::new(Mutex::new(VecDeque::new())), - watched: Arc::new(Mutex::new(VecDeque::new())), - hash_slot: Arc::new(Mutex::new(None)), - id: utils::random_u64(u64::MAX), + inner: inner.clone(), + state: RefCount::new(State { + commands: Mutex::new(VecDeque::new()), + watched: Mutex::new(VecDeque::new()), + hash_slot: Mutex::new(None), + pipelined: AtomicBool::new(false), + id: utils::random_u64(u64::MAX), + }), } } @@ -153,7 +161,7 @@ impl Transaction { } if let Some(slot) = command.cluster_hash() { - if let Some(old_slot) = utils::read_mutex(&self.hash_slot) { + if let Some(old_slot) = utils::read_mutex(&self.state.hash_slot) { let (old_server, server) = self.inner.with_cluster_state(|state| { debug!( "{}: Checking transaction hash slots: {}, {}", @@ -170,7 +178,7 @@ impl Transaction { )); } } else { - utils::set_mutex(&self.hash_slot, Some(slot)); + utils::set_mutex(&self.state.hash_slot, Some(slot)); } } @@ -190,24 +198,35 @@ impl Transaction { /// An ID identifying the underlying transaction state. pub fn id(&self) -> u64 { - self.id + self.state.id } /// Clear the internal command buffer and watched keys. pub fn reset(&self) { - self.commands.lock().clear(); - self.watched.lock().clear(); - self.hash_slot.lock().take(); + self.state.commands.lock().clear(); + self.state.watched.lock().clear(); + self.state.hash_slot.lock().take(); } /// Read the number of commands queued to run. pub fn len(&self) -> usize { - self.commands.lock().len() + self.state.commands.lock().len() + } + + /// Whether to pipeline commands in the transaction. + /// + /// Note: pipelined transactions should only be used with Redis version >=2.6.5. + pub fn pipeline(&self, val: bool) { + utils::set_bool_atomic(&self.state.pipelined, val); } /// Read the number of keys to `WATCH` before the starting the transaction. + #[deprecated( + since = "9.2.0", + note = "Please use `WATCH` with clients from an `ExclusivePool` instead." + )] pub fn watched_len(&self) -> usize { - self.watched.lock().len() + self.state.watched.lock().len() } /// Executes all previously queued commands in a transaction. @@ -239,35 +258,48 @@ impl Transaction { { let commands = { self + .state .commands .lock() .iter() .map(|cmd| cmd.duplicate(ResponseKind::Skip)) .collect() }; - let watched = { self.watched.lock().iter().cloned().collect() }; - let hash_slot = utils::read_mutex(&self.hash_slot); - exec(&self.inner, commands, watched, hash_slot, abort_on_error, self.id) - .await? - .convert() + let pipelined = utils::read_bool_atomic(&self.state.pipelined); + let hash_slot = utils::read_mutex(&self.state.hash_slot); + + exec( + &self.inner, + commands, + hash_slot, + abort_on_error, + pipelined, + self.state.id, + ) + .await? + .convert() } /// Send the `WATCH` command with the provided keys before starting the transaction. + #[deprecated( + since = "9.2.0", + note = "Please use `WATCH` with clients from an `ExclusivePool` instead." + )] pub fn watch_before(&self, keys: K) where K: Into, { - self.watched.lock().extend(keys.into().inner()); + self.state.watched.lock().extend(keys.into().inner()); } /// Read the hash slot against which this transaction will run, if known. pub fn hash_slot(&self) -> Option { - utils::read_mutex(&self.hash_slot) + utils::read_mutex(&self.state.hash_slot) } /// Read the server ID against which this transaction will run, if known. pub fn cluster_node(&self) -> Option { - utils::read_mutex(&self.hash_slot).and_then(|slot| { + utils::read_mutex(&self.state.hash_slot).and_then(|slot| { self .inner .with_cluster_state(|state| Ok(state.get_server(slot).cloned())) @@ -278,11 +310,11 @@ impl Transaction { } async fn exec( - inner: &Arc, + inner: &RefCount, commands: VecDeque, - watched: VecDeque, hash_slot: Option, abort_on_error: bool, + pipelined: bool, id: u64, ) -> Result { if commands.is_empty() { @@ -310,33 +342,18 @@ async fn exec( command }) .collect(); - // collapse the watched keys into one command - let watched = if watched.is_empty() { - None - } else { - let args: Vec = watched.into_iter().map(|k| k.into()).collect(); - let mut watch_cmd = RedisCommand::new(RedisCommandKind::Watch, args); - watch_cmd.can_pipeline = false; - watch_cmd.skip_backpressure = true; - watch_cmd.transaction_id = Some(id); - if let Some(hash_slot) = hash_slot.as_ref() { - watch_cmd.hasher = ClusterHash::Custom(*hash_slot); - } - Some(watch_cmd) - }; _trace!( inner, - "Sending transaction {} with {} commands ({} watched) to router.", + "Sending transaction {} with {} commands to router.", id, commands.len(), - watched.as_ref().map(|c| c.args().len()).unwrap_or(0) ); let command = RouterCommand::Transaction { id, tx, commands, - watched, + pipelined, abort_on_error, }; let timeout_dur = trx_options.timeout.unwrap_or_else(|| inner.default_command_timeout()); diff --git a/src/commands/impls/cluster.rs b/src/commands/impls/cluster.rs index 1a29121f..e0dc093c 100644 --- a/src/commands/impls/cluster.rs +++ b/src/commands/impls/cluster.rs @@ -5,12 +5,12 @@ use crate::{ command::{RedisCommandKind, RouterCommand}, utils as protocol_utils, }, + runtime::oneshot_channel, types::*, utils, }; use bytes_utils::Str; use std::convert::TryInto; -use tokio::sync::oneshot::channel as oneshot_channel; value_cmd!(cluster_bumpepoch, ClusterBumpEpoch); ok_cmd!(cluster_flushslots, ClusterFlushSlots); diff --git a/src/commands/impls/lua.rs b/src/commands/impls/lua.rs index ddff850b..55815a42 100644 --- a/src/commands/impls/lua.rs +++ b/src/commands/impls/lua.rs @@ -10,18 +10,18 @@ use crate::{ responders::ResponseKind, utils as protocol_utils, }, + runtime::{oneshot_channel, RefCount}, types::*, utils, }; use bytes::Bytes; use bytes_utils::Str; use redis_protocol::resp3::types::BytesFrame as Resp3Frame; -use std::{convert::TryInto, str, sync::Arc}; -use tokio::sync::oneshot::channel as oneshot_channel; +use std::{convert::TryInto, str}; /// Check that all the keys in an EVAL* command belong to the same server, returning a key slot that maps to that /// server. -pub fn check_key_slot(inner: &Arc, keys: &[RedisKey]) -> Result, RedisError> { +pub fn check_key_slot(inner: &RefCount, keys: &[RedisKey]) -> Result, RedisError> { if inner.config.server.is_clustered() { inner.with_cluster_state(|state| { let (mut cmd_server, mut cmd_slot) = (None, None); diff --git a/src/commands/impls/scan.rs b/src/commands/impls/scan.rs index bbe4f9bf..9a31050e 100644 --- a/src/commands/impls/scan.rs +++ b/src/commands/impls/scan.rs @@ -8,14 +8,15 @@ use crate::{ responders::ResponseKind, types::*, }, + runtime::{rx_stream, unbounded_channel, RefCount}, types::*, utils, }; use bytes_utils::Str; use futures::stream::{Stream, TryStreamExt}; -use std::sync::Arc; -use tokio::sync::mpsc::{unbounded_channel, UnboundedSender}; -use tokio_stream::wrappers::UnboundedReceiverStream; + +#[cfg(feature = "glommio")] +use crate::runtime::UnboundedSender; static STARTING_CURSOR: &str = "0"; @@ -34,27 +35,22 @@ fn values_args(key: RedisKey, pattern: Str, count: Option) -> Vec(tx: &UnboundedSender>, error: RedisError) { - let tx = tx.clone(); - tokio::spawn(async move { - let _ = tx.send(Err(error)); - }); -} - pub fn scan_cluster( - inner: &Arc, + inner: &RefCount, pattern: Str, count: Option, r#type: Option, ) -> impl Stream> { let (tx, rx) = unbounded_channel(); + #[cfg(feature = "glommio")] + let tx: UnboundedSender<_> = tx.into(); let hash_slots = inner.with_cluster_state(|state| Ok(state.unique_hash_slots())); let hash_slots = match hash_slots { Ok(slots) => slots, Err(e) => { - early_error(&tx, e); - return UnboundedReceiverStream::new(rx); + let _ = tx.send(Err(e)); + return rx_stream(rx); }, }; @@ -84,22 +80,24 @@ pub fn scan_cluster( let command: RedisCommand = (RedisCommandKind::Scan, Vec::new(), response).into(); if let Err(e) = interfaces::default_send_command(inner, command) { - early_error(&tx, e); + let _ = tx.send(Err(e)); break; } } - UnboundedReceiverStream::new(rx) + rx_stream(rx) } pub fn scan( - inner: &Arc, + inner: &RefCount, pattern: Str, count: Option, r#type: Option, server: Option, ) -> impl Stream> { let (tx, rx) = unbounded_channel(); + #[cfg(feature = "glommio")] + let tx: UnboundedSender<_> = tx.into(); let hash_slot = if inner.config.server.is_clustered() { if utils::clustered_scan_pattern_has_hash_tag(inner, &pattern) { @@ -135,14 +133,14 @@ pub fn scan( let command: RedisCommand = (RedisCommandKind::Scan, Vec::new(), response).into(); if let Err(e) = interfaces::default_send_command(inner, command) { - early_error(&tx, e); + let _ = tx.send(Err(e)); } - UnboundedReceiverStream::new(rx) + rx_stream(rx) } pub fn hscan( - inner: &Arc, + inner: &RefCount, key: RedisKey, pattern: Str, count: Option, @@ -150,6 +148,8 @@ pub fn hscan( let (tx, rx) = unbounded_channel(); let args = values_args(key, pattern, count); + #[cfg(feature = "glommio")] + let tx: UnboundedSender<_> = tx.into(); let response = ResponseKind::ValueScan(ValueScanInner { tx: tx.clone(), cursor_idx: 1, @@ -158,10 +158,10 @@ pub fn hscan( let command: RedisCommand = (RedisCommandKind::Hscan, Vec::new(), response).into(); if let Err(e) = interfaces::default_send_command(inner, command) { - early_error(&tx, e); + let _ = tx.send(Err(e)); } - UnboundedReceiverStream::new(rx).try_filter_map(|result| async move { + rx_stream(rx).try_filter_map(|result| async move { match result { ValueScanResult::HScan(res) => Ok(Some(res)), _ => Err(RedisError::new(RedisErrorKind::Protocol, "Expected HSCAN result.")), @@ -170,7 +170,7 @@ pub fn hscan( } pub fn sscan( - inner: &Arc, + inner: &RefCount, key: RedisKey, pattern: Str, count: Option, @@ -178,6 +178,8 @@ pub fn sscan( let (tx, rx) = unbounded_channel(); let args = values_args(key, pattern, count); + #[cfg(feature = "glommio")] + let tx: UnboundedSender<_> = tx.into(); let response = ResponseKind::ValueScan(ValueScanInner { tx: tx.clone(), cursor_idx: 1, @@ -186,10 +188,10 @@ pub fn sscan( let command: RedisCommand = (RedisCommandKind::Sscan, Vec::new(), response).into(); if let Err(e) = interfaces::default_send_command(inner, command) { - early_error(&tx, e); + let _ = tx.send(Err(e)); } - UnboundedReceiverStream::new(rx).try_filter_map(|result| async move { + rx_stream(rx).try_filter_map(|result| async move { match result { ValueScanResult::SScan(res) => Ok(Some(res)), _ => Err(RedisError::new(RedisErrorKind::Protocol, "Expected SSCAN result.")), @@ -198,7 +200,7 @@ pub fn sscan( } pub fn zscan( - inner: &Arc, + inner: &RefCount, key: RedisKey, pattern: Str, count: Option, @@ -206,6 +208,8 @@ pub fn zscan( let (tx, rx) = unbounded_channel(); let args = values_args(key, pattern, count); + #[cfg(feature = "glommio")] + let tx: UnboundedSender<_> = tx.into(); let response = ResponseKind::ValueScan(ValueScanInner { tx: tx.clone(), cursor_idx: 1, @@ -214,10 +218,10 @@ pub fn zscan( let command: RedisCommand = (RedisCommandKind::Zscan, Vec::new(), response).into(); if let Err(e) = interfaces::default_send_command(inner, command) { - early_error(&tx, e); + let _ = tx.send(Err(e)); } - UnboundedReceiverStream::new(rx).try_filter_map(|result| async move { + rx_stream(rx).try_filter_map(|result| async move { match result { ValueScanResult::ZScan(res) => Ok(Some(res)), _ => Err(RedisError::new(RedisErrorKind::Protocol, "Expected ZSCAN result.")), diff --git a/src/commands/impls/server.rs b/src/commands/impls/server.rs index 9813ff1b..d9ffb0d4 100644 --- a/src/commands/impls/server.rs +++ b/src/commands/impls/server.rs @@ -10,12 +10,11 @@ use crate::{ responders::ResponseKind, utils as protocol_utils, }, + runtime::{oneshot_channel, RefCount}, types::*, utils, }; use bytes_utils::Str; -use std::sync::Arc; -use tokio::sync::oneshot::channel as oneshot_channel; pub async fn active_connections(client: &C) -> Result, RedisError> { let (tx, rx) = oneshot_channel(); @@ -83,7 +82,7 @@ pub async fn shutdown(client: &C, flags: Option) - } /// Create a new client struct for each unique primary cluster node based on the cached cluster state. -pub fn split(inner: &Arc) -> Result, RedisError> { +pub fn split(inner: &RefCount) -> Result, RedisError> { if !inner.config.server.is_clustered() { return Err(RedisError::new( RedisErrorKind::Config, @@ -109,7 +108,7 @@ pub fn split(inner: &Arc) -> Result, RedisErr ) } -pub async fn force_reconnection(inner: &Arc) -> Result<(), RedisError> { +pub async fn force_reconnection(inner: &RefCount) -> Result<(), RedisError> { let (tx, rx) = oneshot_channel(); let command = RouterCommand::Reconnect { server: None, diff --git a/src/commands/impls/tracking.rs b/src/commands/impls/tracking.rs index df853141..9b088ddf 100644 --- a/src/commands/impls/tracking.rs +++ b/src/commands/impls/tracking.rs @@ -6,11 +6,11 @@ use crate::{ responders::ResponseKind, utils as protocol_utils, }, + runtime::oneshot_channel, types::{ClusterHash, MultipleStrings, RedisValue, Toggle}, utils, }; use redis_protocol::redis_keyslot; -use tokio::sync::oneshot::channel as oneshot_channel; pub static PREFIX: &str = "PREFIX"; pub static REDIRECT: &str = "REDIRECT"; diff --git a/src/commands/interfaces/acl.rs b/src/commands/interfaces/acl.rs index 1d734e47..86253098 100644 --- a/src/commands/interfaces/acl.rs +++ b/src/commands/interfaces/acl.rs @@ -5,9 +5,11 @@ use crate::{ types::{FromRedis, MultipleStrings, MultipleValues}, }; use bytes_utils::Str; +use fred_macros::rm_send_if; use futures::Future; -/// Functions that implement the [ACL](https://redis.io/commands#server) interface. +/// Functions that implement the [ACL](https://redis.io/commandserver) interface. +#[rm_send_if(feature = "glommio")] pub trait AclInterface: ClientLike + Sized { /// Create an ACL user with the specified rules or modify the rules of an existing user. /// @@ -43,7 +45,7 @@ pub trait AclInterface: ClientLike + Sized { /// The command shows the currently active ACL rules in the Redis server. /// - /// + /// \ fn acl_list(&self) -> impl Future> + Send where R: FromRedis, diff --git a/src/commands/interfaces/client.rs b/src/commands/interfaces/client.rs index 64df7a9a..76da299c 100644 --- a/src/commands/interfaces/client.rs +++ b/src/commands/interfaces/client.rs @@ -12,17 +12,18 @@ use crate::{ Server, }, }; -use bytes_utils::Str; -use futures::Future; -use std::collections::HashMap; - #[cfg(feature = "i-tracking")] use crate::{ error::RedisError, types::{MultipleStrings, Toggle}, }; +use bytes_utils::Str; +use fred_macros::rm_send_if; +use futures::Future; +use std::collections::HashMap; /// Functions that implement the [client](https://redis.io/commands#connection) interface. +#[rm_send_if(feature = "glommio")] pub trait ClientInterface: ClientLike + Sized { /// Return the ID of the current connection. /// @@ -43,7 +44,7 @@ pub trait ClientInterface: ClientLike + Sized { /// /// Note: despite being async this function will return cached information from the client if possible. fn connection_ids(&self) -> impl Future> + Send { - async move { self.inner().backchannel.read().await.connection_ids.clone() } + async move { self.inner().backchannel.write().await.connection_ids.clone() } } /// The command returns information and statistics about the current client connection in a mostly human readable @@ -201,6 +202,7 @@ pub trait ClientInterface: ClientLike + Sized { /// caching feature. /// /// + #[cfg(feature = "i-tracking")] #[cfg_attr(docsrs, doc(cfg(feature = "i-tracking")))] fn client_trackinginfo(&self) -> impl Future> + Send diff --git a/src/commands/interfaces/cluster.rs b/src/commands/interfaces/cluster.rs index 244303d2..115d9626 100644 --- a/src/commands/interfaces/cluster.rs +++ b/src/commands/interfaces/cluster.rs @@ -6,9 +6,11 @@ use crate::{ types::{ClusterFailoverFlag, ClusterResetFlag, ClusterSetSlotState, FromRedis, MultipleHashSlots, RedisKey}, }; use bytes_utils::Str; +use fred_macros::rm_send_if; use futures::Future; /// Functions that implement the [cluster](https://redis.io/commands#cluster) interface. +#[rm_send_if(feature = "glommio")] pub trait ClusterInterface: ClientLike + Sized { /// Read the cached cluster state used for routing commands to the correct cluster nodes. fn cached_cluster_state(&self) -> Option { @@ -262,7 +264,7 @@ pub trait ClusterInterface: ClientLike + Sized { async move { commands::cluster::cluster_set_config_epoch(self, epoch).await } } - /// CLUSTER SETSLOT is responsible of changing the state of a hash slot in the receiving node in different ways. + /// CLUSTER SETSLOT is responsible for changing the state of a hash slot in the receiving node in different ways. /// /// fn cluster_setslot(&self, slot: u16, state: ClusterSetSlotState) -> impl Future> + Send { diff --git a/src/commands/interfaces/config.rs b/src/commands/interfaces/config.rs index fd8a3906..f483519e 100644 --- a/src/commands/interfaces/config.rs +++ b/src/commands/interfaces/config.rs @@ -5,10 +5,12 @@ use crate::{ types::{FromRedis, RedisValue}, }; use bytes_utils::Str; +use fred_macros::rm_send_if; use futures::Future; use std::convert::TryInto; /// Functions that implement the [config](https://redis.io/commands#server) interface. +#[rm_send_if(feature = "glommio")] pub trait ConfigInterface: ClientLike + Sized { /// Resets the statistics reported by Redis using the INFO command. /// diff --git a/src/commands/interfaces/geo.rs b/src/commands/interfaces/geo.rs index 7d0be181..7ca69b62 100644 --- a/src/commands/interfaces/geo.rs +++ b/src/commands/interfaces/geo.rs @@ -1,5 +1,3 @@ -use futures::Future; - use crate::{ commands, error::RedisError, @@ -17,9 +15,12 @@ use crate::{ SortOrder, }, }; +use fred_macros::rm_send_if; +use futures::Future; use std::convert::TryInto; /// Functions that implement the [geo](https://redis.io/commands#geo) interface. +#[rm_send_if(feature = "glommio")] pub trait GeoInterface: ClientLike + Sized { /// Adds the specified geospatial items (longitude, latitude, name) to the specified key. /// diff --git a/src/commands/interfaces/hashes.rs b/src/commands/interfaces/hashes.rs index fafca543..a0009f5c 100644 --- a/src/commands/interfaces/hashes.rs +++ b/src/commands/interfaces/hashes.rs @@ -1,14 +1,15 @@ -use futures::Future; - use crate::{ commands, error::RedisError, interfaces::{ClientLike, RedisResult}, types::{FromRedis, MultipleKeys, RedisKey, RedisMap, RedisValue}, }; +use fred_macros::rm_send_if; +use futures::Future; use std::convert::TryInto; /// Functions that implement the [hashes](https://redis.io/commands#hashes) interface. +#[rm_send_if(feature = "glommio")] pub trait HashesInterface: ClientLike + Sized { /// Returns all fields and values of the hash stored at `key`. /// diff --git a/src/commands/interfaces/hyperloglog.rs b/src/commands/interfaces/hyperloglog.rs index 6ccb7a4f..81d34776 100644 --- a/src/commands/interfaces/hyperloglog.rs +++ b/src/commands/interfaces/hyperloglog.rs @@ -1,14 +1,15 @@ -use futures::Future; - use crate::{ commands, error::RedisError, interfaces::{ClientLike, RedisResult}, types::{FromRedis, MultipleKeys, MultipleValues, RedisKey}, }; +use fred_macros::rm_send_if; +use futures::Future; use std::convert::TryInto; /// Functions that implement the [HyperLogLog](https://redis.io/commands#hyperloglog) interface. +#[rm_send_if(feature = "glommio")] pub trait HyperloglogInterface: ClientLike + Sized { /// Adds all the element arguments to the HyperLogLog data structure stored at the variable name specified as first /// argument. diff --git a/src/commands/interfaces/keys.rs b/src/commands/interfaces/keys.rs index 9bd6181b..cd0c625c 100644 --- a/src/commands/interfaces/keys.rs +++ b/src/commands/interfaces/keys.rs @@ -1,14 +1,15 @@ -use futures::Future; - use crate::{ commands, error::RedisError, interfaces::{ClientLike, RedisResult}, types::{Expiration, ExpireOptions, FromRedis, MultipleKeys, RedisKey, RedisMap, RedisValue, SetOptions}, }; +use fred_macros::rm_send_if; +use futures::Future; use std::convert::TryInto; /// Functions that implement the generic [keys](https://redis.io/commands#generic) interface. +#[rm_send_if(feature = "glommio")] pub trait KeysInterface: ClientLike + Sized { /// Marks the given keys to be watched for conditional execution of a transaction. /// diff --git a/src/commands/interfaces/lists.rs b/src/commands/interfaces/lists.rs index 59cf5c49..86a7078e 100644 --- a/src/commands/interfaces/lists.rs +++ b/src/commands/interfaces/lists.rs @@ -1,5 +1,3 @@ -use futures::Future; - use crate::{ commands, error::RedisError, @@ -18,9 +16,12 @@ use crate::{ }, }; use bytes_utils::Str; +use fred_macros::rm_send_if; +use futures::Future; use std::convert::TryInto; /// Functions that implement the [lists](https://redis.io/commands#lists) interface. +#[rm_send_if(feature = "glommio")] pub trait ListInterface: ClientLike + Sized { /// The blocking variant of [Self::lmpop]. /// @@ -148,7 +149,7 @@ pub trait ListInterface: ClientLike + Sized { } } - /// Returns the element at index index in the list stored at key. + /// Returns the element at index in the list stored at key. /// /// fn lindex(&self, key: K, index: i64) -> impl Future> + Send diff --git a/src/commands/interfaces/lua.rs b/src/commands/interfaces/lua.rs index 2f71adbe..f36e088a 100644 --- a/src/commands/interfaces/lua.rs +++ b/src/commands/interfaces/lua.rs @@ -6,10 +6,12 @@ use crate::{ }; use bytes::Bytes; use bytes_utils::Str; +use fred_macros::rm_send_if; use futures::Future; use std::convert::TryInto; /// Functions that implement the [lua](https://redis.io/commands#lua) interface. +#[rm_send_if(feature = "glommio")] pub trait LuaInterface: ClientLike + Sized { /// Load a script into the scripts cache, without executing it. After the specified command is loaded into the /// script cache it will be callable using EVALSHA with the correct SHA1 digest of the script. @@ -133,6 +135,7 @@ pub trait LuaInterface: ClientLike + Sized { } /// Functions that implement the [function](https://redis.io/docs/manual/programmability/functions-intro/) interface. +#[rm_send_if(feature = "glommio")] pub trait FunctionInterface: ClientLike + Sized { /// Invoke a function. /// diff --git a/src/commands/interfaces/memory.rs b/src/commands/interfaces/memory.rs index 047ddff0..98466166 100644 --- a/src/commands/interfaces/memory.rs +++ b/src/commands/interfaces/memory.rs @@ -4,9 +4,11 @@ use crate::{ prelude::FromRedis, types::RedisKey, }; +use fred_macros::rm_send_if; use futures::Future; /// Functions that implement the [memory](https://redis.io/commands#server) interface. +#[rm_send_if(feature = "glommio")] pub trait MemoryInterface: ClientLike + Sized { /// The MEMORY DOCTOR command reports about different memory-related issues that the Redis server experiences, and /// advises about possible remedies. diff --git a/src/commands/interfaces/pubsub.rs b/src/commands/interfaces/pubsub.rs index 310f5412..deafe879 100644 --- a/src/commands/interfaces/pubsub.rs +++ b/src/commands/interfaces/pubsub.rs @@ -5,10 +5,12 @@ use crate::{ types::{FromRedis, MultipleStrings, RedisValue}, }; use bytes_utils::Str; +use fred_macros::rm_send_if; use futures::Future; use std::convert::TryInto; /// Functions that implement the [pubsub](https://redis.io/commands#pubsub) interface. +#[rm_send_if(feature = "glommio")] pub trait PubsubInterface: ClientLike + Sized + Send { /// Subscribe to a channel on the publish-subscribe interface. /// @@ -16,7 +18,6 @@ pub trait PubsubInterface: ClientLike + Sized + Send { fn subscribe(&self, channels: S) -> impl Future> + Send where S: Into + Send, - Self: Send + Sync, { async move { into!(channels); @@ -30,7 +31,6 @@ pub trait PubsubInterface: ClientLike + Sized + Send { fn unsubscribe(&self, channels: S) -> impl Future> + Send where S: Into + Send, - Self: Sync, { async move { into!(channels); diff --git a/src/commands/interfaces/redis_json.rs b/src/commands/interfaces/redis_json.rs index fdfe3834..75dd3266 100644 --- a/src/commands/interfaces/redis_json.rs +++ b/src/commands/interfaces/redis_json.rs @@ -4,6 +4,7 @@ use crate::{ types::{FromRedis, MultipleKeys, MultipleStrings, RedisKey, SetOptions}, }; use bytes_utils::Str; +use fred_macros::rm_send_if; use futures::Future; use serde_json::Value; @@ -45,6 +46,7 @@ use serde_json::Value; /// } /// ``` #[cfg_attr(docsrs, doc(cfg(feature = "i-redis-json")))] +#[rm_send_if(feature = "glommio")] pub trait RedisJsonInterface: ClientLike + Sized { /// Append the json values into the array at path after the last element in it. /// diff --git a/src/commands/interfaces/redisearch.rs b/src/commands/interfaces/redisearch.rs index 0040b6c1..3639585b 100644 --- a/src/commands/interfaces/redisearch.rs +++ b/src/commands/interfaces/redisearch.rs @@ -17,10 +17,12 @@ use crate::{ }; use bytes::Bytes; use bytes_utils::Str; +use fred_macros::rm_send_if; use std::future::Future; /// A [RediSearch](https://github.com/RediSearch/RediSearch) interface. #[cfg_attr(docsrs, doc(cfg(feature = "i-redisearch")))] +#[rm_send_if(feature = "glommio")] pub trait RediSearchInterface: ClientLike + Sized { /// Returns a list of all existing indexes. /// diff --git a/src/commands/interfaces/sentinel.rs b/src/commands/interfaces/sentinel.rs index 61876309..d42c1fb9 100644 --- a/src/commands/interfaces/sentinel.rs +++ b/src/commands/interfaces/sentinel.rs @@ -5,10 +5,12 @@ use crate::{ types::{FromRedis, RedisMap, RedisValue, SentinelFailureKind}, }; use bytes_utils::Str; +use fred_macros::rm_send_if; use futures::Future; use std::{convert::TryInto, net::IpAddr}; /// Functions that implement the [sentinel](https://redis.io/topics/sentinel#sentinel-commands) interface. +#[rm_send_if(feature = "glommio")] pub trait SentinelInterface: ClientLike + Sized { /// Check if the current Sentinel configuration is able to reach the quorum needed to failover a master, and the /// majority needed to authorize the failover. diff --git a/src/commands/interfaces/server.rs b/src/commands/interfaces/server.rs index 18c15514..0f5e8166 100644 --- a/src/commands/interfaces/server.rs +++ b/src/commands/interfaces/server.rs @@ -4,9 +4,11 @@ use crate::{ interfaces::{ClientLike, RedisResult}, types::{FromRedis, Server}, }; +use fred_macros::rm_send_if; use futures::Future; /// Functions that implement the [server](https://redis.io/commands#server) interface. +#[rm_send_if(feature = "glommio")] pub trait ServerInterface: ClientLike { /// Instruct Redis to start an Append Only File rewrite process. /// diff --git a/src/commands/interfaces/sets.rs b/src/commands/interfaces/sets.rs index eefa38b1..643bc588 100644 --- a/src/commands/interfaces/sets.rs +++ b/src/commands/interfaces/sets.rs @@ -1,14 +1,15 @@ -use futures::Future; - use crate::{ commands, error::RedisError, interfaces::{ClientLike, RedisResult}, types::{FromRedis, MultipleKeys, MultipleValues, RedisKey, RedisValue}, }; +use fred_macros::rm_send_if; +use futures::Future; use std::convert::TryInto; /// Functions that implement the [sets](https://redis.io/commands#set) interface. +#[rm_send_if(feature = "glommio")] pub trait SetsInterface: ClientLike + Sized { /// Add the specified members to the set stored at `key`. /// diff --git a/src/commands/interfaces/slowlog.rs b/src/commands/interfaces/slowlog.rs index 83851c30..d02bdbf1 100644 --- a/src/commands/interfaces/slowlog.rs +++ b/src/commands/interfaces/slowlog.rs @@ -3,9 +3,11 @@ use crate::{ interfaces::{ClientLike, RedisResult}, types::FromRedis, }; +use fred_macros::rm_send_if; use futures::Future; /// Functions that implement the [slowlog](https://redis.io/commands#server) interface. +#[rm_send_if(feature = "glommio")] pub trait SlowlogInterface: ClientLike + Sized { /// This command is used to read the slow queries log. /// diff --git a/src/commands/interfaces/sorted_sets.rs b/src/commands/interfaces/sorted_sets.rs index 21af62cf..2ce8f98f 100644 --- a/src/commands/interfaces/sorted_sets.rs +++ b/src/commands/interfaces/sorted_sets.rs @@ -1,5 +1,3 @@ -use futures::Future; - use crate::{ commands, error::RedisError, @@ -21,9 +19,12 @@ use crate::{ ZSort, }, }; +use fred_macros::rm_send_if; +use futures::Future; use std::convert::TryInto; /// Functions that implement the [sorted sets](https://redis.io/commands#sorted_set) interface. +#[rm_send_if(feature = "glommio")] pub trait SortedSetsInterface: ClientLike + Sized { /// The blocking variant of [Self::zmpop]. /// diff --git a/src/commands/interfaces/streams.rs b/src/commands/interfaces/streams.rs index 176f0da7..8526a057 100644 --- a/src/commands/interfaces/streams.rs +++ b/src/commands/interfaces/streams.rs @@ -19,6 +19,7 @@ use crate::{ }, }; use bytes_utils::Str; +use fred_macros::rm_send_if; use futures::Future; use std::{convert::TryInto, hash::Hash}; @@ -29,6 +30,7 @@ use std::{convert::TryInto, hash::Hash}; /// [xreadgroup_map](Self::xreadgroup_map), [xrange_values](Self::xrange_values), etc exist to make this easier for /// callers. These functions apply an additional layer of parsing logic that can make declaring response types easier, /// as well as automatically handling any differences between RESP2 and RESP3 return value types. +#[rm_send_if(feature = "glommio")] pub trait StreamsInterface: ClientLike + Sized { /// This command returns the list of consumers that belong to the `groupname` consumer group of the stream stored at /// `key`. diff --git a/src/commands/interfaces/timeseries.rs b/src/commands/interfaces/timeseries.rs index e79d37ee..9e907d27 100644 --- a/src/commands/interfaces/timeseries.rs +++ b/src/commands/interfaces/timeseries.rs @@ -16,10 +16,12 @@ use crate::{ }, }; use bytes_utils::Str; +use fred_macros::rm_send_if; use futures::Future; /// A [Redis Timeseries](https://github.com/RedisTimeSeries/RedisTimeSeries/) interface. #[cfg_attr(docsrs, doc(cfg(feature = "i-time-series")))] +#[rm_send_if(feature = "glommio")] pub trait TimeSeriesInterface: ClientLike { /// Append a sample to a time series. /// diff --git a/src/commands/interfaces/tracking.rs b/src/commands/interfaces/tracking.rs index a852a23b..666b4b2a 100644 --- a/src/commands/interfaces/tracking.rs +++ b/src/commands/interfaces/tracking.rs @@ -2,13 +2,15 @@ use crate::{ commands, interfaces::ClientLike, prelude::RedisResult, + runtime::{spawn, BroadcastReceiver, JoinHandle}, types::{Invalidation, MultipleStrings}, }; +use fred_macros::rm_send_if; use futures::Future; -use tokio::{sync::broadcast::Receiver as BroadcastReceiver, task::JoinHandle}; /// A high level interface that supports [client side caching](https://redis.io/docs/manual/client-side-caching/) via the [client tracking](https://redis.io/commands/client-tracking/) interface. #[cfg_attr(docsrs, doc(cfg(feature = "i-tracking")))] +#[rm_send_if(feature = "glommio")] pub trait TrackingInterface: ClientLike + Sized { /// Send the [CLIENT TRACKING](https://redis.io/commands/client-tracking/) command to all connected servers, subscribing to [invalidation messages](Self::on_invalidation) on the same connection. /// @@ -48,7 +50,7 @@ pub trait TrackingInterface: ClientLike + Sized { { let mut invalidation_rx = self.invalidation_rx(); - tokio::spawn(async move { + spawn(async move { let mut result = Ok(()); while let Ok(invalidation) = invalidation_rx.recv().await { diff --git a/src/commands/mod.rs b/src/commands/mod.rs index b10e1f19..b5596797 100644 --- a/src/commands/mod.rs +++ b/src/commands/mod.rs @@ -1,4 +1,3 @@ mod impls; pub mod interfaces; - pub use impls::*; diff --git a/src/error.rs b/src/error.rs index 59fe4bbf..43eac000 100644 --- a/src/error.rs +++ b/src/error.rs @@ -13,7 +13,6 @@ use std::{ str::Utf8Error, string::FromUtf8Error, }; -use tokio::task::JoinError; use url::ParseError; /// An enum representing the type of error from Redis. @@ -243,12 +242,29 @@ impl From for RedisError { } #[doc(hidden)] -impl From for RedisError { - fn from(e: JoinError) -> Self { +#[cfg(not(feature = "glommio"))] +impl From for RedisError { + fn from(e: tokio::task::JoinError) -> Self { RedisError::new(RedisErrorKind::Unknown, format!("Spawn Error: {:?}", e)) } } +#[doc(hidden)] +#[cfg(feature = "glommio")] +impl From> for RedisError { + fn from(e: glommio::GlommioError) -> Self { + RedisError::new(RedisErrorKind::Unknown, format!("{:?}", e)) + } +} + +#[doc(hidden)] +#[cfg(feature = "glommio")] +impl From for RedisError { + fn from(_: oneshot::RecvError) -> Self { + RedisError::new_canceled() + } +} + #[doc(hidden)] impl From for RedisError { fn from(e: SemverError) -> Self { diff --git a/src/glommio/README.md b/src/glommio/README.md new file mode 100644 index 00000000..cc3c7959 --- /dev/null +++ b/src/glommio/README.md @@ -0,0 +1,45 @@ +# Glommio + +See the [Glommio Introduction](https://www.datadoghq.com/blog/engineering/introducing-glommio/) for more info. + +Tokio and Glommio have an important difference in their scheduling interfaces: + +* [tokio::spawn](https://docs.rs/tokio/latest/tokio/task/fn.spawn.html) requires a `Send` bound on the spawned + future so that the Tokio scheduler can implement work-stealing across threads. +* Glommio's scheduling interface is intended to be used in cases where runtime threads do not need to share or + synchronize any state. Both the [spawn_local](https://docs.rs/glommio/latest/glommio/fn.spawn_local.html) + and [spawn_local_into](https://docs.rs/glommio/latest/glommio/fn.spawn_local_into.html) functions + spawn tasks on the same thread and therefore do not have a `Send` bound. + +`fred` was originally written with message-passing design patterns targeting a Tokio runtime and therefore the `Send` +bound from `tokio::spawn` leaked into all the public interfaces that send messages across tasks. This includes nearly +all the public command traits, including the base `ClientLike` trait. + +When building with `--features glommio` the public interface will change in several ways: + +* The `Send + Sync` bounds will be removed from all generic input parameters, where clause predicates, and `impl Trait` + return types. +* Internal `Arc` usages will change to `Rc`. +* Internal `RwLock` and `Mutex` usages will change to `RefCell`. +* Internal usages of `std::sync::atomic` types will change to thin wrappers around a `RefCell`. +* Any Tokio message passing interfaces (`BroadcastSender`, etc) will change to the closest Glommio equivalent. +* A Tokio compatability layer will be used to map between the two runtime's versions of `AsyncRead` and + `AsyncWrite`. This enables the existing codec interface (`Encoder` + `Decoder`) to work with Glommio's network types. + As a result, for now some Tokio dependencies are still required when using Glommio features. + +[Glommio Example](https://github.com/aembke/fred.rs/blob/main/examples/glommio.rs) + +The public docs +on [docs.rs](https://docs.rs/fred/latest) will continue to show the Tokio interfaces that require `Send` bounds, but +callers can find the latest rustdocs for both runtimes on the +`gh-pages` branch: + +[Glommio Documentation](https://aembke.github.io/fred.rs/glommio/fred/index.html) + +[Tokio Documentation](https://aembke.github.io/fred.rs/tokio/fred/index.html) + +Callers can rebuild Glommio docs via the [doc-glommio.sh](../../tests/doc-glommio.sh) script: + +``` +path/to/fred/tests/doc-glommio.sh --open +``` \ No newline at end of file diff --git a/src/glommio/broadcast.rs b/src/glommio/broadcast.rs new file mode 100644 index 00000000..93424103 --- /dev/null +++ b/src/glommio/broadcast.rs @@ -0,0 +1,93 @@ +use crate::error::RedisError; +use glommio::{ + channels::local_channel::{new_unbounded, LocalReceiver, LocalSender}, + GlommioError, + ResourceType, +}; +use std::{cell::RefCell, collections::BTreeMap, rc::Rc}; + +struct Inner { + pub counter: u64, + pub senders: BTreeMap>, +} + +/// A multi-producer multi-consumer channel receiver. +/// +/// See [LocalReceiver](glommio::channels::local_channel::LocalReceiver) for more information. +pub struct BroadcastReceiver { + id: u64, + inner: Rc>>, + rx: LocalReceiver, +} + +impl BroadcastReceiver { + /// Receives data from this channel. + /// + /// See [recv](glommio::channels::local_channel::LocalReceiver::recv) for more information. + pub async fn recv(&self) -> Result { + match self.rx.recv().await { + Some(v) => Ok(v), + None => Err(RedisError::new_canceled()), + } + } +} + +impl Drop for BroadcastReceiver { + fn drop(&mut self) { + self.inner.as_ref().borrow_mut().senders.remove(&self.id); + } +} + +#[derive(Clone)] +pub struct BroadcastSender { + inner: Rc>>, +} + +impl BroadcastSender { + pub fn new() -> Self { + BroadcastSender { + inner: Rc::new(RefCell::new(Inner { + counter: 0, + senders: BTreeMap::new(), + })), + } + } + + pub fn subscribe(&self) -> BroadcastReceiver { + let (tx, rx) = new_unbounded(); + let id = { + let mut guard = self.inner.as_ref().borrow_mut(); + let count = guard.counter.wrapping_add(1); + guard.counter = count; + guard.senders.insert(count, tx); + guard.counter + }; + + BroadcastReceiver { + id, + rx, + inner: self.inner.clone(), + } + } + + pub fn send(&self, msg: &T, func: F) { + let mut guard = self.inner.as_ref().borrow_mut(); + + let to_remove: Vec = guard + .senders + .iter() + .filter_map(|(id, tx)| { + if let Err(GlommioError::Closed(ResourceType::Channel(val))) = tx.try_send(msg.clone()) { + func(&val); + Some(*id) + } else { + None + } + }) + .collect(); + + for id in to_remove.into_iter() { + guard.senders.remove(&id); + } + } +} diff --git a/src/glommio/interfaces.rs b/src/glommio/interfaces.rs new file mode 100644 index 00000000..63d32f27 --- /dev/null +++ b/src/glommio/interfaces.rs @@ -0,0 +1,361 @@ +#[cfg(feature = "i-server")] +use crate::types::ShutdownFlags; +use crate::{ + clients::WithOptions, + commands, + error::RedisError, + glommio::compat::spawn_into, + interfaces::{RedisResult, Resp3Frame}, + modules::inner::RedisClientInner, + prelude::default_send_command, + protocol::command::RedisCommand, + router::commands as router_commands, + runtime::{spawn, BroadcastReceiver, JoinHandle, RefCount}, + types::{ + ClientState, + ConnectHandle, + ConnectionConfig, + CustomCommand, + FromRedis, + InfoKind, + Options, + PerformanceConfig, + ReconnectPolicy, + RedisConfig, + RedisValue, + Server, + }, + utils, +}; +use redis_protocol::resp3::types::RespVersion; +use semver::Version; +use std::future::Future; + +#[cfg(any(feature = "dns", feature = "trust-dns-resolver"))] +use crate::protocol::types::Resolve; + +pub trait ClientLike: Clone + Sized { + #[doc(hidden)] + fn inner(&self) -> &RefCount; + + /// Helper function to intercept and modify a command without affecting how it is sent to the connection layer. + #[doc(hidden)] + fn change_command(&self, _: &mut RedisCommand) {} + + /// Helper function to intercept and customize how a command is sent to the connection layer. + #[doc(hidden)] + fn send_command(&self, command: C) -> Result<(), RedisError> + where + C: Into, + { + let mut command: RedisCommand = command.into(); + self.change_command(&mut command); + default_send_command(self.inner(), command) + } + + /// The unique ID identifying this client and underlying connections. + fn id(&self) -> &str { + &self.inner().id + } + + /// Read the config used to initialize the client. + fn client_config(&self) -> RedisConfig { + self.inner().config.as_ref().clone() + } + + /// Read the reconnect policy used to initialize the client. + fn client_reconnect_policy(&self) -> Option { + self.inner().policy.read().clone() + } + + /// Read the connection config used to initialize the client. + fn connection_config(&self) -> &ConnectionConfig { + self.inner().connection.as_ref() + } + + /// Read the RESP version used by the client when communicating with the server. + fn protocol_version(&self) -> RespVersion { + if self.inner().is_resp3() { + RespVersion::RESP3 + } else { + RespVersion::RESP2 + } + } + + /// Whether the client has a reconnection policy. + fn has_reconnect_policy(&self) -> bool { + self.inner().policy.read().is_some() + } + + /// Whether the client will automatically pipeline commands. + fn is_pipelined(&self) -> bool { + self.inner().is_pipelined() + } + + /// Whether the client is connected to a cluster. + fn is_clustered(&self) -> bool { + self.inner().config.server.is_clustered() + } + + /// Whether the client uses the sentinel interface. + fn uses_sentinels(&self) -> bool { + self.inner().config.server.is_sentinel() + } + + /// Update the internal [PerformanceConfig](crate::types::PerformanceConfig) in place with new values. + fn update_perf_config(&self, config: PerformanceConfig) { + self.inner().update_performance_config(config); + } + + /// Read the [PerformanceConfig](crate::types::PerformanceConfig) associated with this client. + fn perf_config(&self) -> PerformanceConfig { + self.inner().performance_config() + } + + /// Read the state of the underlying connection(s). + /// + /// If running against a cluster the underlying state will reflect the state of the least healthy connection. + fn state(&self) -> ClientState { + self.inner().state.read().clone() + } + + /// Whether all underlying connections are healthy. + fn is_connected(&self) -> bool { + *self.inner().state.read() == ClientState::Connected + } + + /// Read the set of active connections managed by the client. + fn active_connections(&self) -> impl Future, RedisError>> { + commands::server::active_connections(self) + } + + /// Read the server version, if known. + fn server_version(&self) -> Option { + self.inner().server_state.read().kind.server_version() + } + + /// Override the DNS resolution logic for the client. + #[cfg(feature = "dns")] + #[cfg_attr(docsrs, doc(cfg(feature = "dns")))] + fn set_resolver(&self, resolver: RefCount) -> impl Future { + async move { self.inner().set_resolver(resolver).await } + } + + /// Connect to the server. + /// + /// This function returns a `JoinHandle` to a task that drives the connection. It will not resolve until the + /// connection closes, or if a reconnection policy with unlimited attempts is provided then it will + /// run until `QUIT` is called. Callers should avoid calling [abort](tokio::task::JoinHandle::abort) on the returned + /// `JoinHandle` unless the client will no longer be used. + /// + /// **Calling this function more than once will drop all state associated with the previous connection(s).** Any + /// pending commands on the old connection(s) will either finish or timeout, but they will not be retried on the + /// new connection(s). + /// + /// See [init](Self::init) for an alternative shorthand. + fn connect(&self) -> ConnectHandle { + let inner = self.inner().clone(); + let tq = inner.connection.router_task_queue; + utils::reset_router_task(&inner); + + let connection_ft = async move { + utils::clear_backchannel_state(&inner).await; + let result = router_commands::start(&inner).await; + // a canceled error means we intentionally closed the client + _trace!(inner, "Ending connection task with {:?}", result); + + if let Err(ref error) = result { + if !error.is_canceled() { + inner.notifications.broadcast_connect(Err(error.clone())); + } + } + + utils::check_and_set_client_state(&inner.state, ClientState::Disconnecting, ClientState::Disconnected); + result + }; + + if let Some(tq) = tq { + spawn_into(connection_ft, tq) + } else { + spawn(connection_ft) + } + } + + /// Force a reconnection to the server(s). + /// + /// When running against a cluster this function will also refresh the cached cluster routing table. + fn force_reconnection(&self) -> impl Future> { + async move { commands::server::force_reconnection(self.inner()).await } + } + + /// Wait for the result of the next connection attempt. + /// + /// This can be used with `on_reconnect` to separate initialization logic that needs to occur only on the next + /// connection attempt vs all subsequent attempts. + fn wait_for_connect(&self) -> impl Future> { + async move { + if { utils::read_locked(&self.inner().state) } == ClientState::Connected { + debug!("{}: Client is already connected.", self.inner().id); + Ok(()) + } else { + let rx = { self.inner().notifications.connect.load().subscribe() }; + rx.recv().await? + } + } + } + + /// Initialize a new routing and connection task and wait for it to connect successfully. + /// + /// The returned [ConnectHandle](crate::types::ConnectHandle) refers to the task that drives the routing and + /// connection layer. It will not finish until the max reconnection count is reached. Callers should avoid calling + /// [abort](tokio::task::JoinHandle::abort) on the returned `JoinHandle` unless the client will no longer be used. + /// + /// Callers can also use [connect](Self::connect) and [wait_for_connect](Self::wait_for_connect) separately if + /// needed. + /// + /// ```rust + /// use fred::prelude::*; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), RedisError> { + /// let client = RedisClient::default(); + /// let connection_task = client.init().await?; + /// + /// // ... + /// + /// client.quit().await?; + /// connection_task.await? + /// } + /// ``` + fn init(&self) -> impl Future> { + async move { + let rx = { self.inner().notifications.connect.load().subscribe() }; + let task = self.connect(); + let error = rx.recv().await.map_err(RedisError::from).and_then(|r| r).err(); + + if let Some(error) = error { + // the initial connection failed, so we should gracefully close the routing task + utils::reset_router_task(self.inner()); + Err(error) + } else { + Ok(task) + } + } + } + + /// Close the connection to the Redis server. The returned future resolves when the command has been written to the + /// socket, not when the connection has been fully closed. Some time after this future resolves the future + /// returned by [connect](Self::connect) will resolve which indicates that the connection has been fully closed. + /// + /// This function will also close all error, pubsub message, and reconnection event streams. + fn quit(&self) -> impl Future> { + async move { commands::server::quit(self).await } + } + + /// Shut down the server and quit the client. + /// + /// + #[cfg(feature = "i-server")] + #[cfg_attr(docsrs, doc(cfg(feature = "i-server")))] + fn shutdown(&self, flags: Option) -> impl Future> { + async move { commands::server::shutdown(self, flags).await } + } + + /// Delete the keys in all databases. + /// + /// + fn flushall(&self, r#async: bool) -> impl Future> + where + R: FromRedis, + { + async move { commands::server::flushall(self, r#async).await?.convert() } + } + + /// Delete the keys on all nodes in the cluster. This is a special function that does not map directly to the Redis + /// interface. + fn flushall_cluster(&self) -> impl Future> { + async move { commands::server::flushall_cluster(self).await } + } + + /// Ping the Redis server. + /// + /// + fn ping(&self) -> impl Future> + where + R: FromRedis, + { + async move { commands::server::ping(self).await?.convert() } + } + + /// Read info about the server. + /// + /// + fn info(&self, section: Option) -> impl Future> + where + R: FromRedis, + { + async move { commands::server::info(self, section).await?.convert() } + } + + /// Run a custom command that is not yet supported via another interface on this client. This is most useful when + /// interacting with third party modules or extensions. + /// + /// Callers should use the re-exported [redis_keyslot](crate::util::redis_keyslot) function to hash the command's + /// key, if necessary. + /// + /// This interface should be used with caution as it may break the automatic pipeline features in the client if + /// command flags are not properly configured. + fn custom(&self, cmd: CustomCommand, args: Vec) -> impl Future> + where + R: FromRedis, + T: TryInto, + T::Error: Into, + { + async move { + let args = utils::try_into_vec(args)?; + commands::server::custom(self, cmd, args).await?.convert() + } + } + + /// Run a custom command similar to [custom](Self::custom), but return the response frame directly without any + /// parsing. + /// + /// Note: RESP2 frames from the server are automatically converted to the RESP3 format when parsed by the client. + fn custom_raw(&self, cmd: CustomCommand, args: Vec) -> impl Future> + where + T: TryInto, + T::Error: Into, + { + async move { + let args = utils::try_into_vec(args)?; + commands::server::custom_raw(self, cmd, args).await + } + } + + /// Customize various configuration options on commands. + fn with_options(&self, options: &Options) -> WithOptions { + WithOptions { + client: self.clone(), + options: options.clone(), + } + } +} + +pub fn spawn_event_listener(rx: BroadcastReceiver, func: F) -> JoinHandle> +where + T: Clone + 'static, + F: Fn(T) -> RedisResult<()> + 'static, +{ + spawn(async move { + let mut result = Ok(()); + + while let Ok(val) = rx.recv().await { + if let Err(err) = func(val) { + result = Err(err); + break; + } + } + + result + }) +} diff --git a/src/glommio/io_compat.rs b/src/glommio/io_compat.rs new file mode 100644 index 00000000..c03d1e2a --- /dev/null +++ b/src/glommio/io_compat.rs @@ -0,0 +1,68 @@ +/// Reuse the same approach used by gmf (https://github.com/EtaCassiopeia/gmf/blob/591037476e6a17f83954a20558ff0e1920d94301/gmf/src/server/tokio_interop.rs#L1). +/// +/// The `Framed` codec interface used by the `Connection` struct requires that `T: AsyncRead+AsyncWrite`. +/// These traits are defined in the tokio and futures_io/futures_lite crates, but the tokio_util::codec interface +/// uses the versions re-implemented in tokio. However, glommio's network interfaces implement +/// `AsyncRead+AsyncWrite` from the futures_io crate. There are several ways to work around this, including +/// either a re-implementation of the codec traits `Encoder+Decoder`, or a compatibility layer for the different +/// versions of `AsyncRead+AsyncWrite`. The `gmf` project used the second approach, which seems much easier than +/// re-implementing the `Framed` traits (https://github.com/tokio-rs/tokio/blob/1ac8dff213937088616dc84de9adc92b4b68c49a/tokio-util/src/codec/framed_impl.rs#L125). + +// ------------------- https://github.com/EtaCassiopeia/gmf/blob/591037476e6a17f83954a20558ff0e1920d94301/gmf/src/server/tokio_interop.rs + +/// This module provides interoperability with the Tokio async runtime. +/// It contains utilities to bridge between futures_lite and Tokio. +use std::io::{self}; +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +use futures_io::{AsyncRead, AsyncWrite}; +use tokio::io::ReadBuf; + +/// A wrapper type for AsyncRead + AsyncWrite + Unpin types, providing +/// interoperability with Tokio's AsyncRead and AsyncWrite traits. +#[pin_project::pin_project] // This generates a projection for the inner type. +pub struct TokioIO(#[pin] pub T) +where + T: AsyncRead + AsyncWrite + Unpin; + +impl tokio::io::AsyncWrite for TokioIO +where + T: AsyncRead + AsyncWrite + Unpin, +{ + /// Write some data into the inner type, returning how many bytes were written. + fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { + // This is the same as Pin::new(&mut self.0).poll_write(cx, buf) with the source type of `mut self` + // using projection makes it easier to read. + let this = self.project(); + this.0.poll_write(cx, buf) + } + + /// Flushes the inner type. + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.project().0.poll_flush(cx) + } + + /// Shuts down the inner type, flushing any buffered data. + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.project().0.poll_close(cx) + } +} + +impl tokio::io::AsyncRead for TokioIO +where + T: AsyncRead + AsyncWrite + Unpin, +{ + /// Reads some data from the inner type, returning how many bytes were read. + fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut ReadBuf<'_>) -> Poll> { + self.project().0.poll_read(cx, buf.initialize_unfilled()).map(|n| { + if let Ok(n) = n { + buf.advance(n); + } + + Ok(()) + }) + } +} diff --git a/src/glommio/mod.rs b/src/glommio/mod.rs new file mode 100644 index 00000000..3d21b683 --- /dev/null +++ b/src/glommio/mod.rs @@ -0,0 +1,118 @@ +#[cfg(all(feature = "glommio", feature = "unix-sockets"))] +compile_error!("Cannot use glommio and unix-sockets features together."); + +pub(crate) mod broadcast; +pub(crate) mod interfaces; +pub(crate) mod io_compat; +pub(crate) mod mpsc; +pub(crate) mod sync; + +pub(crate) mod compat { + pub use super::{ + broadcast::{BroadcastReceiver, BroadcastSender}, + mpsc::{rx_stream, UnboundedReceiver, UnboundedSender}, + sync::*, + }; + use crate::error::RedisError; + use futures::Future; + use glommio::TaskQueueHandle; + pub use glommio::{ + channels::local_channel::new_unbounded as unbounded_channel, + task::JoinHandle as GlommioJoinHandle, + timer::sleep, + }; + pub use oneshot::{channel as oneshot_channel, Receiver as OneshotReceiver, Sender as OneshotSender}; + use std::{ + cell::RefCell, + pin::Pin, + rc::Rc, + task::{Context, Poll}, + }; + + /// The reference counting container type. + /// + /// This type may change based on the runtime feature flags used. + pub type RefCount = Rc; + + pub fn broadcast_send(tx: &BroadcastSender, msg: &T, func: F) { + tx.send(msg, func); + } + + pub fn broadcast_channel(_: usize) -> (BroadcastSender, BroadcastReceiver) { + let tx = BroadcastSender::new(); + let rx = tx.subscribe(); + (tx, rx) + } + + /// A wrapper type around [JoinHandle](glommio::task::JoinHandle) with an interface similar to Tokio's + /// [JoinHandle](tokio::task::JoinHandle) + pub struct JoinHandle { + pub(crate) inner: GlommioJoinHandle, + pub(crate) finished: Rc>, + } + + pub fn spawn(ft: impl Future + 'static) -> JoinHandle { + let finished = Rc::new(RefCell::new(false)); + let _finished = finished.clone(); + let inner = glommio::spawn_local(async move { + let result = ft.await; + _finished.replace(true); + result + }) + .detach(); + + JoinHandle { inner, finished } + } + + pub fn spawn_into(ft: impl Future + 'static, tq: TaskQueueHandle) -> JoinHandle { + let finished = Rc::new(RefCell::new(false)); + let _finished = finished.clone(); + let inner = glommio::spawn_local_into( + async move { + let result = ft.await; + _finished.replace(true); + result + }, + tq, + ) + .unwrap_or_else(|e| panic!("Failed to spawn task into task queue {tq:?}: {e:?}")) + .detach(); + + JoinHandle { inner, finished } + } + + impl Future for JoinHandle { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + use futures_lite::FutureExt; + + let finished = self.finished.clone(); + let result = self + .get_mut() + .inner + .poll(cx) + .map(|result| result.ok_or(RedisError::new_canceled())); + + if let Poll::Ready(_) = result { + finished.replace(true); + } + result + } + } + + impl JoinHandle { + pub(crate) fn set_finished(&self) { + self.finished.replace(true); + } + + pub fn is_finished(&self) -> bool { + *self.finished.as_ref().borrow() + } + + pub fn abort(&self) { + self.inner.cancel(); + self.set_finished(); + } + } +} diff --git a/src/glommio/mpsc.rs b/src/glommio/mpsc.rs new file mode 100644 index 00000000..17b4122d --- /dev/null +++ b/src/glommio/mpsc.rs @@ -0,0 +1,83 @@ +use crate::error::{RedisError, RedisErrorKind}; +use futures::Stream; +use glommio::{ + channels::local_channel::{LocalReceiver, LocalSender}, + GlommioError, +}; +use std::{ + ops::Deref, + pin::Pin, + rc::Rc, + task::{Context, Poll}, +}; + +pub type UnboundedReceiver = LocalReceiver; + +pub struct UnboundedReceiverStream { + rx: LocalReceiver, +} + +impl From> for UnboundedReceiverStream { + fn from(rx: LocalReceiver) -> Self { + UnboundedReceiverStream { rx } + } +} + +impl UnboundedReceiverStream { + #[allow(dead_code)] + pub async fn recv(&mut self) -> Option { + self.rx.recv().await + } +} + +impl Stream for UnboundedReceiverStream { + type Item = T; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + use futures_lite::stream::StreamExt; + + // TODO make sure this is cancellation-safe. it's a bit unclear why the internal impl of ChannelStream does what + // it does. + self.rx.stream().poll_next(cx) + } +} + +pub struct UnboundedSender { + tx: Rc>, +} + +// https://github.com/rust-lang/rust/issues/26925 +impl Clone for UnboundedSender { + fn clone(&self) -> Self { + UnboundedSender { tx: self.tx.clone() } + } +} + +impl From> for UnboundedSender { + fn from(tx: LocalSender) -> Self { + UnboundedSender { tx: Rc::new(tx) } + } +} + +impl UnboundedSender { + pub fn try_send(&self, msg: T) -> Result<(), GlommioError> { + self.tx.try_send(msg) + } + + pub fn send(&self, msg: T) -> Result<(), RedisError> { + if let Err(_e) = self.tx.deref().try_send(msg) { + // shouldn't happen since we use unbounded channels + Err(RedisError::new( + RedisErrorKind::Canceled, + "Failed to send message on channel.", + )) + } else { + Ok(()) + } + } +} + +pub fn rx_stream(rx: LocalReceiver) -> impl Stream + 'static { + // what happens if we `join` the futures from `recv()` and `rx.stream().next()`? + UnboundedReceiverStream::from(rx) +} diff --git a/src/glommio/notes.md b/src/glommio/notes.md new file mode 100644 index 00000000..e69de29b diff --git a/src/glommio/sync.rs b/src/glommio/sync.rs new file mode 100644 index 00000000..3ab17edc --- /dev/null +++ b/src/glommio/sync.rs @@ -0,0 +1,163 @@ +use std::{ + cell::{Ref, RefCell, RefMut}, + fmt, + mem, + sync::atomic::Ordering, +}; + +pub struct RefSwap { + inner: RefCell, +} + +impl RefSwap { + pub fn new(val: T) -> Self { + RefSwap { + inner: RefCell::new(val), + } + } + + pub fn swap(&self, other: T) -> T { + mem::replace(&mut self.inner.borrow_mut(), other) + } + + pub fn store(&self, other: T) { + self.swap(other); + } + + pub fn load(&self) -> Ref<'_, T> { + self.inner.borrow() + } +} + +pub struct AsyncRwLock { + inner: glommio::sync::RwLock, +} + +impl AsyncRwLock { + pub fn new(val: T) -> Self { + AsyncRwLock { + inner: glommio::sync::RwLock::new(val), + } + } + + pub async fn write(&self) -> glommio::sync::RwLockWriteGuard { + self.inner.write().await.unwrap() + } +} + +#[derive(Debug)] +pub struct AtomicUsize { + inner: RefCell, +} + +impl AtomicUsize { + pub fn new(val: usize) -> Self { + AtomicUsize { + inner: RefCell::new(val), + } + } + + pub fn fetch_add(&self, val: usize, _: Ordering) -> usize { + let mut guard = self.inner.borrow_mut(); + + let new = guard.saturating_add(val); + *guard = new; + new + } + + pub fn fetch_sub(&self, val: usize, _: Ordering) -> usize { + let mut guard = self.inner.borrow_mut(); + + let new = guard.saturating_sub(val); + *guard = new; + new + } + + pub fn load(&self, _: Ordering) -> usize { + *self.inner.borrow() + } + + pub fn swap(&self, val: usize, _: Ordering) -> usize { + let mut guard = self.inner.borrow_mut(); + let old = *guard; + *guard = val; + old + } +} + +#[derive(Debug)] +pub struct AtomicBool { + inner: RefCell, +} + +impl AtomicBool { + pub fn new(val: bool) -> Self { + AtomicBool { + inner: RefCell::new(val), + } + } + + pub fn load(&self, _: Ordering) -> bool { + *self.inner.borrow() + } + + pub fn swap(&self, val: bool, _: Ordering) -> bool { + let mut guard = self.inner.borrow_mut(); + let old = *guard; + *guard = val; + old + } +} + +pub type MutexGuard<'a, T> = RefMut<'a, T>; + +pub struct Mutex { + inner: RefCell, +} + +impl fmt::Debug for Mutex { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:?}", self.inner) + } +} + +impl Mutex { + pub fn new(val: T) -> Self { + Mutex { + inner: RefCell::new(val), + } + } + + pub fn lock(&self) -> MutexGuard { + self.inner.borrow_mut() + } +} + +pub type RwLockReadGuard<'a, T> = Ref<'a, T>; +pub type RwLockWriteGuard<'a, T> = RefMut<'a, T>; + +pub struct RwLock { + inner: RefCell, +} + +impl fmt::Debug for RwLock { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:?}", self.inner) + } +} + +impl RwLock { + pub fn new(val: T) -> Self { + RwLock { + inner: RefCell::new(val), + } + } + + pub fn read(&self) -> RwLockReadGuard { + self.inner.borrow() + } + + pub fn write(&self) -> RwLockWriteGuard { + self.inner.borrow_mut() + } +} diff --git a/src/interfaces.rs b/src/interfaces.rs index c9c2118d..136674a7 100644 --- a/src/interfaces.rs +++ b/src/interfaces.rs @@ -1,49 +1,23 @@ use crate::{ - clients::WithOptions, commands, error::{RedisError, RedisErrorKind}, modules::inner::RedisClientInner, protocol::command::{RedisCommand, RouterCommand}, - router::commands as router_commands, - types::{ - ClientState, - ClusterStateChange, - ConnectHandle, - ConnectionConfig, - CustomCommand, - FromRedis, - InfoKind, - KeyspaceEvent, - Message, - Options, - PerformanceConfig, - ReconnectPolicy, - RedisConfig, - RedisValue, - RespVersion, - Server, - }, + runtime::{sleep, spawn, BroadcastReceiver, JoinHandle, RefCount}, + types::{ClientState, ClusterStateChange, KeyspaceEvent, Message, RespVersion, Server}, utils, }; use bytes_utils::Str; +use fred_macros::rm_send_if; use futures::Future; -use semver::Version; -use std::{convert::TryInto, sync::Arc, time::Duration}; -use tokio::{sync::broadcast::Receiver as BroadcastReceiver, task::JoinHandle}; - pub use redis_protocol::resp3::types::BytesFrame as Resp3Frame; - -#[cfg(feature = "i-server")] -use crate::types::ShutdownFlags; +use std::time::Duration; /// Type alias for `Result`. pub type RedisResult = Result; -#[cfg(any(feature = "dns", feature = "trust-dns-resolver"))] -use crate::protocol::types::Resolve; - /// Send a single `RedisCommand` to the router. -pub(crate) fn default_send_command(inner: &Arc, command: C) -> Result<(), RedisError> +pub(crate) fn default_send_command(inner: &RefCount, command: C) -> Result<(), RedisError> where C: Into, { @@ -60,7 +34,7 @@ where } /// Send a `RouterCommand` to the router. -pub(crate) fn send_to_router(inner: &Arc, command: RouterCommand) -> Result<(), RedisError> { +pub(crate) fn send_to_router(inner: &RefCount, command: RouterCommand) -> Result<(), RedisError> { #[allow(clippy::collapsible_if)] if command.should_check_fail_fast() { if utils::read_locked(&inner.state) != ClientState::Connected { @@ -87,11 +61,11 @@ pub(crate) fn send_to_router(inner: &Arc, command: RouterComma return Ok(()); } - if let Err(e) = inner.command_tx.load().send(command) { + if let Err(e) = inner.send_command(command) { // usually happens if the caller tries to send a command before calling `connect` or after calling `quit` inner.counters.decr_cmd_buffer_len(); - if let RouterCommand::Command(mut command) = e.0 { + if let RouterCommand::Command(mut command) = e { _warn!( inner, "Fatal error sending {} command to router. Client may be stopped or not yet initialized.", @@ -118,326 +92,18 @@ pub(crate) fn send_to_router(inner: &Arc, command: RouterComma } } -/// Any Redis client that implements any part of the Redis interface. -pub trait ClientLike: Clone + Send + Sync + Sized { - #[doc(hidden)] - fn inner(&self) -> &Arc; - - /// Helper function to intercept and modify a command without affecting how it is sent to the connection layer. - #[doc(hidden)] - fn change_command(&self, _: &mut RedisCommand) {} - - /// Helper function to intercept and customize how a command is sent to the connection layer. - #[doc(hidden)] - fn send_command(&self, command: C) -> Result<(), RedisError> - where - C: Into, - { - let mut command: RedisCommand = command.into(); - self.change_command(&mut command); - default_send_command(self.inner(), command) - } - - /// The unique ID identifying this client and underlying connections. - fn id(&self) -> &str { - &self.inner().id - } - - /// Read the config used to initialize the client. - fn client_config(&self) -> RedisConfig { - self.inner().config.as_ref().clone() - } - - /// Read the reconnect policy used to initialize the client. - fn client_reconnect_policy(&self) -> Option { - self.inner().policy.read().clone() - } - - /// Read the connection config used to initialize the client. - fn connection_config(&self) -> &ConnectionConfig { - self.inner().connection.as_ref() - } - - /// Read the RESP version used by the client when communicating with the server. - fn protocol_version(&self) -> RespVersion { - if self.inner().is_resp3() { - RespVersion::RESP3 - } else { - RespVersion::RESP2 - } - } - - /// Whether the client has a reconnection policy. - fn has_reconnect_policy(&self) -> bool { - self.inner().policy.read().is_some() - } - - /// Whether the client will automatically pipeline commands. - fn is_pipelined(&self) -> bool { - self.inner().is_pipelined() - } - - /// Whether the client is connected to a cluster. - fn is_clustered(&self) -> bool { - self.inner().config.server.is_clustered() - } - - /// Whether the client uses the sentinel interface. - fn uses_sentinels(&self) -> bool { - self.inner().config.server.is_sentinel() - } - - /// Update the internal [PerformanceConfig](crate::types::PerformanceConfig) in place with new values. - fn update_perf_config(&self, config: PerformanceConfig) { - self.inner().update_performance_config(config); - } - - /// Read the [PerformanceConfig](crate::types::PerformanceConfig) associated with this client. - fn perf_config(&self) -> PerformanceConfig { - self.inner().performance_config() - } - - /// Read the state of the underlying connection(s). - /// - /// If running against a cluster the underlying state will reflect the state of the least healthy connection. - fn state(&self) -> ClientState { - self.inner().state.read().clone() - } - - /// Whether all underlying connections are healthy. - fn is_connected(&self) -> bool { - *self.inner().state.read() == ClientState::Connected - } - - /// Read the set of active connections managed by the client. - fn active_connections(&self) -> impl Future, RedisError>> + Send { - commands::server::active_connections(self) - } - - /// Read the server version, if known. - fn server_version(&self) -> Option { - self.inner().server_state.read().kind.server_version() - } - - /// Override the DNS resolution logic for the client. - #[cfg(feature = "dns")] - #[cfg_attr(docsrs, doc(cfg(feature = "dns")))] - fn set_resolver(&self, resolver: Arc) -> impl Future + Send { - async move { self.inner().set_resolver(resolver).await } - } +#[cfg(not(feature = "glommio"))] +pub use crate::_tokio::ClientLike; +#[cfg(feature = "glommio")] +pub use crate::glommio::interfaces::ClientLike; - /// Connect to the server. - /// - /// This function returns a `JoinHandle` to a task that drives the connection. It will not resolve until the - /// connection closes, or if a reconnection policy with unlimited attempts is provided then it will - /// run until `QUIT` is called. Callers should avoid calling [abort](tokio::task::JoinHandle::abort) on the returned - /// `JoinHandle` unless the client will no longer be used. - /// - /// **Calling this function more than once will drop all state associated with the previous connection(s).** Any - /// pending commands on the old connection(s) will either finish or timeout, but they will not be retried on the - /// new connection(s). - /// - /// See [init](Self::init) for an alternative shorthand. - fn connect(&self) -> ConnectHandle { - let inner = self.inner().clone(); - utils::reset_router_task(&inner); - - tokio::spawn(async move { - utils::clear_backchannel_state(&inner).await; - let result = router_commands::start(&inner).await; - // a canceled error means we intentionally closed the client - _trace!(inner, "Ending connection task with {:?}", result); - - if let Err(ref error) = result { - if !error.is_canceled() { - inner.notifications.broadcast_connect(Err(error.clone())); - } - } - - utils::check_and_set_client_state(&inner.state, ClientState::Disconnecting, ClientState::Disconnected); - result - }) - } - - /// Force a reconnection to the server(s). - /// - /// When running against a cluster this function will also refresh the cached cluster routing table. - fn force_reconnection(&self) -> impl Future> + Send { - async move { commands::server::force_reconnection(self.inner()).await } - } - - /// Wait for the result of the next connection attempt. - /// - /// This can be used with `on_reconnect` to separate initialization logic that needs to occur only on the next - /// connection attempt vs all subsequent attempts. - fn wait_for_connect(&self) -> impl Future> + Send { - async move { - if utils::read_locked(&self.inner().state) == ClientState::Connected { - debug!("{}: Client is already connected.", self.inner().id); - Ok(()) - } else { - self.inner().notifications.connect.load().subscribe().recv().await? - } - } - } - - /// Initialize a new routing and connection task and wait for it to connect successfully. - /// - /// The returned [ConnectHandle](crate::types::ConnectHandle) refers to the task that drives the routing and - /// connection layer. It will not finish until the max reconnection count is reached. Callers should avoid calling - /// [abort](tokio::task::JoinHandle::abort) on the returned `JoinHandle` unless the client will no longer be used. - /// - /// Callers can also use [connect](Self::connect) and [wait_for_connect](Self::wait_for_connect) separately if - /// needed. - /// - /// ```rust - /// use fred::prelude::*; - /// - /// #[tokio::main] - /// async fn main() -> Result<(), RedisError> { - /// let client = RedisClient::default(); - /// let connection_task = client.init().await?; - /// - /// // ... - /// - /// client.quit().await?; - /// connection_task.await? - /// } - /// ``` - fn init(&self) -> impl Future> + Send { - async move { - let mut rx = { self.inner().notifications.connect.load().subscribe() }; - let task = self.connect(); - let error = rx.recv().await.map_err(RedisError::from).and_then(|r| r).err(); - - if let Some(error) = error { - // the initial connection failed, so we should gracefully close the routing task - utils::reset_router_task(self.inner()); - Err(error) - } else { - Ok(task) - } - } - } - - /// Close the connection to the Redis server. The returned future resolves when the command has been written to the - /// socket, not when the connection has been fully closed. Some time after this future resolves the future - /// returned by [connect](Self::connect) will resolve which indicates that the connection has been fully closed. - /// - /// This function will also close all error, pubsub message, and reconnection event streams. - fn quit(&self) -> impl Future> + Send { - async move { commands::server::quit(self).await } - } - - /// Shut down the server and quit the client. - /// - /// - #[cfg(feature = "i-server")] - #[cfg_attr(docsrs, doc(cfg(feature = "i-server")))] - fn shutdown(&self, flags: Option) -> impl Future> + Send { - async move { commands::server::shutdown(self, flags).await } - } - - /// Delete the keys in all databases. - /// - /// - fn flushall(&self, r#async: bool) -> impl Future> + Send - where - R: FromRedis, - { - async move { commands::server::flushall(self, r#async).await?.convert() } - } - - /// Delete the keys on all nodes in the cluster. This is a special function that does not map directly to the Redis - /// interface. - fn flushall_cluster(&self) -> impl Future> + Send { - async move { commands::server::flushall_cluster(self).await } - } - - /// Ping the Redis server. - /// - /// - fn ping(&self) -> impl Future> + Send - where - R: FromRedis, - { - async move { commands::server::ping(self).await?.convert() } - } - - /// Read info about the server. - /// - /// - fn info(&self, section: Option) -> impl Future> + Send - where - R: FromRedis, - { - async move { commands::server::info(self, section).await?.convert() } - } - - /// Run a custom command that is not yet supported via another interface on this client. This is most useful when - /// interacting with third party modules or extensions. - /// - /// Callers should use the re-exported [redis_keyslot](crate::util::redis_keyslot) function to hash the command's - /// key, if necessary. - /// - /// This interface should be used with caution as it may break the automatic pipeline features in the client if - /// command flags are not properly configured. - fn custom(&self, cmd: CustomCommand, args: Vec) -> impl Future> + Send - where - R: FromRedis, - T: TryInto + Send, - T::Error: Into + Send, - { - async move { - let args = utils::try_into_vec(args)?; - commands::server::custom(self, cmd, args).await?.convert() - } - } - - /// Run a custom command similar to [custom](Self::custom), but return the response frame directly without any - /// parsing. - /// - /// Note: RESP2 frames from the server are automatically converted to the RESP3 format when parsed by the client. - fn custom_raw(&self, cmd: CustomCommand, args: Vec) -> impl Future> + Send - where - T: TryInto + Send, - T::Error: Into + Send, - { - async move { - let args = utils::try_into_vec(args)?; - commands::server::custom_raw(self, cmd, args).await - } - } - - /// Customize various configuration options on commands. - fn with_options(&self, options: &Options) -> WithOptions { - WithOptions { - client: self.clone(), - options: options.clone(), - } - } -} - -fn spawn_event_listener(mut rx: BroadcastReceiver, func: F) -> JoinHandle> -where - T: Clone + Send + 'static, - F: Fn(T) -> RedisResult<()> + Send + 'static, -{ - tokio::spawn(async move { - let mut result = Ok(()); - - while let Ok(val) = rx.recv().await { - if let Err(err) = func(val) { - result = Err(err); - break; - } - } - - result - }) -} +#[cfg(not(feature = "glommio"))] +pub use crate::_tokio::spawn_event_listener; +#[cfg(feature = "glommio")] +pub use crate::glommio::interfaces::spawn_event_listener; /// Functions that provide a connection heartbeat interface. +#[rm_send_if(feature = "glommio")] pub trait HeartbeatInterface: ClientLike { /// Return a future that will ping the server on an interval. #[allow(unreachable_code)] @@ -448,10 +114,9 @@ pub trait HeartbeatInterface: ClientLike { ) -> impl Future> + Send { async move { let _self = self.clone(); - let mut interval = tokio::time::interval(interval); loop { - interval.tick().await; + sleep(interval).await; if break_on_error { let _: () = _self.ping().await?; @@ -466,6 +131,7 @@ pub trait HeartbeatInterface: ClientLike { } /// Functions for authenticating clients. +#[rm_send_if(feature = "glommio")] pub trait AuthInterface: ClientLike { /// Request for authentication in a password-protected Redis server. Returns ok if successful. /// @@ -503,6 +169,7 @@ pub trait AuthInterface: ClientLike { /// An interface that exposes various client and connection events. /// /// Calling [quit](crate::interfaces::ClientLike::quit) will close all event streams. +#[rm_send_if(feature = "glommio")] pub trait EventInterface: ClientLike { /// Spawn a task that runs the provided function on each publish-subscribe message. /// @@ -581,7 +248,7 @@ pub trait EventInterface: ClientLike { let mut reconnect_rx = self.reconnect_rx(); let mut cluster_rx = self.cluster_change_rx(); - tokio::spawn(async move { + spawn(async move { #[allow(unused_assignments)] let mut result = Ok(()); diff --git a/src/lib.rs b/src/lib.rs index 1ba4a914..eb920c44 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,6 +11,7 @@ #![allow(clippy::too_many_arguments)] #![allow(clippy::new_without_default)] #![allow(clippy::assigning_clones)] +#![allow(clippy::manual_async_fn)] #![warn(clippy::large_types_passed_by_value)] #![warn(clippy::large_stack_frames)] #![warn(clippy::large_futures)] @@ -71,6 +72,16 @@ pub mod monitor; /// The structs and enums used by the Redis client. pub mod types; +#[cfg(feature = "glommio")] +mod glommio; +#[cfg(feature = "glommio")] +pub(crate) use glommio::compat as runtime; + +#[cfg(not(feature = "glommio"))] +mod _tokio; +#[cfg(not(feature = "glommio"))] +pub(crate) use _tokio as runtime; + /// Various client utility functions. pub mod util { pub use crate::utils::{f64_to_redis_string, redis_string_to_f64, static_bytes, static_str}; diff --git a/src/modules/backchannel.rs b/src/modules/backchannel.rs index ede036c5..a6d10e6b 100644 --- a/src/modules/backchannel.rs +++ b/src/modules/backchannel.rs @@ -3,17 +3,18 @@ use crate::{ modules::inner::RedisClientInner, protocol::{command::RedisCommand, connection, connection::RedisTransport, types::Server}, router::Connections, + runtime::RefCount, utils, }; use redis_protocol::resp3::types::BytesFrame as Resp3Frame; -use std::{collections::HashMap, sync::Arc}; +use std::collections::HashMap; /// Check if an existing connection can be used to the provided `server`, otherwise create a new one. /// /// Returns whether a new connection was created. async fn check_and_create_transport( backchannel: &mut Backchannel, - inner: &Arc, + inner: &RefCount, server: &Server, ) -> Result { if let Some(ref mut transport) = backchannel.transport { @@ -45,7 +46,7 @@ pub struct Backchannel { impl Backchannel { /// Check if the current server matches the provided server, and disconnect. // TODO does this need to disconnect whenever the caller manually changes the RESP protocol mode? - pub async fn check_and_disconnect(&mut self, inner: &Arc, server: Option<&Server>) { + pub async fn check_and_disconnect(&mut self, inner: &RefCount, server: Option<&Server>) { let should_close = self .current_server() .map(|current| server.map(|server| *server == current).unwrap_or(true)) @@ -60,7 +61,7 @@ impl Backchannel { } /// Clear all local state that depends on the associated `Router` instance. - pub async fn clear_router_state(&mut self, inner: &Arc) { + pub async fn clear_router_state(&mut self, inner: &RefCount) { self.connection_ids.clear(); self.blocked = None; @@ -149,7 +150,7 @@ impl Backchannel { /// If a new connection is created this function also sets it on `self` before returning. pub async fn request_response( &mut self, - inner: &Arc, + inner: &RefCount, server: &Server, command: RedisCommand, ) -> Result { @@ -188,7 +189,7 @@ impl Backchannel { /// * Failing all of the above a random server will be used. pub fn find_server( &self, - inner: &Arc, + inner: &RefCount, command: &RedisCommand, use_blocked: bool, ) -> Result { diff --git a/src/modules/inner.rs b/src/modules/inner.rs index a1831a37..088e1990 100644 --- a/src/modules/inner.rs +++ b/src/modules/inner.rs @@ -7,33 +7,39 @@ use crate::{ connection::RedisTransport, types::{ClusterRouting, DefaultResolver, Resolve, Server}, }, + runtime::{ + broadcast_channel, + broadcast_send, + sleep, + unbounded_channel, + AsyncRwLock, + AtomicBool, + AtomicUsize, + BroadcastSender, + Mutex, + RefCount, + RefSwap, + RwLock, + UnboundedReceiver, + UnboundedSender, + }, types::*, utils, }; -use arc_swap::ArcSwap; use bytes_utils::Str; use futures::future::{select, Either}; -use parking_lot::{Mutex, RwLock}; use semver::Version; -use std::{ - ops::DerefMut, - sync::{ - atomic::{AtomicBool, AtomicUsize}, - Arc, - }, - time::Duration, -}; -use tokio::{ - sync::{ - broadcast::{self, error::SendError as BroadcastSendError, Sender as BroadcastSender}, - mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}, - RwLock as AsyncRwLock, - }, - time::sleep, -}; +use std::{ops::DerefMut, time::Duration}; #[cfg(feature = "metrics")] use crate::modules::metrics::MovingStats; +#[cfg(feature = "credential-provider")] +use crate::{ + clients::RedisClient, + interfaces::RedisResult, + interfaces::{AuthInterface, ClientLike}, + runtime::{spawn, JoinHandle}, +}; #[cfg(feature = "replicas")] use std::collections::HashMap; @@ -47,42 +53,42 @@ pub struct Notifications { /// The client ID. pub id: Str, /// A broadcast channel for the `on_error` interface. - pub errors: ArcSwap>, + pub errors: RefSwap>>, /// A broadcast channel for the `on_message` interface. - pub pubsub: ArcSwap>, + pub pubsub: RefSwap>>, /// A broadcast channel for the `on_keyspace_event` interface. - pub keyspace: ArcSwap>, + pub keyspace: RefSwap>>, /// A broadcast channel for the `on_reconnect` interface. - pub reconnect: ArcSwap>, + pub reconnect: RefSwap>>, /// A broadcast channel for the `on_cluster_change` interface. - pub cluster_change: ArcSwap>>, + pub cluster_change: RefSwap>>>, /// A broadcast channel for the `on_connect` interface. - pub connect: ArcSwap>>, + pub connect: RefSwap>>>, /// A channel for events that should close all client tasks with `Canceled` errors. /// /// Emitted when QUIT, SHUTDOWN, etc are called. pub close: BroadcastSender<()>, /// A broadcast channel for the `on_invalidation` interface. #[cfg(feature = "i-tracking")] - pub invalidations: ArcSwap>, + pub invalidations: RefSwap>>, /// A broadcast channel for notifying callers when servers go unresponsive. - pub unresponsive: ArcSwap>, + pub unresponsive: RefSwap>>, } impl Notifications { pub fn new(id: &Str, capacity: usize) -> Self { Notifications { id: id.clone(), - close: broadcast::channel(capacity).0, - errors: ArcSwap::new(Arc::new(broadcast::channel(capacity).0)), - pubsub: ArcSwap::new(Arc::new(broadcast::channel(capacity).0)), - keyspace: ArcSwap::new(Arc::new(broadcast::channel(capacity).0)), - reconnect: ArcSwap::new(Arc::new(broadcast::channel(capacity).0)), - cluster_change: ArcSwap::new(Arc::new(broadcast::channel(capacity).0)), - connect: ArcSwap::new(Arc::new(broadcast::channel(capacity).0)), + close: broadcast_channel(capacity).0, + errors: RefSwap::new(RefCount::new(broadcast_channel(capacity).0)), + pubsub: RefSwap::new(RefCount::new(broadcast_channel(capacity).0)), + keyspace: RefSwap::new(RefCount::new(broadcast_channel(capacity).0)), + reconnect: RefSwap::new(RefCount::new(broadcast_channel(capacity).0)), + cluster_change: RefSwap::new(RefCount::new(broadcast_channel(capacity).0)), + connect: RefSwap::new(RefCount::new(broadcast_channel(capacity).0)), #[cfg(feature = "i-tracking")] - invalidations: ArcSwap::new(Arc::new(broadcast::channel(capacity).0)), - unresponsive: ArcSwap::new(Arc::new(broadcast::channel(capacity).0)), + invalidations: RefSwap::new(RefCount::new(broadcast_channel(capacity).0)), + unresponsive: RefSwap::new(RefCount::new(broadcast_channel(capacity).0)), } } @@ -100,74 +106,74 @@ impl Notifications { } pub fn broadcast_error(&self, error: RedisError) { - if let Err(BroadcastSendError(err)) = self.errors.load().send(error) { + broadcast_send(self.errors.load().as_ref(), &error, |err| { debug!("{}: No `on_error` listener. The error was: {err:?}", self.id); - } + }); } pub fn broadcast_pubsub(&self, message: Message) { - if let Err(_) = self.pubsub.load().send(message) { + broadcast_send(self.pubsub.load().as_ref(), &message, |_| { debug!("{}: No `on_message` listeners.", self.id); - } + }); } pub fn broadcast_keyspace(&self, event: KeyspaceEvent) { - if let Err(_) = self.keyspace.load().send(event) { + broadcast_send(self.keyspace.load().as_ref(), &event, |_| { debug!("{}: No `on_keyspace_event` listeners.", self.id); - } + }); } pub fn broadcast_reconnect(&self, server: Server) { - if let Err(_) = self.reconnect.load().send(server) { + broadcast_send(self.reconnect.load().as_ref(), &server, |_| { debug!("{}: No `on_reconnect` listeners.", self.id); - } + }); } pub fn broadcast_cluster_change(&self, changes: Vec) { - if let Err(_) = self.cluster_change.load().send(changes) { + broadcast_send(self.cluster_change.load().as_ref(), &changes, |_| { debug!("{}: No `on_cluster_change` listeners.", self.id); - } + }); } pub fn broadcast_connect(&self, result: Result<(), RedisError>) { - if let Err(_) = self.connect.load().send(result) { + broadcast_send(self.connect.load().as_ref(), &result, |_| { debug!("{}: No `on_connect` listeners.", self.id); - } + }); } /// Interrupt any tokio `sleep` calls. //`RedisClientInner::wait_with_interrupt` hides the subscription part from callers. pub fn broadcast_close(&self) { - if let Err(_) = self.close.send(()) { + broadcast_send(&self.close, &(), |_| { debug!("{}: No `close` listeners.", self.id); - } + }); } #[cfg(feature = "i-tracking")] pub fn broadcast_invalidation(&self, msg: Invalidation) { - if let Err(_) = self.invalidations.load().send(msg) { + broadcast_send(self.invalidations.load().as_ref(), &msg, |_| { debug!("{}: No `on_invalidation` listeners.", self.id); - } + }); } pub fn broadcast_unresponsive(&self, server: Server) { - if let Err(_) = self.unresponsive.load().send(server) { + broadcast_send(self.unresponsive.load().as_ref(), &server, |_| { debug!("{}: No unresponsive listeners", self.id); - } + }); } } #[derive(Clone)] pub struct ClientCounters { - pub cmd_buffer_len: Arc, - pub redelivery_count: Arc, + pub cmd_buffer_len: RefCount, + pub redelivery_count: RefCount, } impl Default for ClientCounters { fn default() -> Self { ClientCounters { - cmd_buffer_len: Arc::new(AtomicUsize::new(0)), - redelivery_count: Arc::new(AtomicUsize::new(0)), + cmd_buffer_len: RefCount::new(AtomicUsize::new(0)), + redelivery_count: RefCount::new(AtomicUsize::new(0)), } } } @@ -370,8 +376,53 @@ impl ServerKind { } // TODO make a config option for other defaults and extend this -fn create_resolver(id: &Str) -> Arc { - Arc::new(DefaultResolver::new(id)) +fn create_resolver(id: &Str) -> RefCount { + RefCount::new(DefaultResolver::new(id)) +} + +#[cfg(feature = "credential-provider")] +fn spawn_credential_refresh(client: RedisClient, interval: Duration) -> JoinHandle> { + spawn(async move { + loop { + trace!( + "{}: Waiting {} ms before refreshing credentials.", + client.inner.id, + interval.as_millis() + ); + client.inner.wait_with_interrupt(interval).await?; + + let (username, password) = match client.inner.config.credential_provider { + Some(ref provider) => match provider.fetch(None).await { + Ok(creds) => creds, + Err(e) => { + warn!("{}: Failed to fetch and refresh credentials: {e:?}", client.inner.id); + continue; + }, + }, + None => (None, None), + }; + + if client.state() != ClientState::Connected { + debug!("{}: Skip credential refresh when disconnected", client.inner.id); + continue; + } + + if let Some(password) = password { + if client.inner.config.version == RespVersion::RESP3 { + let username = username.unwrap_or("default".into()); + let result = client + .hello(RespVersion::RESP3, Some((username.into(), password.into())), None) + .await; + + if let Err(err) = result { + warn!("{}: Failed to refresh credentials: {err}", client.inner.id); + } + } else if let Err(err) = client.auth(username, password).await { + warn!("{}: Failed to refresh credentials: {err}", client.inner.id); + } + } + } + }) } pub struct RedisClientInner { @@ -380,32 +431,36 @@ pub struct RedisClientInner { /// The client ID used for logging and the default `CLIENT SETNAME` value. pub id: Str, /// Whether the client uses RESP3. - pub resp3: Arc, + pub resp3: RefCount, /// The state of the underlying connection. pub state: RwLock, /// Client configuration options. - pub config: Arc, + pub config: RefCount, /// Connection configuration options. - pub connection: Arc, + pub connection: RefCount, /// Performance config options for the client. - pub performance: ArcSwap, + pub performance: RefSwap>, /// An optional reconnect policy. pub policy: RwLock>, /// Notification channels for the event interfaces. - pub notifications: Arc, - /// An mpsc sender for commands to the router. - pub command_tx: ArcSwap, - /// Temporary storage for the receiver half of the router command channel. - pub command_rx: RwLock>, + pub notifications: RefCount, /// Shared counters. pub counters: ClientCounters, /// The DNS resolver to use when establishing new connections. - pub resolver: AsyncRwLock>, + pub resolver: AsyncRwLock>, /// A backchannel that can be used to control the router connections even while the connections are blocked. - pub backchannel: Arc>, + pub backchannel: RefCount>, /// Server state cache for various deployment types. pub server_state: RwLock, + /// An mpsc sender for commands to the router. + pub command_tx: RefSwap>, + /// Temporary storage for the receiver half of the router command channel. + pub command_rx: RwLock>, + + /// A handle to a task that refreshes credentials on an interval. + #[cfg(feature = "credential-provider")] + pub credentials_task: RwLock>>>, /// Command latency metrics. #[cfg(feature = "metrics")] pub latency_stats: RwLock, @@ -414,10 +469,17 @@ pub struct RedisClientInner { pub network_latency_stats: RwLock, /// Payload size metrics tracking for requests. #[cfg(feature = "metrics")] - pub req_size_stats: Arc>, + pub req_size_stats: RefCount>, /// Payload size metrics tracking for responses #[cfg(feature = "metrics")] - pub res_size_stats: Arc>, + pub res_size_stats: RefCount>, +} + +#[cfg(feature = "credential-provider")] +impl Drop for RedisClientInner { + fn drop(&mut self) { + self.abort_credential_refresh_task(); + } } impl RedisClientInner { @@ -426,35 +488,39 @@ impl RedisClientInner { perf: PerformanceConfig, connection: ConnectionConfig, policy: Option, - ) -> Arc { + ) -> RefCount { let id = Str::from(format!("fred-{}", utils::random_string(10))); let resolver = AsyncRwLock::new(create_resolver(&id)); let (command_tx, command_rx) = unbounded_channel(); - let notifications = Arc::new(Notifications::new(&id, perf.broadcast_channel_capacity)); - let (config, policy) = (Arc::new(config), RwLock::new(policy)); - let performance = ArcSwap::new(Arc::new(perf)); + let notifications = RefCount::new(Notifications::new(&id, perf.broadcast_channel_capacity)); + let (config, policy) = (RefCount::new(config), RwLock::new(policy)); + let performance = RefSwap::new(RefCount::new(perf)); let (counters, state) = (ClientCounters::default(), RwLock::new(ClientState::Disconnected)); let command_rx = RwLock::new(Some(command_rx)); - let backchannel = Arc::new(AsyncRwLock::new(Backchannel::default())); + let backchannel = RefCount::new(AsyncRwLock::new(Backchannel::default())); let server_state = RwLock::new(ServerState::new(&config)); let resp3 = if config.version == RespVersion::RESP3 { - Arc::new(AtomicBool::new(true)) + RefCount::new(AtomicBool::new(true)) } else { - Arc::new(AtomicBool::new(false)) + RefCount::new(AtomicBool::new(false)) }; - let connection = Arc::new(connection); - let command_tx = ArcSwap::new(Arc::new(command_tx)); + let connection = RefCount::new(connection); + #[cfg(feature = "glommio")] + let command_tx = command_tx.into(); + let command_tx = RefSwap::new(RefCount::new(command_tx)); - Arc::new(RedisClientInner { + RefCount::new(RedisClientInner { _lock: Mutex::new(()), #[cfg(feature = "metrics")] latency_stats: RwLock::new(MovingStats::default()), #[cfg(feature = "metrics")] network_latency_stats: RwLock::new(MovingStats::default()), #[cfg(feature = "metrics")] - req_size_stats: Arc::new(RwLock::new(MovingStats::default())), + req_size_stats: RefCount::new(RwLock::new(MovingStats::default())), #[cfg(feature = "metrics")] - res_size_stats: Arc::new(RwLock::new(MovingStats::default())), + res_size_stats: RefCount::new(RwLock::new(MovingStats::default())), + #[cfg(feature = "credential-provider")] + credentials_task: RwLock::new(None), backchannel, command_rx, @@ -488,8 +554,8 @@ impl RedisClientInner { } /// Swap the command channel sender, returning the old one. - pub fn swap_command_tx(&self, tx: CommandSender) -> Arc { - self.command_tx.swap(Arc::new(tx)) + pub fn swap_command_tx(&self, tx: CommandSender) -> RefCount { + self.command_tx.swap(RefCount::new(tx)) } /// Whether the client has the command channel receiver stored. If not then the caller can assume another @@ -503,7 +569,7 @@ impl RedisClientInner { self.server_state.write().replicas.clear() } - pub fn shared_resp3(&self) -> Arc { + pub fn shared_resp3(&self) -> RefCount { self.resp3.clone() } @@ -516,7 +582,7 @@ impl RedisClientInner { } } - pub async fn set_resolver(&self, resolver: Arc) { + pub async fn set_resolver(&self, resolver: RefCount) { let mut guard = self.resolver.write().await; *guard = resolver; } @@ -528,8 +594,8 @@ impl RedisClientInner { } } - pub async fn get_resolver(&self) -> Arc { - self.resolver.read().await.clone() + pub async fn get_resolver(&self) -> RefCount { + self.resolver.write().await.clone() } pub fn client_name(&self) -> &str { @@ -598,7 +664,7 @@ impl RedisClientInner { } pub fn update_performance_config(&self, config: PerformanceConfig) { - self.performance.store(Arc::new(config)); + self.performance.store(RefCount::new(config)); } pub fn performance_config(&self) -> PerformanceConfig { @@ -666,7 +732,7 @@ impl RedisClientInner { } pub fn send_reconnect( - self: &Arc, + self: &RefCount, server: Option, force: bool, tx: Option, @@ -686,7 +752,7 @@ impl RedisClientInner { } #[cfg(feature = "replicas")] - pub fn send_replica_reconnect(self: &Arc, server: &Server) { + pub fn send_replica_reconnect(self: &RefCount, server: &Server) { debug!( "{}: Sending replica reconnect message to router for {:?}", self.id, server @@ -718,6 +784,7 @@ impl RedisClientInner { } pub async fn wait_with_interrupt(&self, duration: Duration) -> Result<(), RedisError> { + #[allow(unused_mut)] let mut rx = self.notifications.close.subscribe(); debug!("{}: Sleeping for {} ms", self.id, duration.as_millis()); let (sleep_ft, recv_ft) = (sleep(duration), rx.recv()); @@ -730,4 +797,57 @@ impl RedisClientInner { Ok(()) } } + + #[cfg(not(feature = "glommio"))] + pub fn send_command(&self, command: RouterCommand) -> Result<(), RouterCommand> { + self.command_tx.load().send(command).map_err(|e| e.0) + } + + #[cfg(feature = "glommio")] + pub fn send_command(&self, command: RouterCommand) -> Result<(), RouterCommand> { + self.command_tx.load().try_send(command).map_err(|e| match e { + glommio::GlommioError::Closed(glommio::ResourceType::Channel(v)) => v, + glommio::GlommioError::WouldBlock(glommio::ResourceType::Channel(v)) => v, + _ => unreachable!(), + }) + } + + #[cfg(not(feature = "credential-provider"))] + pub async fn read_credentials(&self, _: &Server) -> Result<(Option, Option), RedisError> { + Ok((self.config.username.clone(), self.config.password.clone())) + } + + #[cfg(feature = "credential-provider")] + pub async fn read_credentials(&self, server: &Server) -> Result<(Option, Option), RedisError> { + Ok(if let Some(ref provider) = self.config.credential_provider { + provider.fetch(Some(server)).await? + } else { + (self.config.username.clone(), self.config.password.clone()) + }) + } + + #[cfg(feature = "credential-provider")] + pub fn reset_credential_refresh_task(self: &RefCount) { + let mut guard = self.credentials_task.write(); + + if let Some(task) = guard.take() { + task.abort(); + } + let refresh_interval = self + .config + .credential_provider + .as_ref() + .and_then(|provider| provider.refresh_interval()); + + if let Some(interval) = refresh_interval { + *guard = Some(spawn_credential_refresh(self.into(), interval)); + } + } + + #[cfg(feature = "credential-provider")] + pub fn abort_credential_refresh_task(&self) { + if let Some(task) = self.credentials_task.write().take() { + task.abort(); + } + } } diff --git a/src/modules/mocks.rs b/src/modules/mocks.rs index 559742bc..057cc963 100644 --- a/src/modules/mocks.rs +++ b/src/modules/mocks.rs @@ -14,10 +14,11 @@ use crate::{ error::{RedisError, RedisErrorKind}, + runtime::Mutex, types::{RedisKey, RedisValue}, }; use bytes_utils::Str; -use parking_lot::Mutex; +use fred_macros::rm_send_if; use std::{ collections::{HashMap, VecDeque}, fmt::Debug, @@ -42,6 +43,7 @@ pub struct MockCommand { /// An interface for intercepting and processing Redis commands in a mocking layer. #[allow(unused_variables)] +#[rm_send_if(feature = "glommio")] pub trait Mocks: Debug + Send + Sync + 'static { /// Intercept and process a Redis command, returning any `RedisValue`. /// @@ -326,10 +328,10 @@ mod tests { interfaces::{ClientLike, KeysInterface}, mocks::{Buffer, Echo, Mocks, SimpleMap}, prelude::Expiration, + runtime::JoinHandle, types::{RedisConfig, RedisValue, SetOptions}, }; use std::sync::Arc; - use tokio::task::JoinHandle; async fn create_mock_client(mocks: Arc) -> (RedisClient, JoinHandle>) { let config = RedisConfig { @@ -403,4 +405,20 @@ mod tests { ]; assert_eq!(buffer.take(), expected); } + + #[tokio::test] + async fn should_mock_pipelines() { + let (client, _) = create_mock_client(Arc::new(Echo)).await; + + let pipeline = client.pipeline(); + pipeline.get::<(), _>("foo").await.unwrap(); + pipeline.get::<(), _>("bar").await.unwrap(); + + let all: Vec> = pipeline.all().await.unwrap(); + assert_eq!(all, vec![vec!["foo"], vec!["bar"]]); + let try_all = pipeline.try_all::>().await; + assert_eq!(try_all, vec![Ok(vec!["foo".to_string()]), Ok(vec!["bar".to_string()])]); + let last: Vec = pipeline.last().await.unwrap(); + assert_eq!(last, vec!["bar"]); + } } diff --git a/src/monitor/parser.rs b/src/monitor/parser.rs index 0b61f52a..f7021527 100644 --- a/src/monitor/parser.rs +++ b/src/monitor/parser.rs @@ -1,4 +1,4 @@ -use crate::{modules::inner::RedisClientInner, monitor::Command, types::RedisValue}; +use crate::{modules::inner::RedisClientInner, monitor::Command, runtime::RefCount, types::RedisValue}; use nom::{ bytes::complete::{escaped as nom_escaped, tag as nom_tag, take as nom_take, take_until as nom_take_until}, character::complete::none_of as nom_none_of, @@ -11,7 +11,7 @@ use redis_protocol::{ error::RedisParseError, resp3::types::{BytesFrame as Resp3Frame, Resp3Frame as _Resp3Frame}, }; -use std::{str, sync::Arc}; +use std::str; const EMPTY_SPACE: &str = " "; const RIGHT_BRACKET: &str = "]"; @@ -110,7 +110,7 @@ fn d_parse_frame(input: &[u8]) -> Result> { } #[cfg(feature = "network-logs")] -fn log_frame(inner: &Arc, frame: &[u8]) { +fn log_frame(inner: &RefCount, frame: &[u8]) { if let Ok(s) = str::from_utf8(frame) { _trace!(inner, "Monitor frame: {}", s); } else { @@ -119,9 +119,9 @@ fn log_frame(inner: &Arc, frame: &[u8]) { } #[cfg(not(feature = "network-logs"))] -fn log_frame(_: &Arc, _: &[u8]) {} +fn log_frame(_: &RefCount, _: &[u8]) {} -pub fn parse(inner: &Arc, frame: Resp3Frame) -> Option { +pub fn parse(inner: &RefCount, frame: Resp3Frame) -> Option { let frame_bytes = match frame { Resp3Frame::SimpleString { ref data, .. } => data, Resp3Frame::BlobString { ref data, .. } => data, diff --git a/src/monitor/utils.rs b/src/monitor/utils.rs index 66051b2f..c43befb9 100644 --- a/src/monitor/utils.rs +++ b/src/monitor/utils.rs @@ -9,21 +9,21 @@ use crate::{ types::ProtocolFrame, utils as protocol_utils, }, + runtime::{spawn, unbounded_channel, RefCount, UnboundedSender}, types::{ConnectionConfig, PerformanceConfig, RedisConfig, ServerConfig}, }; use futures::stream::{Stream, StreamExt}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_util::codec::Framed; + +#[cfg(all(feature = "blocking-encoding", not(feature = "glommio")))] use redis_protocol::resp3::types::Resp3Frame; -use std::sync::Arc; -use tokio::{ - io::{AsyncRead, AsyncWrite}, - sync::mpsc::{unbounded_channel, UnboundedSender}, -}; +#[cfg(not(feature = "glommio"))] use tokio_stream::wrappers::UnboundedReceiverStream; -use tokio_util::codec::Framed; -#[cfg(feature = "blocking-encoding")] +#[cfg(all(feature = "blocking-encoding", not(feature = "glommio")))] async fn handle_monitor_frame( - inner: &Arc, + inner: &RefCount, frame: Result, ) -> Option { let frame = match frame { @@ -53,9 +53,9 @@ async fn handle_monitor_frame( } } -#[cfg(not(feature = "blocking-encoding"))] +#[cfg(any(not(feature = "blocking-encoding"), feature = "glommio"))] async fn handle_monitor_frame( - inner: &Arc, + inner: &RefCount, frame: Result, ) -> Option { let frame = match frame { @@ -70,7 +70,7 @@ async fn handle_monitor_frame( } async fn send_monitor_command( - inner: &Arc, + inner: &RefCount, mut connection: RedisTransport, ) -> Result { _debug!(inner, "Sending MONITOR command."); @@ -85,7 +85,7 @@ async fn send_monitor_command( } async fn forward_results( - inner: &Arc, + inner: &RefCount, tx: UnboundedSender, mut framed: Framed, ) where @@ -103,7 +103,11 @@ async fn forward_results( } } -async fn process_stream(inner: &Arc, tx: UnboundedSender, connection: RedisTransport) { +async fn process_stream( + inner: &RefCount, + tx: UnboundedSender, + connection: RedisTransport, +) { _debug!(inner, "Starting monitor stream processing..."); match connection.transport { @@ -144,9 +148,14 @@ pub async fn start(config: RedisConfig) -> Result, R // background task with a channel to process the frames so that the server can keep sending data even if the // stream consumer slows down processing the frames. let (tx, rx) = unbounded_channel(); - tokio::spawn(async move { + #[cfg(feature = "glommio")] + let tx = tx.into(); + spawn(async move { process_stream(&inner, tx, connection).await; }); - Ok(UnboundedReceiverStream::new(rx)) + #[cfg(feature = "glommio")] + return Ok(crate::runtime::rx_stream(rx)); + #[cfg(not(feature = "glommio"))] + return Ok(UnboundedReceiverStream::new(rx)); } diff --git a/src/protocol/cluster.rs b/src/protocol/cluster.rs index bff57400..6fcf2679 100644 --- a/src/protocol/cluster.rs +++ b/src/protocol/cluster.rs @@ -2,11 +2,12 @@ use crate::{ error::{RedisError, RedisErrorKind}, modules::inner::RedisClientInner, protocol::types::{Server, SlotRange}, + runtime::RefCount, types::RedisValue, utils, }; use bytes_utils::Str; -use std::{collections::HashMap, net::IpAddr, str::FromStr, sync::Arc}; +use std::{collections::HashMap, net::IpAddr, str::FromStr}; #[cfg(any( feature = "enable-native-tls", @@ -234,7 +235,11 @@ fn replace_tls_server_names(policy: &TlsHostMapping, ranges: &mut [SlotRange], d feature = "enable-native-tls", feature = "enable-rustls-ring" ))] -pub fn modify_cluster_slot_hostnames(inner: &Arc, ranges: &mut [SlotRange], default_host: &Str) { +pub fn modify_cluster_slot_hostnames( + inner: &RefCount, + ranges: &mut [SlotRange], + default_host: &Str, +) { let policy = match inner.config.tls { Some(ref config) => &config.hostnames, None => { @@ -255,7 +260,7 @@ pub fn modify_cluster_slot_hostnames(inner: &Arc, ranges: &mut feature = "enable-native-tls", feature = "enable-rustls-ring" )))] -pub fn modify_cluster_slot_hostnames(inner: &Arc, _: &mut Vec, _: &Str) { +pub fn modify_cluster_slot_hostnames(inner: &RefCount, _: &mut Vec, _: &Str) { _trace!(inner, "Skip modifying TLS hostnames.") } @@ -445,7 +450,7 @@ mod tests { feature = "enable-rustls-ring" ))] fn should_modify_cluster_slot_hostnames_custom() { - let policy = TlsHostMapping::Custom(Arc::new(FakeHostMapper)); + let policy = TlsHostMapping::Custom(RefCount::new(FakeHostMapper)); let fake_data = fake_cluster_slots_without_metadata(); let mut ranges = parse_cluster_slots(fake_data, &Str::from("default-host")).unwrap(); replace_tls_server_names(&policy, &mut ranges, &Str::from("default-host")); diff --git a/src/protocol/codec.rs b/src/protocol/codec.rs index d6cba0cd..a69315f7 100644 --- a/src/protocol/codec.rs +++ b/src/protocol/codec.rs @@ -5,6 +5,7 @@ use crate::{ types::{ProtocolFrame, Server}, utils as protocol_utils, }, + runtime::{AtomicBool, RefCount}, utils, }; use bytes::BytesMut; @@ -21,13 +22,12 @@ use redis_protocol::{ types::{BytesFrame as Resp3Frame, Resp3Frame as _Resp3Frame, StreamedFrame}, }, }; -use std::sync::{atomic::AtomicBool, Arc}; use tokio_util::codec::{Decoder, Encoder}; #[cfg(feature = "metrics")] use crate::modules::metrics::MovingStats; #[cfg(feature = "metrics")] -use parking_lot::RwLock; +use crate::runtime::RwLock; #[cfg(not(feature = "network-logs"))] fn log_resp2_frame(_: &str, _: &Resp2Frame, _: bool) {} @@ -192,16 +192,16 @@ fn resp2_decode_with_fallback( pub struct RedisCodec { pub name: Str, pub server: Server, - pub resp3: Arc, + pub resp3: RefCount, pub streaming_state: Option>, #[cfg(feature = "metrics")] - pub req_size_stats: Arc>, + pub req_size_stats: RefCount>, #[cfg(feature = "metrics")] - pub res_size_stats: Arc>, + pub res_size_stats: RefCount>, } impl RedisCodec { - pub fn new(inner: &Arc, server: &Server) -> Self { + pub fn new(inner: &RefCount, server: &Server) -> Self { RedisCodec { server: server.clone(), name: inner.id.clone(), diff --git a/src/protocol/command.rs b/src/protocol/command.rs index af9ade78..ffa2535b 100644 --- a/src/protocol/command.rs +++ b/src/protocol/command.rs @@ -8,13 +8,13 @@ use crate::{ types::{ProtocolFrame, Server}, utils as protocol_utils, }, + runtime::{oneshot_channel, AtomicBool, Mutex, OneshotReceiver, OneshotSender, RefCount}, trace, types::{CustomCommand, RedisValue}, utils as client_utils, utils, }; use bytes_utils::Str; -use parking_lot::Mutex; use redis_protocol::resp3::types::RespVersion; use std::{ convert::TryFrom, @@ -22,10 +22,8 @@ use std::{ fmt::Formatter, mem, str, - sync::{atomic::AtomicBool, Arc}, time::{Duration, Instant}, }; -use tokio::sync::oneshot::{channel as oneshot_channel, Receiver as OneshotReceiver, Sender as OneshotSender}; #[cfg(feature = "mocks")] use crate::modules::mocks::MockCommand; @@ -33,9 +31,7 @@ use crate::modules::mocks::MockCommand; use crate::trace::CommandTraces; #[cfg(feature = "debug-ids")] -use std::sync::atomic::AtomicUsize; -#[cfg(feature = "debug-ids")] -static COMMAND_COUNTER: AtomicUsize = AtomicUsize::new(0); +static COMMAND_COUNTER: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0); #[cfg(feature = "debug-ids")] pub fn command_counter() -> usize { COMMAND_COUNTER @@ -1544,7 +1540,7 @@ pub struct RedisCommand { /// Some commands store arguments differently. Callers should use `self.args()` to account for this. pub arguments: Vec, /// A oneshot sender used to communicate with the router. - pub router_tx: Arc>>, + pub router_tx: RefCount>>, /// The number of times the command has been written to a socket. pub write_attempts: u32, /// The number of write attempts remaining. @@ -1564,7 +1560,7 @@ pub struct RedisCommand { /// The timeout duration provided by the `with_options` interface. pub timeout_dur: Option, /// Whether the command has timed out from the perspective of the caller. - pub timed_out: Arc, + pub timed_out: RefCount, /// A timestamp of when the command was last written to the socket. pub network_start: Option, /// Whether to route the command to a replica, if possible. @@ -1625,11 +1621,11 @@ impl From<(RedisCommandKind, Vec)> for RedisCommand { RedisCommand { kind, arguments, - timed_out: Arc::new(AtomicBool::new(false)), + timed_out: RefCount::new(AtomicBool::new(false)), timeout_dur: None, response: ResponseKind::Respond(None), hasher: ClusterHash::default(), - router_tx: Arc::new(Mutex::new(None)), + router_tx: RefCount::new(Mutex::new(None)), attempts_remaining: 0, redirections_remaining: 0, can_pipeline: true, @@ -1658,10 +1654,10 @@ impl From<(RedisCommandKind, Vec, ResponseSender)> for RedisCommand kind, arguments, response: ResponseKind::Respond(Some(tx)), - timed_out: Arc::new(AtomicBool::new(false)), + timed_out: RefCount::new(AtomicBool::new(false)), timeout_dur: None, hasher: ClusterHash::default(), - router_tx: Arc::new(Mutex::new(None)), + router_tx: RefCount::new(Mutex::new(None)), attempts_remaining: 0, redirections_remaining: 0, can_pipeline: true, @@ -1690,10 +1686,10 @@ impl From<(RedisCommandKind, Vec, ResponseKind)> for RedisCommand { kind, arguments, response, - timed_out: Arc::new(AtomicBool::new(false)), + timed_out: RefCount::new(AtomicBool::new(false)), timeout_dur: None, hasher: ClusterHash::default(), - router_tx: Arc::new(Mutex::new(None)), + router_tx: RefCount::new(Mutex::new(None)), attempts_remaining: 0, redirections_remaining: 0, can_pipeline: true, @@ -1722,11 +1718,11 @@ impl RedisCommand { RedisCommand { kind, arguments, - timed_out: Arc::new(AtomicBool::new(false)), + timed_out: RefCount::new(AtomicBool::new(false)), timeout_dur: None, response: ResponseKind::Respond(None), hasher: ClusterHash::default(), - router_tx: Arc::new(Mutex::new(None)), + router_tx: RefCount::new(Mutex::new(None)), attempts_remaining: 0, redirections_remaining: 0, can_pipeline: true, @@ -1754,10 +1750,10 @@ impl RedisCommand { kind: RedisCommandKind::Asking, hasher: ClusterHash::Custom(hash_slot), arguments: Vec::new(), - timed_out: Arc::new(AtomicBool::new(false)), + timed_out: RefCount::new(AtomicBool::new(false)), timeout_dur: None, response: ResponseKind::Respond(None), - router_tx: Arc::new(Mutex::new(None)), + router_tx: RefCount::new(Mutex::new(None)), attempts_remaining: 0, redirections_remaining: 0, can_pipeline: true, @@ -1780,7 +1776,7 @@ impl RedisCommand { } /// Whether to pipeline the command. - pub fn should_auto_pipeline(&self, inner: &Arc, force: bool) -> bool { + pub fn should_auto_pipeline(&self, inner: &RefCount, force: bool) -> bool { let should_pipeline = force || (inner.is_pipelined() && self.can_pipeline @@ -1810,7 +1806,7 @@ impl RedisCommand { } /// Whether errors writing the command should be returned to the caller. - pub fn should_finish_with_error(&self, inner: &Arc) -> bool { + pub fn should_finish_with_error(&self, inner: &RefCount) -> bool { self.fail_fast || self.attempts_remaining == 0 || inner.policy.read().is_none() } @@ -1827,6 +1823,14 @@ impl RedisCommand { } } + pub fn in_pipelined_transaction(&self) -> bool { + self.transaction_id.is_some() && self.response.is_buffer() + } + + pub fn in_non_pipelined_transaction(&self) -> bool { + self.transaction_id.is_some() && !self.response.is_buffer() + } + pub fn decr_check_redirections(&mut self) -> Result<(), RedisError> { if self.redirections_remaining == 0 { Err(RedisError::new(RedisErrorKind::Unknown, "Too many redirections.")) @@ -1895,8 +1899,9 @@ impl RedisCommand { } /// Send a message to unblock the router loop, if necessary. - pub fn respond_to_router(&self, inner: &Arc, cmd: RouterResponse) { - if let Some(tx) = self.router_tx.lock().take() { + pub fn respond_to_router(&self, inner: &RefCount, cmd: RouterResponse) { + #[allow(unused_mut)] + if let Some(mut tx) = self.router_tx.lock().take() { if tx.send(cmd).is_err() { _debug!(inner, "Failed to unblock router loop."); } @@ -1918,7 +1923,7 @@ impl RedisCommand { /// Note: this will **not** clone the router channel. pub fn duplicate(&self, response: ResponseKind) -> Self { RedisCommand { - timed_out: Arc::new(AtomicBool::new(false)), + timed_out: RefCount::new(AtomicBool::new(false)), kind: self.kind.clone(), arguments: self.arguments.clone(), hasher: self.hasher.clone(), @@ -1947,7 +1952,7 @@ impl RedisCommand { } /// Inherit connection and perf settings from the client. - pub fn inherit_options(&mut self, inner: &Arc) { + pub fn inherit_options(&mut self, inner: &RefCount) { if self.attempts_remaining == 0 { self.attempts_remaining = inner.connection.max_command_attempts; } @@ -1996,13 +2001,14 @@ impl RedisCommand { /// Respond to the caller, taking the response channel in the process. pub fn respond_to_caller(&mut self, result: Result) { - if let Some(tx) = self.take_responder() { + #[allow(unused_mut)] + if let Some(mut tx) = self.take_responder() { let _ = tx.send(result); } } /// Finish the command, responding to both the caller and router. - pub fn finish(mut self, inner: &Arc, result: Result) { + pub fn finish(mut self, inner: &RefCount, result: Result) { self.respond_to_caller(result); self.respond_to_router(inner, RouterResponse::Continue); } @@ -2035,7 +2041,7 @@ impl RedisCommand { } /// Convert to a single frame with an array of bulk strings (or null), using a blocking task. - #[cfg(feature = "blocking-encoding")] + #[cfg(all(feature = "blocking-encoding", not(feature = "glommio")))] pub fn to_frame_blocking(&self, is_resp3: bool, blocking_threshold: usize) -> Result { let cmd_size = protocol_utils::args_size(self.args()); @@ -2093,8 +2099,8 @@ pub enum RouterCommand { Transaction { id: u64, commands: Vec, - watched: Option, abort_on_error: bool, + pipelined: bool, tx: ResponseSender, }, /// Retry a command after a `MOVED` error. @@ -2156,6 +2162,7 @@ impl RouterCommand { } /// Finish the command early with the provided error. + #[allow(unused_mut)] pub fn finish_with_error(self, error: RedisError) { match self { RouterCommand::Command(mut command) => { @@ -2167,12 +2174,12 @@ impl RouterCommand { } }, #[cfg(feature = "transactions")] - RouterCommand::Transaction { tx, .. } => { + RouterCommand::Transaction { mut tx, .. } => { if let Err(_) = tx.send(Err(error)) { warn!("Error responding early to transaction."); } }, - RouterCommand::Reconnect { tx: Some(tx), .. } => { + RouterCommand::Reconnect { tx: Some(mut tx), .. } => { if let Err(_) = tx.send(Err(error)) { warn!("Error responding early to reconnect command."); } @@ -2182,7 +2189,7 @@ impl RouterCommand { } /// Inherit settings from the configuration structs on `inner`. - pub fn inherit_options(&mut self, inner: &Arc) { + pub fn inherit_options(&mut self, inner: &RefCount) { match self { RouterCommand::Command(ref mut cmd) => { cmd.inherit_options(inner); diff --git a/src/protocol/connection.rs b/src/protocol/connection.rs index 904f9279..386182dc 100644 --- a/src/protocol/connection.rs +++ b/src/protocol/connection.rs @@ -7,6 +7,7 @@ use crate::{ types::{ProtocolFrame, Server}, utils as protocol_utils, }, + runtime::{AtomicBool, AtomicUsize, JoinHandle, RefCount}, types::InfoKind, utils as client_utils, utils, @@ -21,19 +22,29 @@ use futures::{ }; use redis_protocol::resp3::types::{BytesFrame as Resp3Frame, Resp3Frame as _Resp3Frame, RespVersion}; use semver::Version; -use socket2::SockRef; use std::{ fmt, net::SocketAddr, pin::Pin, str, - sync::{atomic::AtomicUsize, Arc}, task::{Context, Poll}, time::Duration, }; -use tokio::{net::TcpStream, task::JoinHandle}; use tokio_util::codec::Framed; +#[cfg(not(feature = "glommio"))] +use socket2::SockRef; + +#[cfg(feature = "glommio")] +use glommio::net::TcpStream as BaseTcpStream; +#[cfg(feature = "glommio")] +pub type TcpStream = crate::glommio::io_compat::TokioIO; + +#[cfg(not(feature = "glommio"))] +use tokio::net::TcpStream; +#[cfg(not(feature = "glommio"))] +use tokio::net::TcpStream as BaseTcpStream; + #[cfg(feature = "unix-sockets")] use crate::prelude::ServerConfig; #[cfg(any( @@ -43,19 +54,15 @@ use crate::prelude::ServerConfig; ))] use crate::protocol::tls::TlsConnector; #[cfg(feature = "replicas")] -use crate::{ - protocol::{connection, responders::ResponseKind}, - types::RedisValue, -}; +use crate::runtime::oneshot_channel; +#[cfg(feature = "replicas")] +use crate::{protocol::responders::ResponseKind, types::RedisValue}; #[cfg(feature = "unix-sockets")] use std::path::Path; -use std::sync::atomic::AtomicBool; #[cfg(any(feature = "enable-rustls", feature = "enable-rustls-ring"))] use std::{convert::TryInto, ops::Deref}; #[cfg(feature = "unix-sockets")] use tokio::net::UnixStream; -#[cfg(feature = "replicas")] -use tokio::sync::oneshot::channel as oneshot_channel; #[cfg(feature = "enable-native-tls")] use tokio_native_tls::TlsStream as NativeTlsStream; #[cfg(any(feature = "enable-rustls", feature = "enable-rustls-ring"))] @@ -71,15 +78,15 @@ pub type CommandBuffer = Vec; /// A shared buffer across tasks. #[derive(Clone, Debug)] pub struct SharedBuffer { - inner: Arc>, - blocked: Arc, + inner: RefCount>, + blocked: RefCount, } impl SharedBuffer { pub fn new() -> Self { SharedBuffer { - inner: Arc::new(SegQueue::new()), - blocked: Arc::new(AtomicBool::new(false)), + inner: RefCount::new(SegQueue::new()), + blocked: RefCount::new(AtomicBool::new(false)), } } @@ -122,7 +129,7 @@ pub type SplitRedisStream = SplitStream>; /// Connect to each socket addr and return the first successful connection. async fn tcp_connect_any( - inner: &Arc, + inner: &RefCount, server: &Server, addrs: &Vec, ) -> Result<(TcpStream, SocketAddr), RedisError> { @@ -136,7 +143,7 @@ async fn tcp_connect_any( addr.ip(), addr.port() ); - let socket = match TcpStream::connect(addr).await { + let socket = match BaseTcpStream::connect(addr).await { Ok(socket) => socket, Err(e) => { _debug!(inner, "Error connecting to {}: {:?}", addr, e); @@ -147,16 +154,24 @@ async fn tcp_connect_any( if let Some(val) = inner.connection.tcp.nodelay { socket.set_nodelay(val)?; } - if let Some(dur) = inner.connection.tcp.linger { - socket.set_linger(Some(dur))?; + if let Some(_dur) = inner.connection.tcp.linger { + #[cfg(not(feature = "glommio"))] + socket.set_linger(Some(_dur))?; + #[cfg(feature = "glommio")] + _warn!(inner, "TCP Linger is not yet supported with Glommio features."); } if let Some(ttl) = inner.connection.tcp.ttl { socket.set_ttl(ttl)?; } - if let Some(ref keepalive) = inner.connection.tcp.keepalive { - SockRef::from(&socket).set_tcp_keepalive(keepalive)?; + if let Some(ref _keepalive) = inner.connection.tcp.keepalive { + #[cfg(not(feature = "glommio"))] + SockRef::from(&socket).set_tcp_keepalive(_keepalive)?; + #[cfg(feature = "glommio")] + _warn!(inner, "TCP keepalive is not yet supported with Glommio features."); } + #[cfg(feature = "glommio")] + let socket = crate::glommio::io_compat::TokioIO(socket); return Ok((socket, *addr)); } @@ -382,24 +397,25 @@ impl Sink for SplitSinkKind { } /// Atomic counters stored with connection state. +// TODO with glommio these don't need to be atomics #[derive(Clone, Debug)] pub struct Counters { - pub cmd_buffer_len: Arc, - pub in_flight: Arc, - pub feed_count: Arc, + pub cmd_buffer_len: RefCount, + pub in_flight: RefCount, + pub feed_count: RefCount, } impl Counters { - pub fn new(cmd_buffer_len: &Arc) -> Self { + pub fn new(cmd_buffer_len: &RefCount) -> Self { Counters { cmd_buffer_len: cmd_buffer_len.clone(), - in_flight: Arc::new(AtomicUsize::new(0)), - feed_count: Arc::new(AtomicUsize::new(0)), + in_flight: RefCount::new(AtomicUsize::new(0)), + feed_count: RefCount::new(AtomicUsize::new(0)), } } /// Flush the sink if the max feed count is reached or no commands are queued following the current command. - pub fn should_send(&self, inner: &Arc) -> bool { + pub fn should_send(&self, inner: &RefCount) -> bool { client_utils::read_atomic(&self.feed_count) as u64 > inner.max_feed_count() || client_utils::read_atomic(&self.cmd_buffer_len) == 0 } @@ -443,7 +459,7 @@ pub struct RedisTransport { } impl RedisTransport { - pub async fn new_tcp(inner: &Arc, server: &Server) -> Result { + pub async fn new_tcp(inner: &RefCount, server: &Server) -> Result { let counters = Counters::new(&inner.counters.cmd_buffer_len); let (id, version) = (None, None); let default_host = server.host.clone(); @@ -468,7 +484,7 @@ impl RedisTransport { } #[cfg(feature = "unix-sockets")] - pub async fn new_unix(inner: &Arc, path: &Path) -> Result { + pub async fn new_unix(inner: &RefCount, path: &Path) -> Result { _debug!(inner, "Connecting via unix socket to {}", utils::path_to_string(path)); let server = Server::new(utils::path_to_string(path), 0); let counters = Counters::new(&inner.counters.cmd_buffer_len); @@ -491,7 +507,10 @@ impl RedisTransport { #[cfg(feature = "enable-native-tls")] #[allow(unreachable_patterns)] - pub async fn new_native_tls(inner: &Arc, server: &Server) -> Result { + pub async fn new_native_tls( + inner: &RefCount, + server: &Server, + ) -> Result { let connector = match inner.config.tls { Some(ref config) => match config.connector { TlsConnector::Native(ref connector) => connector.clone(), @@ -529,13 +548,16 @@ impl RedisTransport { } #[cfg(not(feature = "enable-native-tls"))] - pub async fn new_native_tls(inner: &Arc, server: &Server) -> Result { + pub async fn new_native_tls( + inner: &RefCount, + server: &Server, + ) -> Result { RedisTransport::new_tcp(inner, server).await } #[cfg(any(feature = "enable-rustls", feature = "enable-rustls-ring"))] #[allow(unreachable_patterns)] - pub async fn new_rustls(inner: &Arc, server: &Server) -> Result { + pub async fn new_rustls(inner: &RefCount, server: &Server) -> Result { use rustls::pki_types::ServerName; let connector = match inner.config.tls { @@ -576,7 +598,7 @@ impl RedisTransport { } #[cfg(not(any(feature = "enable-rustls", feature = "enable-rustls-ring")))] - pub async fn new_rustls(inner: &Arc, server: &Server) -> Result { + pub async fn new_rustls(inner: &RefCount, server: &Server) -> Result { RedisTransport::new_tcp(inner, server).await } @@ -592,7 +614,7 @@ impl RedisTransport { } /// Set the client name with `CLIENT SETNAME`. - pub async fn set_client_name(&mut self, inner: &Arc) -> Result<(), RedisError> { + pub async fn set_client_name(&mut self, inner: &RefCount) -> Result<(), RedisError> { _debug!(inner, "Setting client name."); let name = &inner.id; let command = RedisCommand::new(RedisCommandKind::ClientSetname, vec![name.clone().into()]); @@ -608,7 +630,7 @@ impl RedisTransport { } /// Read and cache the server version. - pub async fn cache_server_version(&mut self, inner: &Arc) -> Result<(), RedisError> { + pub async fn cache_server_version(&mut self, inner: &RefCount) -> Result<(), RedisError> { let command = RedisCommand::new(RedisCommandKind::Info, vec![InfoKind::Server.to_str().into()]); let result = self.request_response(command, inner.is_resp3()).await?; let result = match result { @@ -679,11 +701,13 @@ impl RedisTransport { } /// Authenticate via HELLO in RESP3 mode or AUTH in RESP2 mode, then set the client name. - pub async fn switch_protocols_and_authenticate(&mut self, inner: &Arc) -> Result<(), RedisError> { + pub async fn switch_protocols_and_authenticate( + &mut self, + inner: &RefCount, + ) -> Result<(), RedisError> { // reset the protocol version to the one specified by the config when we create new connections inner.reset_protocol_version(); - let username = inner.config.username.clone(); - let password = inner.config.password.clone(); + let (username, password) = inner.read_credentials(&self.server).await?; if inner.is_resp3() { _debug!(inner, "Switching to RESP3 protocol with HELLO..."); @@ -710,7 +734,7 @@ impl RedisTransport { } /// Read and cache the connection ID. - pub async fn cache_connection_id(&mut self, inner: &Arc) -> Result<(), RedisError> { + pub async fn cache_connection_id(&mut self, inner: &RefCount) -> Result<(), RedisError> { let command = (RedisCommandKind::ClientID, vec![]).into(); let result = self.request_response(command, inner.is_resp3()).await; _debug!(inner, "Read client ID: {:?}", result); @@ -723,7 +747,7 @@ impl RedisTransport { } /// Send `PING` to the server. - pub async fn ping(&mut self, inner: &Arc) -> Result<(), RedisError> { + pub async fn ping(&mut self, inner: &RefCount) -> Result<(), RedisError> { let command = RedisCommandKind::Ping.into(); let response = self.request_response(command, inner.is_resp3()).await?; @@ -735,7 +759,7 @@ impl RedisTransport { } /// Send `QUIT` and close the connection. - pub async fn disconnect(&mut self, inner: &Arc) -> Result<(), RedisError> { + pub async fn disconnect(&mut self, inner: &RefCount) -> Result<(), RedisError> { if let Err(e) = self.transport.close().await { _warn!(inner, "Error closing connection to {}: {:?}", self.server, e); } @@ -743,7 +767,7 @@ impl RedisTransport { } /// Select the database provided in the `RedisConfig`. - pub async fn select_database(&mut self, inner: &Arc) -> Result<(), RedisError> { + pub async fn select_database(&mut self, inner: &RefCount) -> Result<(), RedisError> { if inner.config.server.is_clustered() { return Ok(()); } @@ -767,7 +791,7 @@ impl RedisTransport { /// Check the `cluster_state` via `CLUSTER INFO`. /// /// Returns an error if the state is not `ok`. - pub async fn check_cluster_state(&mut self, inner: &Arc) -> Result<(), RedisError> { + pub async fn check_cluster_state(&mut self, inner: &RefCount) -> Result<(), RedisError> { if !inner.config.server.is_clustered() { return Ok(()); } @@ -792,12 +816,19 @@ impl RedisTransport { /// Authenticate, set the protocol version, set the client name, select the provided database, cache the /// connection ID and server version, and check the cluster state (if applicable). - pub async fn setup(&mut self, inner: &Arc, timeout: Option) -> Result<(), RedisError> { + pub async fn setup( + &mut self, + inner: &RefCount, + timeout: Option, + ) -> Result<(), RedisError> { let timeout = timeout.unwrap_or(inner.internal_command_timeout()); + let has_credentials = inner.config.password.is_some() || inner.config.version == RespVersion::RESP3; + #[cfg(feature = "credential-provider")] + let has_credentials = has_credentials || inner.config.credential_provider.is_some(); utils::timeout( async { - if inner.config.password.is_some() || inner.config.version == RespVersion::RESP3 { + if has_credentials { self.switch_protocols_and_authenticate(inner).await?; } else { self.ping(inner).await?; @@ -823,7 +854,7 @@ impl RedisTransport { #[cfg(feature = "replicas")] pub async fn readonly( &mut self, - inner: &Arc, + inner: &RefCount, timeout: Option, ) -> Result<(), RedisError> { if !inner.config.server.is_clustered() { @@ -849,7 +880,7 @@ impl RedisTransport { #[cfg(feature = "replicas")] pub async fn role( &mut self, - inner: &Arc, + inner: &RefCount, timeout: Option, ) -> Result { let timeout = timeout.unwrap_or(inner.internal_command_timeout()); @@ -869,7 +900,7 @@ impl RedisTransport { /// Discover connected replicas via the ROLE command. #[cfg(feature = "replicas")] - pub async fn discover_replicas(&mut self, inner: &Arc) -> Result, RedisError> { + pub async fn discover_replicas(&mut self, inner: &RefCount) -> Result, RedisError> { self .role(inner, None) .await @@ -878,7 +909,7 @@ impl RedisTransport { /// Discover connected replicas via the ROLE command. #[cfg(not(feature = "replicas"))] - pub async fn discover_replicas(&mut self, _: &Arc) -> Result, RedisError> { + pub async fn discover_replicas(&mut self, _: &RefCount) -> Result, RedisError> { Ok(Vec::new()) } @@ -980,9 +1011,9 @@ impl RedisWriter { } #[cfg(feature = "replicas")] - pub async fn discover_replicas(&mut self, inner: &Arc) -> Result, RedisError> { + pub async fn discover_replicas(&mut self, inner: &RefCount) -> Result, RedisError> { let command = RedisCommand::new(RedisCommandKind::Role, vec![]); - let role = connection::request_response(inner, self, command, None) + let role = request_response(inner, self, command, None) .await .and_then(protocol_utils::frame_to_results)?; @@ -1031,7 +1062,7 @@ impl RedisWriter { } /// Put a command at the back of the command queue. - pub fn push_command(&self, inner: &Arc, mut cmd: RedisCommand) { + pub fn push_command(&self, inner: &RefCount, mut cmd: RedisCommand) { if cmd.has_no_responses() { _trace!( inner, @@ -1085,7 +1116,7 @@ impl RedisWriter { /// /// The returned connection will not be initialized. pub async fn create( - inner: &Arc, + inner: &RefCount, server: &Server, timeout: Option, ) -> Result { @@ -1112,14 +1143,14 @@ pub async fn create( /// Split a connection, spawn a reader task, and link the reader and writer halves. pub fn split( - inner: &Arc, + inner: &RefCount, transport: RedisTransport, is_replica: bool, func: F, ) -> Result<(Server, RedisWriter), RedisError> where F: FnOnce( - &Arc, + &RefCount, SplitStreamKind, &Server, &SharedBuffer, @@ -1154,7 +1185,7 @@ where /// Send a command to the server and wait for a response. #[cfg(feature = "replicas")] pub async fn request_response( - inner: &Arc, + inner: &RefCount, writer: &mut RedisWriter, mut command: RedisCommand, timeout: Option, diff --git a/src/protocol/responders.rs b/src/protocol/responders.rs index 0d55f329..7e25c0ce 100644 --- a/src/protocol/responders.rs +++ b/src/protocol/responders.rs @@ -7,26 +7,18 @@ use crate::{ types::{KeyScanInner, Server, ValueScanInner, ValueScanResult}, utils as protocol_utils, }, + runtime::{AtomicUsize, Mutex, RefCount}, types::{HScanResult, RedisKey, RedisValue, SScanResult, ScanResult, ZScanResult}, utils as client_utils, }; use bytes_utils::Str; -use parking_lot::Mutex; -use redis_protocol::resp3::types::Resp3Frame as _Resp3Frame; -use std::{ - fmt, - fmt::Formatter, - iter::repeat, - mem, - ops::DerefMut, - sync::{atomic::AtomicUsize, Arc}, -}; +use redis_protocol::resp3::types::{FrameKind, Resp3Frame as _Resp3Frame}; +use std::{fmt, fmt::Formatter, iter::repeat, mem, ops::DerefMut}; #[cfg(feature = "metrics")] use crate::modules::metrics::MovingStats; #[cfg(feature = "metrics")] -use parking_lot::RwLock; -use redis_protocol::resp3::types::FrameKind; +use crate::runtime::RwLock; #[cfg(feature = "metrics")] use std::{cmp, time::Instant}; @@ -48,13 +40,13 @@ pub enum ResponseKind { /// cluster connections. Buffer { /// A shared buffer for response frames. - frames: Arc>>, + frames: RefCount>>, /// The expected number of response frames. expected: usize, /// The number of response frames received. - received: Arc, + received: RefCount, /// A shared oneshot channel to the caller. - tx: Arc>>, + tx: RefCount>>, /// A local field for tracking the expected index of the response in the `frames` array. index: usize, /// Whether errors should be returned early to the caller. @@ -123,9 +115,9 @@ impl ResponseKind { pub fn new_buffer(tx: ResponseSender) -> Self { ResponseKind::Buffer { - frames: Arc::new(Mutex::new(vec![])), - tx: Arc::new(Mutex::new(Some(tx))), - received: Arc::new(AtomicUsize::new(0)), + frames: RefCount::new(Mutex::new(vec![])), + tx: RefCount::new(Mutex::new(Some(tx))), + received: RefCount::new(AtomicUsize::new(0)), index: 0, expected: 0, error_early: true, @@ -135,9 +127,9 @@ impl ResponseKind { pub fn new_buffer_with_size(expected: usize, tx: ResponseSender) -> Self { let frames = repeat(Resp3Frame::Null).take(expected).collect(); ResponseKind::Buffer { - frames: Arc::new(Mutex::new(frames)), - tx: Arc::new(Mutex::new(Some(tx))), - received: Arc::new(AtomicUsize::new(0)), + frames: RefCount::new(Mutex::new(frames)), + tx: RefCount::new(Mutex::new(Some(tx))), + received: RefCount::new(AtomicUsize::new(0)), index: 0, error_early: true, expected, @@ -154,7 +146,7 @@ impl ResponseKind { } /// Clone the shared response sender for `Buffer` or `Multiple` variants. - pub fn clone_shared_response_tx(&self) -> Option>>> { + pub fn clone_shared_response_tx(&self) -> Option>>> { match self { ResponseKind::Buffer { tx, .. } => Some(tx.clone()), _ => None, @@ -176,6 +168,11 @@ impl ResponseKind { ResponseKind::ValueScan(_) | ResponseKind::KeyScan(_) => 1, } } + + /// Whether the responder is a `ResponseKind::Buffer`. + pub fn is_buffer(&self) -> bool { + matches!(self, ResponseKind::Buffer { .. }) + } } #[cfg(feature = "metrics")] @@ -187,7 +184,7 @@ fn sample_latency(latency_stats: &RwLock, sent: Instant) { /// Sample overall and network latency values for a command. #[cfg(feature = "metrics")] -fn sample_command_latencies(inner: &Arc, command: &mut RedisCommand) { +fn sample_command_latencies(inner: &RefCount, command: &mut RedisCommand) { if let Some(sent) = command.network_start.take() { sample_latency(&inner.network_latency_stats, sent); } @@ -195,10 +192,10 @@ fn sample_command_latencies(inner: &Arc, command: &mut RedisCo } #[cfg(not(feature = "metrics"))] -fn sample_command_latencies(_: &Arc, _: &mut RedisCommand) {} +fn sample_command_latencies(_: &RefCount, _: &mut RedisCommand) {} /// Update the client's protocol version codec version after receiving a non-error response to HELLO. -fn update_protocol_version(inner: &Arc, command: &RedisCommand, frame: &Resp3Frame) { +fn update_protocol_version(inner: &RefCount, command: &RedisCommand, frame: &Resp3Frame) { if !matches!(frame.kind(), FrameKind::SimpleError | FrameKind::BlobError) { let version = match command.kind { RedisCommandKind::_Hello(ref version) => version, @@ -213,8 +210,8 @@ fn update_protocol_version(inner: &Arc, command: &RedisCommand } fn respond_locked( - inner: &Arc, - tx: &Arc>>, + inner: &RefCount, + tx: &RefCount>>, result: Result, ) { if let Some(tx) = tx.lock().take() { @@ -226,7 +223,7 @@ fn respond_locked( fn add_buffered_frame( server: &Server, - buffer: &Arc>>, + buffer: &RefCount>>, index: usize, frame: Resp3Frame, ) -> Result<(), RedisError> { @@ -366,7 +363,7 @@ fn parse_value_scan_frame(frame: Resp3Frame) -> Result<(Str, Vec), R /// Send the output to the caller of a command that scans values. fn send_value_scan_result( - inner: &Arc, + inner: &RefCount, scanner: ValueScanInner, command: &RedisCommand, result: Vec, @@ -430,7 +427,7 @@ fn send_value_scan_result( /// Respond to the caller with the default response policy. pub fn respond_to_caller( - inner: &Arc, + inner: &RefCount, server: &Server, mut command: RedisCommand, tx: ResponseSender, @@ -456,15 +453,15 @@ pub fn respond_to_caller( /// Respond to the caller, assuming multiple response frames from the last command, storing intermediate responses in /// the shared buffer. pub fn respond_buffer( - inner: &Arc, + inner: &RefCount, server: &Server, command: RedisCommand, - received: Arc, + received: RefCount, expected: usize, error_early: bool, - frames: Arc>>, + frames: RefCount>>, index: usize, - tx: Arc>>, + tx: RefCount>>, frame: Resp3Frame, ) -> Result<(), RedisError> { _trace!( @@ -476,22 +473,30 @@ pub fn respond_buffer( index, command.debug_id() ); + let closes_connection = command.kind.closes_connection(); // errors are buffered like normal frames and are not returned early if let Err(e) = add_buffered_frame(server, &frames, index, frame) { - respond_locked(inner, &tx, Err(e)); - command.respond_to_router(inner, RouterResponse::Continue); - _error!( - inner, - "Exiting early after unexpected buffer response index from {} with command {}, ID {}", - server, - command.kind.to_str_debug(), - command.debug_id() - ); - return Err(RedisError::new( - RedisErrorKind::Unknown, - "Invalid buffer response index.", - )); + if closes_connection { + _debug!(inner, "Ignoring unexpected buffer response index from QUIT or SHUTDOWN"); + respond_locked(inner, &tx, Err(RedisError::new_canceled())); + command.respond_to_router(inner, RouterResponse::Continue); + return Err(RedisError::new_canceled()); + } else { + respond_locked(inner, &tx, Err(e)); + command.respond_to_router(inner, RouterResponse::Continue); + _error!( + inner, + "Exiting early after unexpected buffer response index from {} with command {}, ID {}", + server, + command.kind.to_str_debug(), + command.debug_id() + ); + return Err(RedisError::new( + RedisErrorKind::Unknown, + "Invalid buffer response index.", + )); + } } // this must come after adding the buffered frame. there's a potential race condition if this task is interrupted @@ -537,7 +542,7 @@ pub fn respond_buffer( /// Respond to the caller of a key scanning operation. pub fn respond_key_scan( - inner: &Arc, + inner: &RefCount, server: &Server, command: RedisCommand, mut scanner: KeyScanInner, @@ -577,7 +582,7 @@ pub fn respond_key_scan( /// Respond to the caller of a value scanning operation. pub fn respond_value_scan( - inner: &Arc, + inner: &RefCount, server: &Server, command: RedisCommand, mut scanner: ValueScanInner, diff --git a/src/protocol/types.rs b/src/protocol/types.rs index ef714912..194fd013 100644 --- a/src/protocol/types.rs +++ b/src/protocol/types.rs @@ -4,6 +4,7 @@ use crate::{ modules::inner::RedisClientInner, prelude::RedisResult, protocol::{cluster, utils::server_to_parts}, + runtime::{RefCount, UnboundedSender}, types::*, utils, }; @@ -18,9 +19,7 @@ use std::{ fmt::{Display, Formatter}, hash::{Hash, Hasher}, net::{SocketAddr, ToSocketAddrs}, - sync::Arc, }; -use tokio::sync::mpsc::UnboundedSender; #[cfg(any( feature = "enable-rustls", @@ -514,7 +513,7 @@ impl ClusterRouting { /// Rebuild the cache in place with the output of a `CLUSTER SLOTS` command. pub(crate) fn rebuild( &mut self, - inner: &Arc, + inner: &RefCount, cluster_slots: RedisValue, default_host: &Str, ) -> Result<(), RedisError> { @@ -598,16 +597,6 @@ impl ClusterRouting { } } -/// A trait that can be used to override DNS resolution logic. -/// -/// Note: currently this requires [async-trait](https://crates.io/crates/async-trait). -#[async_trait] -#[cfg_attr(docsrs, doc(cfg(feature = "dns")))] -pub trait Resolve: Send + Sync + 'static { - /// Resolve a hostname. - async fn resolve(&self, host: Str, port: u16) -> RedisResult>; -} - /// Default DNS resolver that uses [to_socket_addrs](std::net::ToSocketAddrs::to_socket_addrs). #[derive(Clone, Debug)] pub struct DefaultResolver { @@ -621,6 +610,54 @@ impl DefaultResolver { } } +/// A trait that can be used to override DNS resolution logic. +/// +/// Note: currently this requires [async-trait](https://crates.io/crates/async-trait). +#[cfg(feature = "glommio")] +#[async_trait(?Send)] +#[cfg_attr(docsrs, doc(cfg(feature = "dns")))] +pub trait Resolve: 'static { + /// Resolve a hostname. + async fn resolve(&self, host: Str, port: u16) -> RedisResult>; +} + +#[cfg(feature = "glommio")] +#[async_trait(?Send)] +impl Resolve for DefaultResolver { + async fn resolve(&self, host: Str, port: u16) -> RedisResult> { + let client_id = self.id.clone(); + + // glommio users should probably use a non-blocking impl such as hickory-dns + crate::runtime::spawn(async move { + let addr = format!("{}:{}", host, port); + let ips: Vec = addr.to_socket_addrs()?.collect(); + + if ips.is_empty() { + Err(RedisError::new( + RedisErrorKind::IO, + format!("Failed to resolve {}:{}", host, port), + )) + } else { + trace!("{}: Found {} addresses for {}", client_id, ips.len(), addr); + Ok(ips) + } + }) + .await? + } +} + +/// A trait that can be used to override DNS resolution logic. +/// +/// Note: currently this requires [async-trait](https://crates.io/crates/async-trait). +#[cfg(not(feature = "glommio"))] +#[async_trait] +#[cfg_attr(docsrs, doc(cfg(feature = "dns")))] +pub trait Resolve: Send + Sync + 'static { + /// Resolve a hostname. + async fn resolve(&self, host: Str, port: u16) -> RedisResult>; +} + +#[cfg(not(feature = "glommio"))] #[async_trait] impl Resolve for DefaultResolver { async fn resolve(&self, host: Str, port: u16) -> RedisResult> { diff --git a/src/protocol/utils.rs b/src/protocol/utils.rs index 0bf4a5e0..9538d48c 100644 --- a/src/protocol/utils.rs +++ b/src/protocol/utils.rs @@ -7,6 +7,7 @@ use crate::{ connection::OK, types::{ProtocolFrame, *}, }, + runtime::RefCount, types::*, utils, }; @@ -17,7 +18,7 @@ use redis_protocol::{ resp3::types::{BytesFrame as Resp3Frame, Resp3Frame as _Resp3Frame}, types::{PUBSUB_PUSH_PREFIX, REDIS_CLUSTER_SLOTS}, }; -use std::{borrow::Cow, collections::HashMap, convert::TryInto, ops::Deref, str, sync::Arc}; +use std::{borrow::Cow, collections::HashMap, convert::TryInto, ops::Deref, str}; #[cfg(any(feature = "i-lists", feature = "i-sorted-sets"))] use redis_protocol::resp3::types::FrameKind; @@ -947,13 +948,17 @@ pub fn command_to_frame(command: &RedisCommand, is_resp3: bool) -> Result, command: &RedisCommand) -> Result { - #[cfg(feature = "blocking-encoding")] +pub fn encode_frame(inner: &RefCount, command: &RedisCommand) -> Result { + #[cfg(all(feature = "blocking-encoding", not(feature = "glommio")))] return command.to_frame_blocking( inner.is_resp3(), inner.with_perf_config(|c| c.blocking_encode_threshold), ); - #[cfg(not(feature = "blocking-encoding"))] + + #[cfg(any( + not(feature = "blocking-encoding"), + all(feature = "blocking-encoding", feature = "glommio") + ))] return command.to_frame(inner.is_resp3()); } diff --git a/src/router/centralized.rs b/src/router/centralized.rs index d3053d06..70af33b5 100644 --- a/src/router/centralized.rs +++ b/src/router/centralized.rs @@ -11,14 +11,14 @@ use crate::{ utils as protocol_utils, }, router::{responses, utils, Connections, Written}, + runtime::{spawn, JoinHandle, RefCount}, types::ServerConfig, }; use redis_protocol::resp3::types::{BytesFrame as Resp3Frame, Resp3Frame as _Resp3Frame}; -use std::{collections::VecDeque, sync::Arc}; -use tokio::task::JoinHandle; +use std::collections::VecDeque; pub async fn write( - inner: &Arc, + inner: &RefCount, writer: &mut Option, command: RedisCommand, force_flush: bool, @@ -38,7 +38,7 @@ pub async fn write( /// Spawn a task to read response frames from the reader half of the socket. #[allow(unused_assignments)] pub fn spawn_reader_task( - inner: &Arc, + inner: &RefCount, mut reader: SplitStreamKind, server: &Server, buffer: &SharedBuffer, @@ -47,8 +47,10 @@ pub fn spawn_reader_task( ) -> JoinHandle> { let (inner, server) = (inner.clone(), server.clone()); let (buffer, counters) = (buffer.clone(), counters.clone()); + #[cfg(feature = "glommio")] + let tq = inner.connection.connection_task_queue; - tokio::spawn(async move { + let reader_ft = async move { let mut last_error = None; loop { @@ -91,14 +93,23 @@ pub fn spawn_reader_task( _debug!(inner, "Ending reader task from {}", server); Ok(()) - }) + }; + + #[cfg(feature = "glommio")] + if let Some(tq) = tq { + crate::runtime::spawn_into(reader_ft, tq) + } else { + spawn(reader_ft) + } + #[cfg(not(feature = "glommio"))] + spawn(reader_ft) } /// Process the response frame in the context of the last command. /// /// Errors returned here will be logged, but will not close the socket or initiate a reconnect. pub async fn process_response_frame( - inner: &Arc, + inner: &RefCount, server: &Server, buffer: &SharedBuffer, counters: &Counters, @@ -129,9 +140,13 @@ pub async fn process_response_frame( } responses::check_and_set_unblocked_flag(inner, &command).await; - if command.transaction_id.is_some() { + // non-pipelined transactions use ResponseKind::Skip, pipelined ones use a buffer. non-pipelined transactions + // need to retry commands in a special way so this logic forwards the result via the latest command's router + // response channel and exits early. pipelined transactions use the normal buffered response process below. + if command.in_non_pipelined_transaction() { if let Some(error) = protocol_utils::frame_to_error(&frame) { - if let Some(tx) = command.take_router_tx() { + #[allow(unused_mut)] + if let Some(mut tx) = command.take_router_tx() { let _ = tx.send(RouterResponse::TransactionError((error, command))); } return Ok(()); @@ -144,7 +159,6 @@ pub async fn process_response_frame( } } - // TODO clean this up _trace!(inner, "Handling centralized response kind: {:?}", command.response); match command.take_response() { ResponseKind::Skip | ResponseKind::Respond(None) => { @@ -179,7 +193,7 @@ pub async fn process_response_frame( /// Initialize fresh connections to the server, dropping any old connections and saving in-flight commands on /// `buffer`. pub async fn initialize_connection( - inner: &Arc, + inner: &RefCount, connections: &mut Connections, buffer: &mut VecDeque, ) -> Result<(), RedisError> { diff --git a/src/router/clustered.rs b/src/router/clustered.rs index 34f2a1fb..8fe7ea93 100644 --- a/src/router/clustered.rs +++ b/src/router/clustered.rs @@ -11,22 +11,20 @@ use crate::{ utils as protocol_utils, }, router::{responses, types::ClusterChange, utils, Connections, Written}, + runtime::{spawn, JoinHandle, Mutex, RefCount}, types::{ClusterDiscoveryPolicy, ClusterStateChange}, utils as client_utils, }; use futures::future::try_join_all; -use parking_lot::Mutex; use redis_protocol::resp3::types::{BytesFrame as Resp3Frame, FrameKind, Resp3Frame as _Resp3Frame}; use std::{ collections::{BTreeSet, HashMap, VecDeque}, iter::repeat, - sync::Arc, }; -use tokio::task::JoinHandle; /// Find the cluster node that should receive the command. pub fn route_command<'a>( - inner: &Arc, + inner: &RefCount, state: &'a ClusterRouting, command: &RedisCommand, ) -> Option<&'a Server> { @@ -66,7 +64,7 @@ pub fn route_command<'a>( /// Write a command to the cluster according to the [cluster hashing](https://redis.io/docs/reference/cluster-spec/) interface. pub async fn write( - inner: &Arc, + inner: &RefCount, writers: &mut HashMap, state: &ClusterRouting, command: RedisCommand, @@ -129,7 +127,7 @@ pub async fn write( // There's probably a much cleaner way to express this. Most of the complexity here comes from the need to // pre-allocate and assign response locations in the buffer ahead of time. This is done to avoid any race conditions. pub async fn send_all_cluster_command( - inner: &Arc, + inner: &RefCount, writers: &mut HashMap, mut command: RedisCommand, ) -> Result<(), RedisError> { @@ -210,7 +208,7 @@ pub fn parse_cluster_changes( ClusterChange { add, remove } } -pub fn broadcast_cluster_change(inner: &Arc, changes: &ClusterChange) { +pub fn broadcast_cluster_change(inner: &RefCount, changes: &ClusterChange) { let mut added: Vec = changes .add .iter() @@ -235,7 +233,7 @@ pub fn broadcast_cluster_change(inner: &Arc, changes: &Cluster /// Spawn a task to read response frames from the reader half of the socket. #[allow(unused_assignments)] pub fn spawn_reader_task( - inner: &Arc, + inner: &RefCount, mut reader: SplitStreamKind, server: &Server, buffer: &SharedBuffer, @@ -244,8 +242,10 @@ pub fn spawn_reader_task( ) -> JoinHandle> { let (inner, server) = (inner.clone(), server.clone()); let (buffer, counters) = (buffer.clone(), counters.clone()); + #[cfg(feature = "glommio")] + let tq = inner.connection.connection_task_queue; - tokio::spawn(async move { + let reader_ft = async move { let mut last_error = None; loop { @@ -290,14 +290,47 @@ pub fn spawn_reader_task( _debug!(inner, "Ending reader task from {}", server); Ok(()) - }) + }; + + #[cfg(feature = "glommio")] + if let Some(tq) = tq { + crate::runtime::spawn_into(reader_ft, tq) + } else { + spawn(reader_ft) + } + #[cfg(not(feature = "glommio"))] + spawn(reader_ft) +} + +/// Parse a cluster redirection frame from the provided server, returning the new destination node info. +pub fn parse_cluster_error_frame( + inner: &RefCount, + frame: &Resp3Frame, + server: &Server, +) -> Result<(ClusterErrorKind, u16, Server), RedisError> { + let (kind, slot, server_str) = match frame.as_str() { + Some(data) => protocol_utils::parse_cluster_error(data)?, + None => return Err(RedisError::new(RedisErrorKind::Protocol, "Invalid cluster error.")), + }; + let server = match Server::from_parts(&server_str, &server.host) { + Some(server) => server, + None => { + _warn!(inner, "Invalid server field in cluster error: {}", server_str); + return Err(RedisError::new( + RedisErrorKind::Protocol, + "Invalid cluster redirection error.", + )); + }, + }; + + Ok((kind, slot, server)) } /// Send a MOVED or ASK command to the router, using the router channel if possible and falling back on the /// command queue if appropriate. -// Cluster errors within a transaction can only be handled via the blocking router channel. +// Cluster errors within a non-pipelined transaction can only be handled via the blocking router channel. fn process_cluster_error( - inner: &Arc, + inner: &RefCount, server: &Server, mut command: RedisCommand, frame: Resp3Frame, @@ -305,35 +338,17 @@ fn process_cluster_error( // commands are not redirected to replica nodes command.use_replica = false; - let (kind, slot, server_str) = match frame.as_str() { - Some(data) => match protocol_utils::parse_cluster_error(data) { - Ok(result) => result, - Err(e) => { - command.respond_to_router(inner, RouterResponse::Continue); - command.respond_to_caller(Err(e)); - return; - }, - }, - None => { - command.respond_to_router(inner, RouterResponse::Continue); - command.respond_to_caller(Err(RedisError::new(RedisErrorKind::Protocol, "Invalid cluster error."))); - return; - }, - }; - let server = match Server::from_parts(&server_str, &server.host) { - Some(server) => server, - None => { - _warn!(inner, "Invalid server field in cluster error: {}", server_str); + let (kind, slot, server) = match parse_cluster_error_frame(inner, &frame, server) { + Ok(results) => results, + Err(e) => { command.respond_to_router(inner, RouterResponse::Continue); - command.respond_to_caller(Err(RedisError::new( - RedisErrorKind::Cluster, - "Invalid cluster redirection error.", - ))); + command.respond_to_caller(Err(e)); return; }, }; - if let Some(tx) = command.take_router_tx() { + #[allow(unused_mut)] + if let Some(mut tx) = command.take_router_tx() { let response = match kind { ClusterErrorKind::Ask => RouterResponse::Ask((slot, server, command)), ClusterErrorKind::Moved => RouterResponse::Moved((slot, server, command)), @@ -341,6 +356,9 @@ fn process_cluster_error( _debug!(inner, "Sending cluster error to router channel."); if let Err(response) = tx.send(response) { + #[cfg(feature = "glommio")] + let response = response.into_inner(); + // if it could not be sent on the router tx then send it on the command channel let command = match response { RouterResponse::Ask((slot, server, command)) => { @@ -397,7 +415,7 @@ fn process_cluster_error( /// /// Errors returned here will be logged, but will not close the socket or initiate a reconnect. pub async fn process_response_frame( - inner: &Arc, + inner: &RefCount, server: &Server, buffer: &SharedBuffer, counters: &Counters, @@ -428,7 +446,8 @@ pub async fn process_response_frame( } responses::check_and_set_unblocked_flag(inner, &command).await; - if frame.is_redirection() { + // pipelined transactions defer cluster redirections until after `EXECABORT` is received + if frame.is_redirection() && !command.in_pipelined_transaction() { _debug!( inner, "Recv MOVED or ASK error for `{}` from {}: {:?}", @@ -440,9 +459,13 @@ pub async fn process_response_frame( return Ok(()); } - if command.transaction_id.is_some() { + // non-pipelined transactions use ResponseKind::Skip, pipelined ones use a buffer. non-pipelined transactions + // need to retry commands in a special way so this logic forwards the result via the latest command's router + // response channel and exits early. pipelined transactions use the normal buffered response process below. + if command.in_non_pipelined_transaction() { if let Some(error) = protocol_utils::frame_to_error(&frame) { - if let Some(tx) = command.take_router_tx() { + #[allow(unused_mut)] + if let Some(mut tx) = command.take_router_tx() { let _ = tx.send(RouterResponse::TransactionError((error, command))); } return Ok(()); @@ -488,7 +511,7 @@ pub async fn process_response_frame( /// Try connecting to any node in the provided `RedisConfig` or `old_servers`. pub async fn connect_any( - inner: &Arc, + inner: &RefCount, old_cache: Option<&[SlotRange]>, ) -> Result { let mut all_servers: BTreeSet = if let Some(old_cache) = old_cache { @@ -537,7 +560,7 @@ pub async fn connect_any( /// /// If this returns an error then all known cluster nodes are unreachable. pub async fn cluster_slots_backchannel( - inner: &Arc, + inner: &RefCount, cache: Option<&ClusterRouting>, force_disconnect: bool, ) -> Result { @@ -650,7 +673,7 @@ pub async fn drop_broken_connections(writers: &mut HashMap) /// Run `CLUSTER SLOTS`, update the cached routing table, and modify the connection map. pub async fn sync( - inner: &Arc, + inner: &RefCount, connections: &mut Connections, buffer: &mut VecDeque, ) -> Result<(), RedisError> { @@ -693,7 +716,7 @@ pub async fn sync( } let mut connections_ft = Vec::with_capacity(changes.add.len()); - let new_writers = Arc::new(Mutex::new(HashMap::with_capacity(changes.add.len()))); + let new_writers = RefCount::new(Mutex::new(HashMap::with_capacity(changes.add.len()))); // connect to each of the new nodes for server in changes.add.into_iter() { let _inner = inner.clone(); @@ -732,7 +755,7 @@ pub async fn sync( /// Initialize fresh connections to the server, dropping any old connections and saving in-flight commands on /// `buffer`. pub async fn initialize_connections( - inner: &Arc, + inner: &RefCount, connections: &mut Connections, buffer: &mut VecDeque, ) -> Result<(), RedisError> { diff --git a/src/router/commands.rs b/src/router/commands.rs index 28b8b440..91c06019 100644 --- a/src/router/commands.rs +++ b/src/router/commands.rs @@ -10,12 +10,11 @@ use crate::{ RouterResponse, }, router::{utils, Backpressure, Router, Written}, + runtime::{OneshotSender, RefCount}, types::{Blocking, ClientState, ClientUnblockFlag, ClusterHash, Server}, utils as client_utils, }; use redis_protocol::resp3::types::BytesFrame as Resp3Frame; -use std::sync::Arc; -use tokio::sync::oneshot::Sender as OneshotSender; #[cfg(feature = "transactions")] use crate::router::transactions; @@ -28,7 +27,7 @@ use tracing_futures::Instrument; /// /// Errors from this function should end the connection task. async fn handle_router_response( - inner: &Arc, + inner: &RefCount, router: &mut Router, rx: Option, ) -> Result, RedisError> { @@ -99,7 +98,7 @@ async fn handle_router_response( /// Continuously write the command until it is sent, queued to try later, or fails with a fatal error. async fn write_with_backpressure( - inner: &Arc, + inner: &RefCount, router: &mut Router, command: RedisCommand, force_pipeline: bool, @@ -286,7 +285,7 @@ async fn write_with_backpressure( #[cfg(feature = "full-tracing")] async fn write_with_backpressure_t( - inner: &Arc, + inner: &RefCount, router: &mut Router, mut command: RedisCommand, force_pipeline: bool, @@ -304,7 +303,7 @@ async fn write_with_backpressure_t( #[cfg(not(feature = "full-tracing"))] async fn write_with_backpressure_t( - inner: &Arc, + inner: &RefCount, router: &mut Router, command: RedisCommand, force_pipeline: bool, @@ -314,7 +313,7 @@ async fn write_with_backpressure_t( /// Run a pipelined series of commands, queueing commands to run later if needed. async fn process_pipeline( - inner: &Arc, + inner: &RefCount, router: &mut Router, commands: Vec, ) -> Result<(), RedisError> { @@ -346,7 +345,7 @@ async fn process_pipeline( /// Send ASKING to the provided server, then retry the provided command. async fn process_ask( - inner: &Arc, + inner: &RefCount, router: &mut Router, server: Server, slot: u16, @@ -373,7 +372,7 @@ async fn process_ask( /// Sync the cluster state then retry the command. async fn process_moved( - inner: &Arc, + inner: &RefCount, router: &mut Router, server: Server, slot: u16, @@ -402,16 +401,17 @@ async fn process_moved( #[cfg(feature = "replicas")] async fn process_replica_reconnect( - inner: &Arc, + inner: &RefCount, router: &mut Router, server: Option, force: bool, tx: Option, replica: bool, ) -> Result<(), RedisError> { + #[allow(unused_mut)] if replica { let result = utils::sync_replicas_with_policy(inner, router, false).await; - if let Some(tx) = tx { + if let Some(mut tx) = tx { let _ = tx.send(result.map(|_| Resp3Frame::Null)); } @@ -422,8 +422,9 @@ async fn process_replica_reconnect( } /// Reconnect to the server(s). +#[allow(unused_mut)] async fn process_reconnect( - inner: &Arc, + inner: &RefCount, router: &mut Router, server: Option, force: bool, @@ -437,7 +438,7 @@ async fn process_reconnect( if has_connection && !force { _debug!(inner, "Skip reconnecting to {}", server); - if let Some(tx) = tx { + if let Some(mut tx) = tx { let _ = tx.send(Ok(Resp3Frame::Null)); } @@ -447,7 +448,7 @@ async fn process_reconnect( if !force && router.has_healthy_centralized_connection() { _debug!(inner, "Skip reconnecting to centralized host"); - if let Some(tx) = tx { + if let Some(mut tx) = tx { let _ = tx.send(Ok(Resp3Frame::Null)); } return Ok(()); @@ -455,13 +456,13 @@ async fn process_reconnect( _debug!(inner, "Starting reconnection loop..."); if let Err(e) = utils::reconnect_with_policy(inner, router).await { - if let Some(tx) = tx { + if let Some(mut tx) = tx { let _ = tx.send(Err(e.clone())); } Err(e) } else { - if let Some(tx) = tx { + if let Some(mut tx) = tx { let _ = tx.send(Ok(Resp3Frame::Null)); } @@ -470,10 +471,11 @@ async fn process_reconnect( } #[cfg(feature = "replicas")] +#[allow(unused_mut)] async fn process_sync_replicas( - inner: &Arc, + inner: &RefCount, router: &mut Router, - tx: OneshotSender>, + mut tx: OneshotSender>, reset: bool, ) -> Result<(), RedisError> { let result = utils::sync_replicas_with_policy(inner, router, reset).await; @@ -482,10 +484,11 @@ async fn process_sync_replicas( } /// Sync and update the cached cluster state. +#[allow(unused_mut)] async fn process_sync_cluster( - inner: &Arc, + inner: &RefCount, router: &mut Router, - tx: OneshotSender>, + mut tx: OneshotSender>, ) -> Result<(), RedisError> { let result = utils::sync_cluster_with_policy(inner, router).await; let _ = tx.send(result.clone()); @@ -494,7 +497,7 @@ async fn process_sync_cluster( /// Send a single command to the server(s). async fn process_normal_command( - inner: &Arc, + inner: &RefCount, router: &mut Router, command: RedisCommand, ) -> Result<(), RedisError> { @@ -502,10 +505,11 @@ async fn process_normal_command( } /// Read the set of active connections managed by the client. +#[allow(unused_mut)] fn process_connections( - inner: &Arc, + inner: &RefCount, router: &Router, - tx: OneshotSender>, + mut tx: OneshotSender>, ) -> Result<(), RedisError> { #[allow(unused_mut)] let mut connections = router.connections.active_connections(); @@ -519,7 +523,7 @@ fn process_connections( /// Process any kind of router command. async fn process_command( - inner: &Arc, + inner: &RefCount, router: &mut Router, command: RouterCommand, ) -> Result<(), RedisError> { @@ -530,11 +534,17 @@ async fn process_command( #[cfg(feature = "transactions")] RouterCommand::Transaction { commands, - watched, + pipelined, id, tx, abort_on_error, - } => transactions::run(inner, router, commands, watched, id, abort_on_error, tx).await, + } => { + if pipelined { + transactions::exec::pipelined(inner, router, commands, id, tx).await + } else { + transactions::exec::non_pipelined(inner, router, commands, id, abort_on_error, tx).await + } + }, RouterCommand::Pipeline { commands } => process_pipeline(inner, router, commands).await, RouterCommand::Command(command) => process_normal_command(inner, router, command).await, RouterCommand::Connections { tx } => process_connections(inner, router, tx), @@ -554,7 +564,7 @@ async fn process_command( /// Start processing commands from the client front end. async fn process_commands( - inner: &Arc, + inner: &RefCount, router: &mut Router, rx: &mut CommandReceiver, ) -> Result<(), RedisError> { @@ -583,7 +593,7 @@ async fn process_commands( } /// Start the command processing stream, initiating new connections in the process. -pub async fn start(inner: &Arc) -> Result<(), RedisError> { +pub async fn start(inner: &RefCount) -> Result<(), RedisError> { #[cfg(feature = "mocks")] if let Some(ref mocks) = inner.config.mocks { return mocking::start(inner, mocks).await; @@ -622,22 +632,33 @@ pub async fn start(inner: &Arc) -> Result<(), RedisError> { inner.store_command_rx(rx, false); Err(error) } else { + #[cfg(feature = "credential-provider")] + inner.reset_credential_refresh_task(); + let result = Box::pin(process_commands(inner, &mut router, &mut rx)).await; inner.store_command_rx(rx, false); + #[cfg(feature = "credential-provider")] + inner.abort_credential_refresh_task(); result } } #[cfg(feature = "mocks")] +#[allow(unused_mut)] mod mocking { use super::*; - use crate::{modules::mocks::Mocks, protocol::utils as protocol_utils}; + use crate::{ + modules::mocks::Mocks, + protocol::{responders::ResponseKind, utils as protocol_utils}, + }; + use redis_protocol::resp3::types::BytesFrame; + use std::sync::Arc; /// Process any kind of router command. pub fn process_command(mocks: &Arc, command: RouterCommand) -> Result<(), RedisError> { match command { #[cfg(feature = "transactions")] - RouterCommand::Transaction { commands, tx, .. } => { + RouterCommand::Transaction { commands, mut tx, .. } => { let mocked = commands.into_iter().skip(1).map(|c| c.to_mocked()).collect(); match mocks.process_transaction(mocked) { @@ -651,12 +672,42 @@ mod mocking { }, } }, - RouterCommand::Pipeline { commands } => { + RouterCommand::Pipeline { mut commands } => { + let mut results = Vec::with_capacity(commands.len()); + let response = commands.last_mut().map(|c| c.take_response()); + let uses_all_results = matches!(response, Some(ResponseKind::Buffer { .. })); + let tx = response.and_then(|mut k| k.take_response_tx()); + for mut command in commands.into_iter() { - let mocked = command.to_mocked(); - let result = mocks.process_command(mocked).map(protocol_utils::mocked_value_to_frame); + let result = mocks + .process_command(command.to_mocked()) + .map(protocol_utils::mocked_value_to_frame); - command.respond_to_caller(result); + results.push(result); + } + if let Some(mut tx) = tx { + let mut frames = Vec::with_capacity(results.len()); + + for frame in results.into_iter() { + match frame { + Ok(frame) => frames.push(frame), + Err(err) => { + frames.push(Resp3Frame::SimpleError { + data: err.details().into(), + attributes: None, + }); + }, + } + } + + if uses_all_results { + let _ = tx.send(Ok(BytesFrame::Array { + data: frames, + attributes: None, + })); + } else { + let _ = tx.send(Ok(frames.pop().unwrap_or(BytesFrame::Null))); + } } Ok(()) @@ -674,7 +725,7 @@ mod mocking { } pub async fn process_commands( - inner: &Arc, + inner: &RefCount, mocks: &Arc, rx: &mut CommandReceiver, ) -> Result<(), RedisError> { @@ -696,9 +747,14 @@ mod mocking { Ok(()) } - pub async fn start(inner: &Arc, mocks: &Arc) -> Result<(), RedisError> { + pub async fn start(inner: &RefCount, mocks: &Arc) -> Result<(), RedisError> { _debug!(inner, "Starting mocking layer"); + + #[cfg(feature = "glommio")] + glommio::yield_if_needed().await; + #[cfg(not(feature = "glommio"))] tokio::task::yield_now().await; + let mut rx = match inner.take_command_rx() { Some(rx) => rx, None => { diff --git a/src/router/mod.rs b/src/router/mod.rs index 5990de2a..6cbf0c9b 100644 --- a/src/router/mod.rs +++ b/src/router/mod.rs @@ -6,6 +6,7 @@ use crate::{ connection::{self, CommandBuffer, Counters, RedisWriter}, types::{ClusterRouting, Server}, }, + runtime::RefCount, trace, utils as client_utils, }; @@ -15,21 +16,19 @@ use std::{ collections::{HashMap, VecDeque}, fmt, fmt::Formatter, - sync::Arc, time::Duration, }; +#[cfg(feature = "transactions")] +use crate::runtime::oneshot_channel; #[cfg(feature = "transactions")] use crate::{protocol::command::ClusterErrorKind, protocol::responders::ResponseKind}; #[cfg(feature = "replicas")] use std::collections::HashSet; -#[cfg(feature = "transactions")] -use tokio::sync::oneshot::channel as oneshot_channel; pub mod centralized; pub mod clustered; pub mod commands; -pub mod reader; pub mod replicas; pub mod responses; pub mod sentinel; @@ -94,7 +93,7 @@ impl Backpressure { /// Apply the backpressure policy. pub async fn wait( self, - inner: &Arc, + inner: &RefCount, command: &mut RedisCommand, ) -> Result, RedisError> { match self { @@ -160,7 +159,10 @@ impl Connections { /// Discover and return a mapping of replica nodes to their associated primary node. #[cfg(feature = "replicas")] - pub async fn replica_map(&mut self, inner: &Arc) -> Result, RedisError> { + pub async fn replica_map( + &mut self, + inner: &RefCount, + ) -> Result, RedisError> { Ok(match self { Connections::Centralized { ref mut writer } | Connections::Sentinel { ref mut writer } => { if let Some(writer) = writer { @@ -238,7 +240,7 @@ impl Connections { /// Initialize the underlying connection(s) and update the cached backchannel information. pub async fn initialize( &mut self, - inner: &Arc, + inner: &RefCount, buffer: &mut VecDeque, ) -> Result<(), RedisError> { let result = if inner.config.server.is_clustered() { @@ -284,7 +286,7 @@ impl Connections { } /// Disconnect from the provided server, using the default centralized connection if `None` is provided. - pub async fn disconnect(&mut self, inner: &Arc, server: Option<&Server>) -> CommandBuffer { + pub async fn disconnect(&mut self, inner: &RefCount, server: Option<&Server>) -> CommandBuffer { match self { Connections::Centralized { ref mut writer } => { if let Some(writer) = writer.take() { @@ -318,7 +320,7 @@ impl Connections { } /// Disconnect and clear local state for all connections, returning all in-flight commands. - pub async fn disconnect_all(&mut self, inner: &Arc) -> CommandBuffer { + pub async fn disconnect_all(&mut self, inner: &RefCount) -> CommandBuffer { match self { Connections::Centralized { ref mut writer } => { if let Some(writer) = writer.take() { @@ -380,7 +382,7 @@ impl Connections { } /// Flush the socket(s) associated with each server if they have pending frames. - pub async fn check_and_flush(&mut self, inner: &Arc) -> Result<(), RedisError> { + pub async fn check_and_flush(&mut self, inner: &RefCount) -> Result<(), RedisError> { _trace!(inner, "Checking and flushing sockets..."); match self { @@ -407,7 +409,12 @@ impl Connections { } /// Send a command to the server(s). - pub async fn write(&mut self, inner: &Arc, command: RedisCommand, force_flush: bool) -> Written { + pub async fn write( + &mut self, + inner: &RefCount, + command: RedisCommand, + force_flush: bool, + ) -> Written { match self { Connections::Clustered { ref mut writers, @@ -419,7 +426,7 @@ impl Connections { } /// Send a command to all servers in a cluster. - pub async fn write_all_cluster(&mut self, inner: &Arc, command: RedisCommand) -> Written { + pub async fn write_all_cluster(&mut self, inner: &RefCount, command: RedisCommand) -> Written { if let Connections::Clustered { ref mut writers, .. } = self { if let Err(error) = clustered::send_all_cluster_command(inner, writers, command).await { Written::Disconnected((None, None, error)) @@ -449,7 +456,11 @@ impl Connections { } /// Connect or reconnect to the provided `host:port`. - pub async fn add_connection(&mut self, inner: &Arc, server: &Server) -> Result<(), RedisError> { + pub async fn add_connection( + &mut self, + inner: &RefCount, + server: &Server, + ) -> Result<(), RedisError> { if let Connections::Clustered { ref mut writers, .. } = self { let mut transport = connection::create(inner, server, None).await?; transport.setup(inner, None).await?; @@ -497,7 +508,7 @@ pub struct Router { /// The connection map for each deployment type. pub connections: Connections, /// The inner client state associated with the router. - pub inner: Arc, + pub inner: RefCount, /// Storage for commands that should be deferred or retried later. pub buffer: VecDeque, /// The replica routing interface. @@ -507,7 +518,7 @@ pub struct Router { impl Router { /// Create a new `Router` without connecting to the server(s). - pub fn new(inner: &Arc) -> Self { + pub fn new(inner: &RefCount) -> Self { let connections = if inner.config.server.is_clustered() { Connections::new_clustered() } else if inner.config.server.is_sentinel() { diff --git a/src/router/reader.rs b/src/router/reader.rs deleted file mode 100644 index 8b137891..00000000 --- a/src/router/reader.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/src/router/replicas.rs b/src/router/replicas.rs index 8d718235..9c015913 100644 --- a/src/router/replicas.rs +++ b/src/router/replicas.rs @@ -10,6 +10,7 @@ use crate::{ connection::{CommandBuffer, RedisWriter}, }, router::{centralized, clustered, utils, Written}, + runtime::RefCount, types::Server, }; #[cfg(feature = "replicas")] @@ -17,7 +18,6 @@ use std::{ collections::{HashMap, VecDeque}, fmt, fmt::Formatter, - sync::Arc, }; /// An interface used to filter the list of available replica nodes. @@ -44,7 +44,7 @@ pub struct ReplicaConfig { /// An optional interface for filtering available replica nodes. /// /// Default: `None` - pub filter: Option>, + pub filter: Option>, /// Whether the client should ignore errors from replicas that occur when the max reconnection count is reached. /// /// Default: `true` @@ -258,7 +258,7 @@ impl Replicas { } /// Sync the connection map in place based on the cached routing table. - pub async fn sync_connections(&mut self, inner: &Arc) -> Result<(), RedisError> { + pub async fn sync_connections(&mut self, inner: &RefCount) -> Result<(), RedisError> { for (_, writer) in self.writers.drain() { let commands = writer.graceful_close().await; self.buffer.extend(commands); @@ -272,7 +272,7 @@ impl Replicas { } /// Drop all connections and clear the cached routing table. - pub async fn clear_connections(&mut self, inner: &Arc) -> Result<(), RedisError> { + pub async fn clear_connections(&mut self, inner: &RefCount) -> Result<(), RedisError> { self.routing.clear(); self.sync_connections(inner).await } @@ -285,7 +285,7 @@ impl Replicas { /// Connect to the replica and add it to the cached routing table. pub async fn add_connection( &mut self, - inner: &Arc, + inner: &RefCount, primary: Server, replica: Server, force: bool, @@ -331,7 +331,7 @@ impl Replicas { /// Close the replica connection and optionally remove the replica from the routing table. pub async fn remove_connection( &mut self, - inner: &Arc, + inner: &RefCount, primary: &Server, replica: &Server, keep_routable: bool, @@ -414,7 +414,7 @@ impl Replicas { /// Send a command to one of the replicas associated with the provided primary server. pub async fn write( &mut self, - inner: &Arc, + inner: &RefCount, primary: &Server, mut command: RedisCommand, force_flush: bool, @@ -532,7 +532,7 @@ impl Replicas { } #[cfg(all(feature = "replicas", any(feature = "enable-native-tls", feature = "enable-rustls")))] -pub fn map_replica_tls_names(inner: &Arc, primary: &Server, replica: &mut Server) { +pub fn map_replica_tls_names(inner: &RefCount, primary: &Server, replica: &mut Server) { let policy = match inner.config.tls { Some(ref config) => &config.hostnames, None => { @@ -552,4 +552,4 @@ pub fn map_replica_tls_names(inner: &Arc, primary: &Server, re feature = "replicas", not(any(feature = "enable-native-tls", feature = "enable-rustls")) ))] -pub fn map_replica_tls_names(_: &Arc, _: &Server, _: &mut Server) {} +pub fn map_replica_tls_names(_: &RefCount, _: &Server, _: &mut Server) {} diff --git a/src/router/responses.rs b/src/router/responses.rs index e3c20245..1cde3b52 100644 --- a/src/router/responses.rs +++ b/src/router/responses.rs @@ -2,6 +2,7 @@ use crate::{ error::{RedisError, RedisErrorKind}, modules::inner::RedisClientInner, protocol::{command::RedisCommand, types::Server, utils as protocol_utils, utils::pretty_error}, + runtime::RefCount, trace, types::{ClientState, KeyspaceEvent, Message, RedisKey, RedisValue}, utils, @@ -10,7 +11,7 @@ use redis_protocol::{ resp3::types::{BytesFrame as Resp3Frame, FrameKind, Resp3Frame as _Resp3Frame}, types::PUBSUB_PUSH_PREFIX, }; -use std::{str, sync::Arc}; +use std::str; #[cfg(feature = "i-tracking")] use crate::types::Invalidation; @@ -59,7 +60,7 @@ fn parse_keyspace_notification(channel: &str, message: &RedisValue) -> Option, message: Message, server: &Server) { +fn broadcast_pubsub_invalidation(inner: &RefCount, message: Message, server: &Server) { if let Some(invalidation) = Invalidation::from_message(message, server) { inner.notifications.broadcast_invalidation(invalidation); } else { @@ -71,7 +72,7 @@ fn broadcast_pubsub_invalidation(inner: &Arc, message: Message } #[cfg(not(feature = "i-tracking"))] -fn broadcast_pubsub_invalidation(_: &Arc, _: Message, _: &Server) {} +fn broadcast_pubsub_invalidation(_: &RefCount, _: Message, _: &Server) {} #[cfg(feature = "i-tracking")] fn is_pubsub_invalidation(message: &Message) -> bool { @@ -84,7 +85,7 @@ fn is_pubsub_invalidation(_: &Message) -> bool { } #[cfg(feature = "i-tracking")] -fn broadcast_resp3_invalidation(inner: &Arc, server: &Server, frame: Resp3Frame) { +fn broadcast_resp3_invalidation(inner: &RefCount, server: &Server, frame: Resp3Frame) { if let Resp3Frame::Push { mut data, .. } = frame { if data.len() != 2 { return; @@ -105,7 +106,7 @@ fn broadcast_resp3_invalidation(inner: &Arc, server: &Server, } #[cfg(not(feature = "i-tracking"))] -fn broadcast_resp3_invalidation(_: &Arc, _: &Server, _: Resp3Frame) {} +fn broadcast_resp3_invalidation(_: &RefCount, _: &Server, _: Resp3Frame) {} #[cfg(feature = "i-tracking")] fn is_resp3_invalidation(frame: &Resp3Frame) -> bool { @@ -163,7 +164,11 @@ fn is_resp3_invalidation(_: &Resp3Frame) -> bool { /// Check if the frame is part of a pubsub message, and if so route it to any listeners. /// /// If not then return it to the caller for further processing. -pub fn check_pubsub_message(inner: &Arc, server: &Server, frame: Resp3Frame) -> Option { +pub fn check_pubsub_message( + inner: &RefCount, + server: &Server, + frame: Resp3Frame, +) -> Option { if is_subscription_response(&frame) { _debug!(inner, "Dropping unused subscription response."); return None; @@ -213,7 +218,7 @@ pub fn check_pubsub_message(inner: &Arc, server: &Server, fram // TODO cleanup and rename // this is called by the reader task after a blocking command finishes in order to mark the connection as unblocked -pub async fn check_and_set_unblocked_flag(inner: &Arc, command: &RedisCommand) { +pub async fn check_and_set_unblocked_flag(inner: &RefCount, command: &RedisCommand) { if command.blocks_connection() { inner.backchannel.write().await.set_unblocked(); } @@ -235,7 +240,7 @@ fn parse_redis_auth_error(frame: &Resp3Frame) -> Option { } #[cfg(feature = "custom-reconnect-errors")] -fn check_global_reconnect_errors(inner: &Arc, frame: &Resp3Frame) -> Option { +fn check_global_reconnect_errors(inner: &RefCount, frame: &Resp3Frame) -> Option { if let Resp3Frame::SimpleError { ref data, .. } = frame { for prefix in inner.connection.reconnect_errors.iter() { if data.starts_with(prefix.to_str()) { @@ -253,7 +258,7 @@ fn check_global_reconnect_errors(inner: &Arc, frame: &Resp3Fra } #[cfg(not(feature = "custom-reconnect-errors"))] -fn check_global_reconnect_errors(_: &Arc, _: &Resp3Frame) -> Option { +fn check_global_reconnect_errors(_: &RefCount, _: &Resp3Frame) -> Option { None } @@ -283,7 +288,7 @@ fn is_clusterdown_error(frame: &Resp3Frame) -> Option<&str> { } /// Check for special errors configured by the caller to initiate a reconnection process. -pub fn check_special_errors(inner: &Arc, frame: &Resp3Frame) -> Option { +pub fn check_special_errors(inner: &RefCount, frame: &Resp3Frame) -> Option { if inner.connection.reconnect_on_auth_error { if let Some(auth_error) = parse_redis_auth_error(frame) { return Some(auth_error); @@ -297,7 +302,7 @@ pub fn check_special_errors(inner: &Arc, frame: &Resp3Frame) - } /// Handle an error in the reader task that should end the connection. -pub fn broadcast_reader_error(inner: &Arc, server: &Server, error: Option) { +pub fn broadcast_reader_error(inner: &RefCount, server: &Server, error: Option) { _warn!(inner, "Ending reader task from {} due to {:?}", server, error); if inner.should_reconnect() { @@ -311,12 +316,12 @@ pub fn broadcast_reader_error(inner: &Arc, server: &Server, er } #[cfg(not(feature = "replicas"))] -pub fn broadcast_replica_error(inner: &Arc, server: &Server, error: Option) { +pub fn broadcast_replica_error(inner: &RefCount, server: &Server, error: Option) { broadcast_reader_error(inner, server, error); } #[cfg(feature = "replicas")] -pub fn broadcast_replica_error(inner: &Arc, server: &Server, error: Option) { +pub fn broadcast_replica_error(inner: &RefCount, server: &Server, error: Option) { _warn!(inner, "Ending replica reader task from {} due to {:?}", server, error); if inner.should_reconnect() { diff --git a/src/router/sentinel.rs b/src/router/sentinel.rs index c53d4873..4be44292 100644 --- a/src/router/sentinel.rs +++ b/src/router/sentinel.rs @@ -9,14 +9,12 @@ use crate::{ utils as protocol_utils, }, router::{centralized, Connections}, + runtime::RefCount, types::{RedisValue, Server, ServerConfig}, utils, }; use bytes_utils::Str; -use std::{ - collections::{HashMap, HashSet, VecDeque}, - sync::Arc, -}; +use std::collections::{HashMap, HashSet, VecDeque}; pub static CONFIG: &str = "CONFIG"; pub static SET: &str = "SET"; @@ -48,7 +46,7 @@ macro_rules! stry ( ); fn parse_sentinel_nodes_response( - inner: &Arc, + inner: &RefCount, value: RedisValue, ) -> Result, RedisError> { let result_maps: Vec> = stry!(value.convert()); @@ -96,7 +94,7 @@ fn has_different_sentinel_nodes(old: &[(String, u16)], new: &[(String, u16)]) -> } #[cfg(feature = "sentinel-auth")] -fn read_sentinel_auth(inner: &Arc) -> Result<(Option, Option), RedisError> { +fn read_sentinel_auth(inner: &RefCount) -> Result<(Option, Option), RedisError> { match inner.config.server { ServerConfig::Sentinel { ref username, @@ -111,31 +109,48 @@ fn read_sentinel_auth(inner: &Arc) -> Result<(Option, } #[cfg(not(feature = "sentinel-auth"))] -fn read_sentinel_auth(inner: &Arc) -> Result<(Option, Option), RedisError> { +fn read_sentinel_auth(inner: &RefCount) -> Result<(Option, Option), RedisError> { Ok((inner.config.username.clone(), inner.config.password.clone())) } +fn read_sentinel_hosts(inner: &RefCount) -> Result, RedisError> { + inner + .server_state + .read() + .kind + .read_sentinel_nodes(&inner.config.server) + .ok_or(RedisError::new( + RedisErrorKind::Sentinel, + "Failed to read cached sentinel nodes.", + )) +} + /// Read the `(host, port)` tuples for the known sentinel nodes, and the credentials to use when connecting. -fn read_sentinel_nodes_and_auth( - inner: &Arc, -) -> Result<(Vec, (Option, Option)), RedisError> { - let (username, password) = read_sentinel_auth(inner)?; - let hosts = match inner.server_state.read().kind.read_sentinel_nodes(&inner.config.server) { - Some(hosts) => hosts, - None => { - return Err(RedisError::new( - RedisErrorKind::Sentinel, - "Failed to read cached sentinel nodes.", - )) - }, +#[cfg(feature = "credential-provider")] +async fn read_sentinel_credentials( + inner: &RefCount, + server: &Server, +) -> Result<(Option, Option), RedisError> { + let (username, password) = if let Some(ref provider) = inner.config.credential_provider { + provider.fetch(Some(server)).await? + } else { + read_sentinel_auth(inner)? }; - Ok((hosts, (username, password))) + Ok((username, password)) +} + +#[cfg(not(feature = "credential-provider"))] +async fn read_sentinel_credentials( + inner: &RefCount, + _: &Server, +) -> Result<(Option, Option), RedisError> { + read_sentinel_auth(inner) } /// Read the set of sentinel nodes via `SENTINEL sentinels`. async fn read_sentinels( - inner: &Arc, + inner: &RefCount, sentinel: &mut RedisTransport, ) -> Result, RedisError> { let service_name = read_service_name(inner)?; @@ -157,10 +172,12 @@ async fn read_sentinels( } /// Connect to any of the sentinel nodes provided on the associated `RedisConfig`. -async fn connect_to_sentinel(inner: &Arc) -> Result { - let (hosts, (username, password)) = read_sentinel_nodes_and_auth(inner)?; +async fn connect_to_sentinel(inner: &RefCount) -> Result { + let hosts = read_sentinel_hosts(inner)?; for server in hosts.into_iter() { + let (username, password) = read_sentinel_credentials(inner, &server).await?; + _debug!(inner, "Connecting to sentinel {}", server); let mut transport = try_or_continue!(connection::create(inner, &server, None).await); try_or_continue!( @@ -180,7 +197,7 @@ async fn connect_to_sentinel(inner: &Arc) -> Result) -> Result { +fn read_service_name(inner: &RefCount) -> Result { match inner.config.server { ServerConfig::Sentinel { ref service_name, .. } => Ok(service_name.to_owned()), _ => Err(RedisError::new( @@ -193,7 +210,7 @@ fn read_service_name(inner: &Arc) -> Result, + inner: &RefCount, sentinel: &mut RedisTransport, ) -> Result { let service_name = read_service_name(inner)?; @@ -233,7 +250,7 @@ async fn discover_primary_node( /// Verify that the Redis server is a primary node and not a replica. async fn check_primary_node_role( - inner: &Arc, + inner: &RefCount, transport: &mut RedisTransport, ) -> Result<(), RedisError> { let command = RedisCommand::new(RedisCommandKind::Role, Vec::new()); @@ -268,7 +285,7 @@ async fn check_primary_node_role( /// Update the cached backchannel state with the new connection information, disconnecting the old connection if /// needed. async fn update_sentinel_backchannel( - inner: &Arc, + inner: &RefCount, transport: &RedisTransport, ) -> Result<(), RedisError> { let mut backchannel = inner.backchannel.write().await; @@ -290,7 +307,7 @@ async fn update_sentinel_backchannel( /// * Update the cached backchannel information. /// * Split and store the primary node transport on `writer`. async fn update_cached_client_state( - inner: &Arc, + inner: &RefCount, writer: &mut Option, mut sentinel: RedisTransport, transport: RedisTransport, @@ -313,7 +330,7 @@ async fn update_cached_client_state( /// /// pub async fn initialize_connection( - inner: &Arc, + inner: &RefCount, connections: &mut Connections, buffer: &mut VecDeque, ) -> Result<(), RedisError> { diff --git a/src/router/transactions.rs b/src/router/transactions.rs index a7bb0c4e..53fbe942 100644 --- a/src/router/transactions.rs +++ b/src/router/transactions.rs @@ -5,12 +5,15 @@ use crate::{ protocol::{ command::{ClusterErrorKind, RedisCommand, RedisCommandKind, ResponseSender, RouterReceiver, RouterResponse}, responders::ResponseKind, + utils::pretty_error, }, - router::{utils, Router, Written}, + router::{clustered::parse_cluster_error_frame, utils, Router, Written}, + runtime::{oneshot_channel, AtomicUsize, Mutex, RefCount}, types::{ClusterHash, Server}, utils as client_utils, }; -use std::sync::Arc; +use redis_protocol::resp3::types::{FrameKind, Resp3Frame as _Resp3Frame}; +use std::iter::repeat; /// An internal enum describing the result of an attempt to send a transaction command. #[derive(Debug)] @@ -34,12 +37,12 @@ enum TransactionResponse { /// /// Returns the command result policy or a fatal error that should end the transaction. async fn write_command( - inner: &Arc, + inner: &RefCount, router: &mut Router, server: &Server, command: RedisCommand, abort_on_error: bool, - rx: RouterReceiver, + rx: Option, ) -> Result { _trace!( inner, @@ -62,31 +65,35 @@ async fn write_command( return Ok(TransactionResponse::Retry(e)); } - match client_utils::timeout(rx, timeout_dur).await? { - RouterResponse::Continue => Ok(TransactionResponse::Continue), - RouterResponse::Ask((slot, server, _)) => { - Ok(TransactionResponse::Redirection((ClusterErrorKind::Ask, slot, server))) - }, - RouterResponse::Moved((slot, server, _)) => Ok(TransactionResponse::Redirection(( - ClusterErrorKind::Moved, - slot, - server, - ))), - RouterResponse::ConnectionClosed((err, _)) => Ok(TransactionResponse::Retry(err)), - RouterResponse::TransactionError((err, _)) => { - if abort_on_error { - Err(err) - } else { - Ok(TransactionResponse::Continue) - } - }, - RouterResponse::TransactionResult(frame) => Ok(TransactionResponse::Finished(frame)), + if let Some(rx) = rx { + match client_utils::timeout(rx, timeout_dur).await? { + RouterResponse::Continue => Ok(TransactionResponse::Continue), + RouterResponse::Ask((slot, server, _)) => { + Ok(TransactionResponse::Redirection((ClusterErrorKind::Ask, slot, server))) + }, + RouterResponse::Moved((slot, server, _)) => Ok(TransactionResponse::Redirection(( + ClusterErrorKind::Moved, + slot, + server, + ))), + RouterResponse::ConnectionClosed((err, _)) => Ok(TransactionResponse::Retry(err)), + RouterResponse::TransactionError((err, _)) => { + if abort_on_error { + Err(err) + } else { + Ok(TransactionResponse::Continue) + } + }, + RouterResponse::TransactionResult(frame) => Ok(TransactionResponse::Finished(frame)), + } + } else { + Ok(TransactionResponse::Continue) } } /// Send EXEC to the provided server. -async fn send_exec( - inner: &Arc, +async fn send_non_pipelined_exec( + inner: &RefCount, router: &mut Router, server: &Server, id: u64, @@ -97,12 +104,12 @@ async fn send_exec( command.transaction_id = Some(id); let rx = command.create_router_channel(); - write_command(inner, router, server, command, true, rx).await + write_command(inner, router, server, command, true, Some(rx)).await } /// Send DISCARD to the provided server. -async fn send_discard( - inner: &Arc, +async fn send_non_pipelined_discard( + inner: &RefCount, router: &mut Router, server: &Server, id: u64, @@ -113,7 +120,7 @@ async fn send_discard( command.transaction_id = Some(id); let rx = command.create_router_channel(); - write_command(inner, router, server, command, true, rx).await + write_command(inner, router, server, command, true, Some(rx)).await } fn update_hash_slot(commands: &mut [RedisCommand], slot: u16) { @@ -122,78 +129,178 @@ fn update_hash_slot(commands: &mut [RedisCommand], slot: u16) { } } -/// Run the transaction, following cluster redirects and reconnecting as needed. -// this would be a lot cleaner with GATs if we could abstract the inner loops with async closures -pub async fn run( - inner: &Arc, +/// Find the server that should receive the transaction, creating connections if needed. +async fn find_or_create_connection( + inner: &RefCount, router: &mut Router, - mut commands: Vec, - watched: Option, + command: &RedisCommand, +) -> Result, RedisError> { + if let Some(server) = command.cluster_node.as_ref() { + Ok(Some(server.clone())) + } else { + match router.find_connection(command) { + Some(server) => Ok(Some(server.clone())), + None => { + if inner.config.server.is_clustered() { + // optimistically sync the cluster, then fall back to a full reconnect + if router.sync_cluster().await.is_err() { + utils::delay_cluster_sync(inner).await?; + utils::reconnect_with_policy(inner, router).await? + } + } else { + utils::reconnect_with_policy(inner, router).await? + }; + + Ok(None) + }, + } + } +} + +fn build_pipeline( + commands: &[RedisCommand], + response: ResponseKind, id: u64, - abort_on_error: bool, - tx: ResponseSender, -) -> Result<(), RedisError> { - if commands.is_empty() { - let _ = tx.send(Ok(Resp3Frame::Null)); - return Ok(()); +) -> Result, RedisError> { + let mut pipeline = Vec::with_capacity(commands.len() + 1); + let mut exec = RedisCommand::new(RedisCommandKind::Exec, vec![]); + exec.can_pipeline = true; + exec.skip_backpressure = true; + exec.fail_fast = true; + exec.transaction_id = Some(id); + exec.response = response + .duplicate() + .ok_or_else(|| RedisError::new(RedisErrorKind::Unknown, "Invalid pipelined transaction response."))?; + exec.response.set_expected_index(commands.len()); + + for (idx, command) in commands.iter().enumerate() { + let mut response = response + .duplicate() + .ok_or_else(|| RedisError::new(RedisErrorKind::Unknown, "Invalid pipelined transaction response."))?; + response.set_expected_index(idx); + let mut command = command.duplicate(response); + command.fail_fast = true; + command.skip_backpressure = true; + command.can_pipeline = true; + + pipeline.push(command); } - // each of the commands should have the same options - let max_attempts = if commands[0].attempts_remaining == 0 { - inner.max_command_attempts() - } else { - commands[0].attempts_remaining - }; - let max_redirections = if commands[0].redirections_remaining == 0 { - inner.connection.max_redirections - } else { - commands[0].redirections_remaining - }; + pipeline.push(exec); + Ok(pipeline) +} - let mut attempted = 0; - let mut redirections = 0; - 'outer: loop { - _debug!(inner, "Starting transaction {} (attempted: {})", id, attempted); +pub mod exec { + use super::*; + // TODO find a better way to combine these functions - let server = if let Some(server) = commands[0].cluster_node.as_ref() { - server.clone() + /// Run the transaction, following cluster redirects and reconnecting as needed. + #[allow(unused_mut)] + pub async fn non_pipelined( + inner: &RefCount, + router: &mut Router, + mut commands: Vec, + id: u64, + abort_on_error: bool, + mut tx: ResponseSender, + ) -> Result<(), RedisError> { + if commands.is_empty() { + let _ = tx.send(Ok(Resp3Frame::Null)); + return Ok(()); + } + // each of the commands should have the same options + let max_attempts = if commands[0].attempts_remaining == 0 { + inner.max_command_attempts() } else { - match router.find_connection(&commands[0]) { - Some(server) => server.clone(), - None => { - if inner.config.server.is_clustered() { - // optimistically sync the cluster, then fall back to a full reconnect - if router.sync_cluster().await.is_err() { - utils::delay_cluster_sync(inner).await?; - utils::reconnect_with_policy(inner, router).await? + commands[0].attempts_remaining + }; + let max_redirections = if commands[0].redirections_remaining == 0 { + inner.connection.max_redirections + } else { + commands[0].redirections_remaining + }; + + let mut attempted = 0; + let mut redirections = 0; + 'outer: loop { + _debug!(inner, "Starting transaction {} (attempted: {})", id, attempted); + let server = match find_or_create_connection(inner, router, &commands[0]).await? { + Some(server) => server, + None => continue, + }; + + let mut idx = 0; + if attempted > 0 { + inner.counters.incr_redelivery_count(); + } + // send each of the commands. the first one is always MULTI + 'inner: while idx < commands.len() { + let command = commands[idx].duplicate(ResponseKind::Skip); + let rx = command.create_router_channel(); + + // wait on each response before sending the next command in order to handle errors or follow cluster + // redirections as quickly as possible. + match write_command(inner, router, &server, command, abort_on_error, Some(rx)).await { + Ok(TransactionResponse::Continue) => { + idx += 1; + continue 'inner; + }, + Ok(TransactionResponse::Retry(error)) => { + _debug!(inner, "Retrying trx {} after error: {:?}", id, error); + if let Err(e) = send_non_pipelined_discard(inner, router, &server, id).await { + _warn!(inner, "Error sending DISCARD in trx {}: {:?}", id, e); } - } else { - utils::reconnect_with_policy(inner, router).await? - }; - continue; - }, + attempted += 1; + if attempted >= max_attempts { + let _ = tx.send(Err(error)); + return Ok(()); + } else { + utils::reconnect_with_policy(inner, router).await?; + } + + continue 'outer; + }, + Ok(TransactionResponse::Redirection((kind, slot, server))) => { + redirections += 1; + if redirections > max_redirections { + let _ = tx.send(Err(RedisError::new( + RedisErrorKind::Cluster, + "Too many cluster redirections.", + ))); + return Ok(()); + } + + update_hash_slot(&mut commands, slot); + if let Err(e) = send_non_pipelined_discard(inner, router, &server, id).await { + _warn!(inner, "Error sending DISCARD in trx {}: {:?}", id, e); + } + utils::cluster_redirect_with_policy(inner, router, kind, slot, &server).await?; + + continue 'outer; + }, + Ok(TransactionResponse::Finished(frame)) => { + let _ = tx.send(Ok(frame)); + return Ok(()); + }, + Err(error) => { + // fatal errors that end the transaction + let _ = send_non_pipelined_discard(inner, router, &server, id).await; + let _ = tx.send(Err(error)); + return Ok(()); + }, + } } - }; - let mut idx = 0; - - // send the WATCH command before any of the trx commands - if let Some(watch) = watched.as_ref() { - let watch = watch.duplicate(ResponseKind::Skip); - let rx = watch.create_router_channel(); - - _debug!( - inner, - "Sending WATCH for {} keys in trx {} to {}", - watch.args().len(), - id, - server - ); - match write_command(inner, router, &server, watch, false, rx).await { - Ok(TransactionResponse::Continue) => { - _debug!(inner, "Successfully sent WATCH command before transaction {}.", id); + + match send_non_pipelined_exec(inner, router, &server, id).await { + Ok(TransactionResponse::Finished(frame)) => { + let _ = tx.send(Ok(frame)); + return Ok(()); }, Ok(TransactionResponse::Retry(error)) => { - _debug!(inner, "Retrying trx {} after WATCH error: {:?}.", id, error); + _debug!(inner, "Retrying trx {} after error: {:?}", id, error); + if let Err(e) = send_non_pipelined_discard(inner, router, &server, id).await { + _warn!(inner, "Error sending DISCARD in trx {}: {:?}", id, e); + } attempted += 1; if attempted >= max_attempts { @@ -205,56 +312,131 @@ pub async fn run( continue 'outer; }, - Ok(TransactionResponse::Redirection((kind, slot, server))) => { - redirections += 1; - if redirections > max_redirections { - let _ = tx.send(Err(RedisError::new( - RedisErrorKind::Cluster, - "Too many cluster redirections.", - ))); - return Ok(()); - } - - _debug!(inner, "Recv {} redirection to {} for WATCH in trx {}", kind, server, id); - update_hash_slot(&mut commands, slot); - utils::delay_cluster_sync(inner).await?; - utils::cluster_redirect_with_policy(inner, router, kind, slot, &server).await?; - continue 'outer; + Ok(TransactionResponse::Redirection((kind, slot, dest))) => { + // doesn't make sense on EXEC, but return it as an error so it isn't lost + let _ = send_non_pipelined_discard(inner, router, &server, id).await; + let _ = tx.send(Err(RedisError::new( + RedisErrorKind::Cluster, + format!("{} {} {}", kind, slot, dest), + ))); + return Ok(()); }, - Ok(TransactionResponse::Finished(frame)) => { - _warn!(inner, "Unexpected trx finished frame after WATCH."); - let _ = tx.send(Ok(frame)); + Ok(TransactionResponse::Continue) => { + _warn!(inner, "Invalid final response to transaction {}", id); + let _ = send_non_pipelined_discard(inner, router, &server, id).await; + let _ = tx.send(Err(RedisError::new_canceled())); return Ok(()); }, Err(error) => { + let _ = send_non_pipelined_discard(inner, router, &server, id).await; let _ = tx.send(Err(error)); return Ok(()); }, }; } + } - if attempted > 0 { - inner.counters.incr_redelivery_count(); + #[allow(unused_mut)] + pub async fn pipelined( + inner: &RefCount, + router: &mut Router, + mut commands: Vec, + id: u64, + mut tx: ResponseSender, + ) -> Result<(), RedisError> { + if commands.is_empty() { + let _ = tx.send(Ok(Resp3Frame::Null)); + return Ok(()); } - // send each of the commands. the first one is always MULTI - 'inner: while idx < commands.len() { - let command = commands[idx].duplicate(ResponseKind::Skip); - let rx = command.create_router_channel(); + // each of the commands should have the same options + let max_attempts = if commands[0].attempts_remaining == 0 { + inner.max_command_attempts() + } else { + commands[0].attempts_remaining + }; + let max_redirections = if commands[0].redirections_remaining == 0 { + inner.connection.max_redirections + } else { + commands[0].redirections_remaining + }; - match write_command(inner, router, &server, command, abort_on_error, rx).await { - Ok(TransactionResponse::Continue) => { - idx += 1; - continue 'inner; - }, - Ok(TransactionResponse::Retry(error)) => { - _debug!(inner, "Retrying trx {} after error: {:?}", id, error); - if let Err(e) = send_discard(inner, router, &server, id).await { - _warn!(inner, "Error sending DISCARD in trx {}: {:?}", id, e); - } + let mut attempted = 0; + let mut redirections = 0; + 'outer: loop { + _debug!(inner, "Starting transaction {} (attempted: {})", id, attempted); + let server = match find_or_create_connection(inner, router, &commands[0]).await? { + Some(server) => server, + None => continue, + }; + + if attempted > 0 { + inner.counters.incr_redelivery_count(); + } + let (exec_tx, exec_rx) = oneshot_channel(); + let buf: Vec<_> = repeat(Resp3Frame::Null).take(commands.len() + 1).collect(); + // pipelined transactions buffer their results until a response to EXEC is received + let response = ResponseKind::Buffer { + error_early: false, + expected: commands.len() + 1, + received: RefCount::new(AtomicUsize::new(0)), + tx: RefCount::new(Mutex::new(Some(exec_tx))), + frames: RefCount::new(Mutex::new(buf)), + index: 0, + }; + + // write each command in the pipeline + let pipeline = build_pipeline(&commands, response, id)?; + for command in pipeline.into_iter() { + match write_command(inner, router, &server, command, false, None).await? { + TransactionResponse::Continue => continue, + TransactionResponse::Retry(error) => { + _debug!(inner, "Retrying pipelined trx {} after error: {:?}", id, error); + if let Err(e) = send_non_pipelined_discard(inner, router, &server, id).await { + _warn!(inner, "Error sending pipelined discard: {:?}", e); + } + + attempted += 1; + if attempted >= max_attempts { + let _ = tx.send(Err(error)); + return Ok(()); + } else { + utils::reconnect_with_policy(inner, router).await?; + } + continue 'outer; + }, + _ => { + _error!(inner, "Unexpected pipelined write response."); + let _ = tx.send(Err(RedisError::new( + RedisErrorKind::Protocol, + "Unexpected pipeline write response.", + ))); + return Ok(()); + }, + } + } + // wait on the response and deconstruct the output frames + let mut response = match exec_rx.await.map_err(RedisError::from) { + Ok(Ok(frame)) => match frame { + Resp3Frame::Array { data, .. } => data, + _ => { + _error!(inner, "Unexpected pipelined exec response."); + let _ = tx.send(Err(RedisError::new( + RedisErrorKind::Protocol, + "Unexpected pipeline exec response.", + ))); + return Ok(()); + }, + }, + Ok(Err(err)) | Err(err) => { + _debug!( + inner, + "Reconnecting and retrying pipelined transaction after error: {:?}", + err + ); attempted += 1; if attempted >= max_attempts { - let _ = tx.send(Err(error)); + let _ = tx.send(Err(err)); return Ok(()); } else { utils::reconnect_with_policy(inner, router).await?; @@ -262,78 +444,76 @@ pub async fn run( continue 'outer; }, - Ok(TransactionResponse::Redirection((kind, slot, server))) => { - redirections += 1; - if redirections > max_redirections { - let _ = tx.send(Err(RedisError::new( - RedisErrorKind::Cluster, - "Too many cluster redirections.", - ))); - return Ok(()); - } + }; + if response.is_empty() { + let _ = tx.send(Err(RedisError::new( + RedisErrorKind::Protocol, + "Unexpected empty pipeline exec response.", + ))); + return Ok(()); + } - update_hash_slot(&mut commands, slot); - if let Err(e) = send_discard(inner, router, &server, id).await { - _warn!(inner, "Error sending DISCARD in trx {}: {:?}", id, e); + // check the last result for EXECABORT + let execabort = response + .last() + .and_then(|f| f.as_str()) + .map(|s| s.starts_with("EXECABORT")) + .unwrap_or(false); + + if execabort { + // find the first error, if it's a redirection then follow it and retry, otherwise return to the caller + let first_error = response.iter().enumerate().find_map(|(idx, frame)| { + if matches!(frame.kind(), FrameKind::SimpleError | FrameKind::BlobError) { + Some(idx) + } else { + None } - utils::cluster_redirect_with_policy(inner, router, kind, slot, &server).await?; + }); - continue 'outer; - }, - Ok(TransactionResponse::Finished(frame)) => { - let _ = tx.send(Ok(frame)); - return Ok(()); - }, - Err(error) => { - // fatal errors that end the transaction - let _ = send_discard(inner, router, &server, id).await; - let _ = tx.send(Err(error)); - return Ok(()); - }, - } - } + if let Some(idx) = first_error { + let first_error_frame = response[idx].take(); + // check if error is a cluster redirection, otherwise return the error to the caller + if first_error_frame.is_redirection() { + redirections += 1; + if redirections > max_redirections { + let _ = tx.send(Err(RedisError::new( + RedisErrorKind::Cluster, + "Too many cluster redirections.", + ))); + return Ok(()); + } - match send_exec(inner, router, &server, id).await { - Ok(TransactionResponse::Finished(frame)) => { - let _ = tx.send(Ok(frame)); - return Ok(()); - }, - Ok(TransactionResponse::Retry(error)) => { - _debug!(inner, "Retrying trx {} after error: {:?}", id, error); - if let Err(e) = send_discard(inner, router, &server, id).await { - _warn!(inner, "Error sending DISCARD in trx {}: {:?}", id, e); - } + let (kind, slot, dest) = parse_cluster_error_frame(inner, &first_error_frame, &server)?; + update_hash_slot(&mut commands, slot); + utils::cluster_redirect_with_policy(inner, router, kind, slot, &dest).await?; + continue 'outer; + } else { + // these errors are typically from the server, not from the connection layer + let error = first_error_frame.as_str().map(pretty_error).unwrap_or_else(|| { + RedisError::new( + RedisErrorKind::Protocol, + "Unexpected response to pipelined transaction.", + ) + }); - attempted += 1; - if attempted >= max_attempts { + let _ = tx.send(Err(error)); + return Ok(()); + } + } else { + // return the EXECABORT error to the caller if there's no other error + let error = response + .pop() + .and_then(|f| f.as_str().map(pretty_error)) + .unwrap_or_else(|| RedisError::new(RedisErrorKind::Protocol, "Invalid pipelined transaction response.")); let _ = tx.send(Err(error)); return Ok(()); - } else { - utils::reconnect_with_policy(inner, router).await?; } - - continue 'outer; - }, - Ok(TransactionResponse::Redirection((kind, slot, dest))) => { - // doesn't make sense on EXEC, but return it as an error so it isn't lost - let _ = send_discard(inner, router, &server, id).await; - let _ = tx.send(Err(RedisError::new( - RedisErrorKind::Cluster, - format!("{} {} {}", kind, slot, dest), - ))); - return Ok(()); - }, - Ok(TransactionResponse::Continue) => { - _warn!(inner, "Invalid final response to transaction {}", id); - let _ = send_discard(inner, router, &server, id).await; - let _ = tx.send(Err(RedisError::new_canceled())); - return Ok(()); - }, - Err(error) => { - let _ = send_discard(inner, router, &server, id).await; - let _ = tx.send(Err(error)); + } else { + // return the last frame to the caller + let last = response.pop().unwrap_or(Resp3Frame::Null); + let _ = tx.send(Ok(last)); return Ok(()); - }, - }; + } + } } } diff --git a/src/router/utils.rs b/src/router/utils.rs index caf36ccb..8a7d110c 100644 --- a/src/router/utils.rs +++ b/src/router/utils.rs @@ -10,23 +10,22 @@ use crate::{ utils as protocol_utils, }, router::{utils, Backpressure, Counters, Router, Written}, + runtime::{oneshot_channel, sleep, RefCount}, types::*, utils as client_utils, }; use futures::TryStreamExt; use std::{ cmp, - sync::Arc, time::{Duration, Instant}, }; -use tokio::{self, sync::oneshot::channel as oneshot_channel, time::sleep}; #[cfg(feature = "transactions")] use crate::protocol::command::ClusterErrorKind; /// Check the connection state and command flags to determine the backpressure policy to apply, if any. pub fn check_backpressure( - inner: &Arc, + inner: &RefCount, counters: &Counters, command: &RedisCommand, ) -> Result, RedisError> { @@ -64,20 +63,20 @@ pub fn check_backpressure( } #[cfg(feature = "partial-tracing")] -fn set_command_trace(inner: &Arc, command: &mut RedisCommand) { +fn set_command_trace(inner: &RefCount, command: &mut RedisCommand) { if inner.should_trace() { crate::trace::set_network_span(inner, command, true); } } #[cfg(not(feature = "partial-tracing"))] -fn set_command_trace(_inner: &Arc, _: &mut RedisCommand) {} +fn set_command_trace(_inner: &RefCount, _: &mut RedisCommand) {} /// Prepare the command, updating flags in place. /// /// Returns the RESP frame and whether the socket should be flushed. pub fn prepare_command( - inner: &Arc, + inner: &RefCount, counters: &Counters, command: &mut RedisCommand, ) -> Result<(ProtocolFrame, bool), RedisError> { @@ -104,7 +103,7 @@ pub fn prepare_command( /// Write a command on the provided writer half of a socket. pub async fn write_command( - inner: &Arc, + inner: &RefCount, writer: &mut RedisWriter, mut command: RedisCommand, force_flush: bool, @@ -175,13 +174,14 @@ pub async fn write_command( /// Check the shared connection command buffer to see if the oldest command blocks the router task on a /// response (not pipelined). -pub fn check_blocked_router(inner: &Arc, buffer: &SharedBuffer, error: &Option) { +pub fn check_blocked_router(inner: &RefCount, buffer: &SharedBuffer, error: &Option) { let command = match buffer.pop() { Some(cmd) => cmd, None => return, }; if command.has_router_channel() { - let tx = match command.take_router_tx() { + #[allow(unused_mut)] + let mut tx = match command.take_router_tx() { Some(tx) => tx, None => return, }; @@ -201,7 +201,11 @@ pub fn check_blocked_router(inner: &Arc, buffer: &SharedBuffer /// Filter the shared buffer, removing commands that reached the max number of attempts and responding to each caller /// with the underlying error. -pub fn check_final_write_attempt(inner: &Arc, buffer: &SharedBuffer, error: &Option) { +pub fn check_final_write_attempt( + inner: &RefCount, + buffer: &SharedBuffer, + error: &Option, +) { buffer .drain() .into_iter() @@ -227,7 +231,7 @@ pub fn check_final_write_attempt(inner: &Arc, buffer: &SharedB } /// Read the next reconnection delay for the client. -pub fn next_reconnection_delay(inner: &Arc) -> Result { +pub fn next_reconnection_delay(inner: &RefCount) -> Result { inner .policy .write() @@ -238,7 +242,7 @@ pub fn next_reconnection_delay(inner: &Arc) -> Result, router: &mut Router) -> Result<(), RedisError> { +pub async fn reconnect_once(inner: &RefCount, router: &mut Router) -> Result<(), RedisError> { client_utils::set_client_state(&inner.state, ClientState::Connecting); if let Err(e) = Box::pin(router.connect()).await { _debug!(inner, "Failed reconnecting with error: {:?}", e); @@ -268,7 +272,10 @@ pub async fn reconnect_once(inner: &Arc, router: &mut Router) /// Reconnect to the server(s) until the max reconnect policy attempts are reached. /// /// Errors from this function should end the connection task. -pub async fn reconnect_with_policy(inner: &Arc, router: &mut Router) -> Result<(), RedisError> { +pub async fn reconnect_with_policy( + inner: &RefCount, + router: &mut Router, +) -> Result<(), RedisError> { let mut delay = utils::next_reconnection_delay(inner)?; loop { @@ -299,7 +306,7 @@ pub async fn reconnect_with_policy(inner: &Arc, router: &mut R /// Attempt to follow a cluster redirect, reconnecting as needed until the max reconnections attempts is reached. #[cfg(feature = "transactions")] pub async fn cluster_redirect_with_policy( - inner: &Arc, + inner: &RefCount, router: &mut Router, kind: ClusterErrorKind, slot: u16, @@ -329,7 +336,7 @@ pub async fn cluster_redirect_with_policy( /// /// Errors from this function should end the connection task. pub async fn send_asking_with_policy( - inner: &Arc, + inner: &RefCount, router: &mut Router, server: &Server, slot: u16, @@ -407,7 +414,7 @@ pub async fn send_asking_with_policy( #[cfg(feature = "replicas")] async fn sync_cluster_replicas( - inner: &Arc, + inner: &RefCount, router: &mut Router, reset: bool, ) -> Result<(), RedisError> { @@ -425,7 +432,7 @@ async fn sync_cluster_replicas( /// Repeatedly try to sync the cluster state, reconnecting as needed until the max reconnection attempts is reached. #[cfg(feature = "replicas")] pub async fn sync_replicas_with_policy( - inner: &Arc, + inner: &RefCount, router: &mut Router, reset: bool, ) -> Result<(), RedisError> { @@ -460,7 +467,7 @@ pub async fn sync_replicas_with_policy( } /// Wait for `inner.connection.cluster_cache_update_delay`. -pub async fn delay_cluster_sync(inner: &Arc) -> Result<(), RedisError> { +pub async fn delay_cluster_sync(inner: &RefCount) -> Result<(), RedisError> { if inner.config.server.is_clustered() && !inner.connection.cluster_cache_update_delay.is_zero() { inner .wait_with_interrupt(inner.connection.cluster_cache_update_delay) @@ -473,7 +480,10 @@ pub async fn delay_cluster_sync(inner: &Arc) -> Result<(), Red /// Repeatedly try to sync the cluster state, reconnecting as needed until the max reconnection attempts is reached. /// /// Errors from this function should end the connection task. -pub async fn sync_cluster_with_policy(inner: &Arc, router: &mut Router) -> Result<(), RedisError> { +pub async fn sync_cluster_with_policy( + inner: &RefCount, + router: &mut Router, +) -> Result<(), RedisError> { let mut delay = Duration::from_millis(0); loop { @@ -504,7 +514,7 @@ pub async fn sync_cluster_with_policy(inner: &Arc, router: &mu Ok(()) } -pub fn defer_reconnect(inner: &Arc) { +pub fn defer_reconnect(inner: &RefCount) { if inner.config.server.is_clustered() { let (tx, _) = oneshot_channel(); let cmd = RouterCommand::SyncCluster { tx }; @@ -527,7 +537,7 @@ pub fn defer_reconnect(inner: &Arc) { /// Attempt to read the next frame from the reader half of a connection. pub async fn next_frame( - inner: &Arc, + inner: &RefCount, conn: &mut SplitStreamKind, server: &Server, buffer: &SharedBuffer, @@ -546,7 +556,7 @@ pub async fn next_frame( // complicated. // // The approach here implements a ~~hack~~ heuristic where we measure the time since first noticing a new - // frame in the shared buffer from the reader task perspective. This only works because we use `Stream::next` + // frame in the shared buffer from the reader task's perspective. This only works because we use `Stream::next` // which is noted to be cancellation-safe in the tokio::select! docs. With this implementation the worst case // error margin is an extra `interval`. diff --git a/src/trace/disabled.rs b/src/trace/disabled.rs index 566c3029..c5af3b78 100644 --- a/src/trace/disabled.rs +++ b/src/trace/disabled.rs @@ -5,9 +5,9 @@ use crate::modules::inner::RedisClientInner; #[cfg(not(any(feature = "full-tracing", feature = "partial-tracing")))] use crate::protocol::command::RedisCommand; #[cfg(not(any(feature = "full-tracing", feature = "partial-tracing")))] -use redis_protocol::resp3::types::BytesFrame as Frame; +use crate::runtime::RefCount; #[cfg(not(any(feature = "full-tracing", feature = "partial-tracing")))] -use std::sync::Arc; +use redis_protocol::resp3::types::BytesFrame as Frame; /// Fake span for mocking tracing functions. #[cfg(not(feature = "full-tracing"))] @@ -23,10 +23,10 @@ impl Span { } #[cfg(not(any(feature = "full-tracing", feature = "partial-tracing")))] -pub fn set_network_span(_inner: &Arc, _command: &mut RedisCommand, _flush: bool) {} +pub fn set_network_span(_inner: &RefCount, _command: &mut RedisCommand, _flush: bool) {} #[cfg(not(any(feature = "full-tracing", feature = "partial-tracing")))] -pub fn create_pubsub_span(_inner: &Arc, _frame: &Frame) -> Option { +pub fn create_pubsub_span(_inner: &RefCount, _frame: &Frame) -> Option { Some(Span {}) } diff --git a/src/trace/enabled.rs b/src/trace/enabled.rs index a1544044..719c6d9e 100644 --- a/src/trace/enabled.rs +++ b/src/trace/enabled.rs @@ -1,6 +1,6 @@ -use crate::{modules::inner::RedisClientInner, protocol::command::RedisCommand}; +use crate::{modules::inner::RedisClientInner, protocol::command::RedisCommand, runtime::RefCount}; use redis_protocol::resp3::types::{BytesFrame as Resp3Frame, Resp3Frame as _Resp3Frame}; -use std::{fmt, ops::Deref, sync::Arc}; +use std::{fmt, ops::Deref}; pub use tracing::span::Span; use tracing::{event, field::Empty, Id as TraceId, Level}; @@ -42,7 +42,7 @@ impl fmt::Debug for CommandTraces { } } -pub fn set_network_span(inner: &Arc, command: &mut RedisCommand, flush: bool) { +pub fn set_network_span(inner: &RefCount, command: &mut RedisCommand, flush: bool) { trace!("Setting network span from command {}", command.debug_id()); let span = fspan!(command, inner.tracing_span_level(), "fred.rtt", "cmd.flush" = flush); span.in_scope(|| {}); @@ -54,7 +54,7 @@ pub fn record_response_size(span: &Span, frame: &Resp3Frame) { span.record("cmd.res", &frame.encode_len()); } -pub fn create_command_span(inner: &Arc) -> Span { +pub fn create_command_span(inner: &RefCount) -> Span { span_lvl!( inner.tracing_span_level(), "fred.command", @@ -67,28 +67,28 @@ pub fn create_command_span(inner: &Arc) -> Span { } #[cfg(feature = "full-tracing")] -pub fn create_args_span(parent: Option, inner: &Arc) -> Span { +pub fn create_args_span(parent: Option, inner: &RefCount) -> Span { span_lvl!(inner.full_tracing_span_level(), parent: parent, "fred.prepare", "cmd.args" = Empty) } #[cfg(not(feature = "full-tracing"))] -pub fn create_args_span(_parent: Option, _inner: &Arc) -> FakeSpan { +pub fn create_args_span(_parent: Option, _inner: &RefCount) -> FakeSpan { FakeSpan {} } #[cfg(feature = "full-tracing")] -pub fn create_queued_span(parent: Option, inner: &Arc) -> Span { +pub fn create_queued_span(parent: Option, inner: &RefCount) -> Span { let buf_len = inner.counters.read_cmd_buffer_len(); span_lvl!(inner.full_tracing_span_level(), parent: parent, "fred.queued", buf_len) } #[cfg(not(feature = "full-tracing"))] -pub fn create_queued_span(_parent: Option, _inner: &Arc) -> FakeSpan { +pub fn create_queued_span(_parent: Option, _inner: &RefCount) -> FakeSpan { FakeSpan {} } #[cfg(feature = "full-tracing")] -pub fn create_pubsub_span(inner: &Arc, frame: &Resp3Frame) -> Option { +pub fn create_pubsub_span(inner: &RefCount, frame: &Resp3Frame) -> Option { if inner.should_trace() { let span = span_lvl!( inner.full_tracing_span_level(), @@ -107,7 +107,7 @@ pub fn create_pubsub_span(inner: &Arc, frame: &Resp3Frame) -> } #[cfg(not(feature = "full-tracing"))] -pub fn create_pubsub_span(_inner: &Arc, _frame: &Resp3Frame) -> Option { +pub fn create_pubsub_span(_inner: &RefCount, _frame: &Resp3Frame) -> Option { Some(FakeSpan {}) } diff --git a/src/types/builder.rs b/src/types/builder.rs index ffb21341..2f3d9ace 100644 --- a/src/types/builder.rs +++ b/src/types/builder.rs @@ -5,6 +5,7 @@ use crate::{ types::{ConnectionConfig, PerformanceConfig, RedisConfig, ServerConfig}, }; +#[cfg(not(feature = "glommio"))] use crate::clients::ExclusivePool; #[cfg(feature = "subscriber-client")] use crate::clients::SubscriberClient; @@ -249,6 +250,7 @@ impl Builder { } /// Create a new exclusive client pool. + #[cfg(not(feature = "glommio"))] pub fn build_exclusive_pool(&self, size: usize) -> Result { if let Some(config) = self.config.as_ref() { ExclusivePool::new( diff --git a/src/types/config.rs b/src/types/config.rs index b5621c67..0a9ed057 100644 --- a/src/types/config.rs +++ b/src/types/config.rs @@ -1,15 +1,21 @@ pub use crate::protocol::types::Server; -use crate::{error::RedisError, protocol::command::RedisCommand, types::RespVersion, utils}; +use crate::{ + error::{RedisError, RedisErrorKind}, + protocol::command::RedisCommand, + types::{ClusterHash, RespVersion}, + utils, +}; use socket2::TcpKeepalive; -use std::{cmp, time::Duration}; +use std::{cmp, fmt::Debug, time::Duration}; use url::Url; -use crate::error::RedisErrorKind; #[cfg(feature = "mocks")] use crate::mocks::Mocks; +#[cfg(feature = "credential-provider")] +use async_trait::async_trait; #[cfg(feature = "unix-sockets")] use std::path::PathBuf; -#[cfg(feature = "mocks")] +#[cfg(any(feature = "mocks", feature = "credential-provider"))] use std::sync::Arc; #[cfg(any( @@ -30,7 +36,6 @@ pub use crate::protocol::tls::{HostMapping, TlsConfig, TlsConnector, TlsHostMapp #[cfg(feature = "replicas")] #[cfg_attr(docsrs, doc(cfg(feature = "replicas")))] pub use crate::router::replicas::{ReplicaConfig, ReplicaFilter}; -use crate::types::ClusterHash; /// The default amount of jitter when waiting to reconnect. pub const DEFAULT_JITTER_MS: u32 = 100; @@ -499,6 +504,20 @@ pub struct ConnectionConfig { #[cfg(feature = "custom-reconnect-errors")] #[cfg_attr(docsrs, doc(cfg(feature = "custom-reconnect-errors")))] pub reconnect_errors: Vec, + + /// The task queue onto which routing tasks will be spawned. + /// + /// May cause a panic if [spawn_local_into](glommio::spawn_local_into) fails. + #[cfg(feature = "glommio")] + #[cfg_attr(docsrs, doc(cfg(feature = "glommio")))] + pub router_task_queue: Option, + + /// The task queue onto which connection reader tasks will be spawned. + /// + /// May cause a panic if [spawn_local_into](glommio::spawn_local_into) fails. + #[cfg(feature = "glommio")] + #[cfg_attr(docsrs, doc(cfg(feature = "glommio")))] + pub connection_task_queue: Option, } impl Default for ConnectionConfig { @@ -523,6 +542,10 @@ impl Default for ConnectionConfig { ReconnectError::Loading, ReconnectError::ReadOnly, ], + #[cfg(feature = "glommio")] + router_task_queue: None, + #[cfg(feature = "glommio")] + connection_task_queue: None, } } } @@ -580,6 +603,34 @@ impl Default for PerformanceConfig { } } +/// A trait that can be used to override the credentials used in each `AUTH` or `HELLO` command. +#[async_trait] +#[cfg(all(feature = "credential-provider", not(feature = "glommio")))] +#[cfg_attr(docsrs, doc(cfg(feature = "credential-provider")))] +pub trait CredentialProvider: Debug + Send + Sync + 'static { + /// Read the username and password that should be used in the next `AUTH` or `HELLO` command. + async fn fetch(&self, server: Option<&Server>) -> Result<(Option, Option), RedisError>; + + /// Configure the client to call [fetch](Self::fetch) and send `AUTH` or `HELLO` on some interval. + fn refresh_interval(&self) -> Option { + None + } +} + +/// A trait that can be used to override the credentials used in each `AUTH` or `HELLO` command. +#[async_trait(?Send)] +#[cfg(all(feature = "credential-provider", feature = "glommio"))] +#[cfg_attr(docsrs, doc(cfg(feature = "credential-provider")))] +pub trait CredentialProvider: Debug + 'static { + /// Read the username and password that should be used in the next `AUTH` or `HELLO` command. + async fn fetch(&self, server: Option<&Server>) -> Result<(Option, Option), RedisError>; + + /// Configure the client to call [fetch](Self::fetch) and send `AUTH` or `HELLO` on some interval. + fn refresh_interval(&self) -> Option { + None + } +} + /// Configuration options for a `RedisClient`. #[derive(Clone, Debug)] pub struct RedisConfig { @@ -611,10 +662,11 @@ pub struct RedisConfig { /// /// Default: `None` pub password: Option, + /// Connection configuration for the server(s). /// /// Default: `Centralized(localhost, 6379)` - pub server: ServerConfig, + pub server: ServerConfig, /// The protocol version to use when communicating with the server(s). /// /// If RESP3 is specified the client will automatically use `HELLO` when authenticating. **This requires Redis @@ -625,7 +677,7 @@ pub struct RedisConfig { /// has a slightly different type system than RESP2. /// /// Default: `RESP2` - pub version: RespVersion, + pub version: RespVersion, /// An optional database number that the client will automatically `SELECT` after connecting or reconnecting. /// /// It is recommended that callers use this field instead of putting a `select()` call inside the `on_reconnect` @@ -633,7 +685,7 @@ pub struct RedisConfig { /// the `on_reconnect` block. /// /// Default: `None` - pub database: Option, + pub database: Option, /// TLS configuration options. /// /// Default: `None` @@ -650,17 +702,26 @@ pub struct RedisConfig { feature = "enable-rustls-ring" ))) )] - pub tls: Option, + pub tls: Option, /// Tracing configuration options. #[cfg(feature = "partial-tracing")] #[cfg_attr(docsrs, doc(cfg(feature = "partial-tracing")))] - pub tracing: TracingConfig, + pub tracing: TracingConfig, /// An optional [mocking layer](crate::mocks) to intercept and process commands. /// /// Default: `None` #[cfg(feature = "mocks")] #[cfg_attr(docsrs, doc(cfg(feature = "mocks")))] - pub mocks: Option>, + pub mocks: Option>, + /// An optional credential provider callback interface. + /// + /// Default: `None` + /// + /// When used with the `sentinel-auth` feature this interface will take precedence over all `username` and + /// `password` fields for both sentinel nodes and Redis servers. + #[cfg(feature = "credential-provider")] + #[cfg_attr(docsrs, doc(cfg(feature = "credential-provider")))] + pub credential_provider: Option>, } impl PartialEq for RedisConfig { @@ -680,27 +741,30 @@ impl Eq for RedisConfig {} impl Default for RedisConfig { fn default() -> Self { RedisConfig { - fail_fast: true, - blocking: Blocking::default(), - username: None, - password: None, - server: ServerConfig::default(), - version: RespVersion::RESP2, - database: None, + fail_fast: true, + blocking: Blocking::default(), + username: None, + password: None, + server: ServerConfig::default(), + version: RespVersion::RESP2, + database: None, #[cfg(any( feature = "enable-native-tls", feature = "enable-rustls", feature = "enable-rustls-ring" ))] - tls: None, + tls: None, #[cfg(feature = "partial-tracing")] - tracing: TracingConfig::default(), + tracing: TracingConfig::default(), #[cfg(feature = "mocks")] - mocks: None, + mocks: None, + #[cfg(feature = "credential-provider")] + credential_provider: None, } } } +#[cfg_attr(docsrs, allow(rustdoc::broken_intra_doc_links))] impl RedisConfig { /// Whether the client uses TLS. #[cfg(any( @@ -1299,25 +1363,27 @@ impl Default for SentinelConfig { impl From for RedisConfig { fn from(config: SentinelConfig) -> Self { RedisConfig { - server: ServerConfig::Centralized { + server: ServerConfig::Centralized { server: Server::new(config.host, config.port), }, - fail_fast: true, - database: None, - blocking: Blocking::Block, - username: config.username, - password: config.password, - version: RespVersion::RESP2, + fail_fast: true, + database: None, + blocking: Blocking::Block, + username: config.username, + password: config.password, + version: RespVersion::RESP2, #[cfg(any( feature = "enable-native-tls", feature = "enable-rustls", feature = "enable-rustls-ring" ))] - tls: config.tls, + tls: config.tls, #[cfg(feature = "partial-tracing")] - tracing: config.tracing, + tracing: config.tracing, #[cfg(feature = "mocks")] - mocks: None, + mocks: None, + #[cfg(feature = "credential-provider")] + credential_provider: None, } } } diff --git a/src/types/mod.rs b/src/types/mod.rs index 904d3b0d..1148edea 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -1,7 +1,6 @@ -use crate::error::RedisError; pub use crate::modules::response::{FromRedis, FromRedisKey}; +use crate::{error::RedisError, runtime::JoinHandle}; pub use redis_protocol::resp3::types::{BytesFrame as Resp3Frame, RespVersion}; -use tokio::task::JoinHandle; mod args; mod builder; diff --git a/src/types/scan.rs b/src/types/scan.rs index 45c08779..f5d4f62f 100644 --- a/src/types/scan.rs +++ b/src/types/scan.rs @@ -8,11 +8,12 @@ use crate::{ responders::ResponseKind, types::{KeyScanInner, ValueScanInner}, }, + runtime::RefCount, types::{RedisKey, RedisMap, RedisValue}, utils, }; use bytes_utils::Str; -use std::{borrow::Cow, sync::Arc}; +use std::borrow::Cow; /// The types of values supported by the [type](https://redis.io/commands/type) command. #[derive(Clone, Debug, Eq, PartialEq)] @@ -82,7 +83,7 @@ pub trait Scanner { /// The result of a SCAN operation. pub struct ScanResult { pub(crate) results: Option>, - pub(crate) inner: Arc, + pub(crate) inner: RefCount, pub(crate) scan_state: KeyScanInner, pub(crate) can_continue: bool, } @@ -129,7 +130,7 @@ impl Scanner for ScanResult { /// The result of a HSCAN operation. pub struct HScanResult { pub(crate) results: Option, - pub(crate) inner: Arc, + pub(crate) inner: RefCount, pub(crate) scan_state: ValueScanInner, pub(crate) can_continue: bool, } @@ -173,7 +174,7 @@ impl Scanner for HScanResult { /// The result of a SSCAN operation. pub struct SScanResult { pub(crate) results: Option>, - pub(crate) inner: Arc, + pub(crate) inner: RefCount, pub(crate) scan_state: ValueScanInner, pub(crate) can_continue: bool, } @@ -217,7 +218,7 @@ impl Scanner for SScanResult { /// The result of a ZSCAN operation. pub struct ZScanResult { pub(crate) results: Option>, - pub(crate) inner: Arc, + pub(crate) inner: RefCount, pub(crate) scan_state: ValueScanInner, pub(crate) can_continue: bool, } diff --git a/src/utils.rs b/src/utils.rs index 35f4995d..221e18dd 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -7,9 +7,20 @@ use crate::{ responders::ResponseKind, utils as protocol_utils, }, + runtime::{ + broadcast_channel, + oneshot_channel, + sleep, + unbounded_channel, + AtomicBool, + AtomicUsize, + BroadcastSender, + RefCount, + RefSwap, + RwLock, + }, types::*, }; -use arc_swap::ArcSwap; use bytes::Bytes; use bytes_utils::Str; use float_cmp::approx_eq; @@ -19,27 +30,9 @@ use futures::{ Future, TryFutureExt, }; -use parking_lot::RwLock; use rand::{self, distributions::Alphanumeric, Rng}; use redis_protocol::resp3::types::BytesFrame as Resp3Frame; -use std::{ - collections::HashMap, - convert::TryInto, - f64, - sync::{ - atomic::{AtomicBool, AtomicUsize, Ordering}, - Arc, - }, - time::Duration, -}; -use tokio::{ - sync::{ - broadcast::{channel as broadcast_channel, Sender as BroadcastSender}, - mpsc::unbounded_channel, - oneshot::channel as oneshot_channel, - }, - time::sleep, -}; +use std::{collections::HashMap, convert::TryInto, f64, sync::atomic::Ordering, time::Duration}; use url::Url; use urlencoding::decode as percent_decode; @@ -49,11 +42,11 @@ use urlencoding::decode as percent_decode; feature = "enable-rustls-ring" ))] use crate::protocol::tls::{TlsConfig, TlsConnector}; +#[cfg(feature = "transactions")] +use crate::runtime::Mutex; #[cfg(any(feature = "full-tracing", feature = "partial-tracing"))] use crate::trace; #[cfg(feature = "transactions")] -use parking_lot::Mutex; -#[cfg(feature = "transactions")] use std::mem; #[cfg(feature = "unix-sockets")] use std::path::{Path, PathBuf}; @@ -330,7 +323,7 @@ where } /// Disconnect any state shared with the last router task spawned by the client. -pub fn reset_router_task(inner: &Arc) { +pub fn reset_router_task(inner: &RefCount) { let _guard = inner._lock.lock(); if !inner.has_command_rx() { @@ -338,6 +331,9 @@ pub fn reset_router_task(inner: &Arc) { // another connection task is running. this will let the command channel drain, then it'll drop everything on // the old connection/router interface. let (tx, rx) = unbounded_channel(); + #[cfg(feature = "glommio")] + let tx = tx.into(); + let old_command_tx = inner.swap_command_tx(tx); inner.store_command_rx(rx, true); close_router_channel(inner, old_command_tx); @@ -345,12 +341,12 @@ pub fn reset_router_task(inner: &Arc) { } /// Whether the router should check and interrupt the blocked command. -async fn should_enforce_blocking_policy(inner: &Arc, command: &RedisCommand) -> bool { +async fn should_enforce_blocking_policy(inner: &RefCount, command: &RedisCommand) -> bool { if command.kind.closes_connection() { return false; } if matches!(inner.config.blocking, Blocking::Error | Blocking::Interrupt) { - inner.backchannel.read().await.is_blocked() + inner.backchannel.write().await.is_blocked() } else { false } @@ -358,11 +354,11 @@ async fn should_enforce_blocking_policy(inner: &Arc, command: /// Interrupt the currently blocked connection (if found) with the provided flag. pub async fn interrupt_blocked_connection( - inner: &Arc, + inner: &RefCount, flag: ClientUnblockFlag, ) -> Result<(), RedisError> { let connection_id = { - let backchannel = inner.backchannel.read().await; + let backchannel = inner.backchannel.write().await; let server = match backchannel.blocked_server() { Some(server) => server, None => return Err(RedisError::new(RedisErrorKind::Unknown, "Connection is not blocked.")), @@ -391,7 +387,7 @@ pub async fn interrupt_blocked_connection( /// Check the status of the connection (usually before sending a command) to determine whether the connection should /// be unblocked automatically. -async fn check_blocking_policy(inner: &Arc, command: &RedisCommand) -> Result<(), RedisError> { +async fn check_blocking_policy(inner: &RefCount, command: &RedisCommand) -> Result<(), RedisError> { if should_enforce_blocking_policy(inner, command).await { _debug!( inner, @@ -527,7 +523,7 @@ where /// /// A new connection may be created. pub async fn backchannel_request_response( - inner: &Arc, + inner: &RefCount, command: RedisCommand, use_blocked: bool, ) -> Result { @@ -539,7 +535,7 @@ pub async fn backchannel_request_response( /// Check for a scan pattern without a hash tag, or with a wildcard in the hash tag. /// /// These patterns will result in scanning a random node if used against a clustered redis. -pub fn clustered_scan_pattern_has_hash_tag(inner: &Arc, pattern: &str) -> bool { +pub fn clustered_scan_pattern_has_hash_tag(inner: &RefCount, pattern: &str) -> bool { let (mut i, mut j, mut has_wildcard) = (None, None, false); for (idx, c) in pattern.chars().enumerate() { if c == '{' && i.is_none() { @@ -717,9 +713,9 @@ pub fn tls_config_from_url(tls: bool) -> Result, RedisError> { } } -pub fn swap_new_broadcast_channel(old: &ArcSwap>, capacity: usize) { +pub fn swap_new_broadcast_channel(old: &RefSwap>>, capacity: usize) { let new = broadcast_channel(capacity).0; - old.swap(Arc::new(new)); + old.swap(RefCount::new(new)); } pub fn url_uses_tls(url: &Url) -> bool { @@ -860,12 +856,12 @@ pub fn parse_url_sentinel_password(url: &Url) -> Option { }) } -pub async fn clear_backchannel_state(inner: &Arc) { +pub async fn clear_backchannel_state(inner: &RefCount) { inner.backchannel.write().await.clear_router_state(inner).await; } /// Send QUIT to the servers and clean up the old router task's state. -fn close_router_channel(inner: &Arc, command_tx: Arc) { +fn close_router_channel(inner: &RefCount, command_tx: RefCount) { inner.notifications.broadcast_close(); inner.reset_server_state(); diff --git a/tests/doc-glommio.sh b/tests/doc-glommio.sh new file mode 100755 index 00000000..6888feba --- /dev/null +++ b/tests/doc-glommio.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +FEATURES="network-logs custom-reconnect-errors serde-json blocking-encoding credential-provider mocks + full-tracing monitor metrics sentinel-client subscriber-client dns debug-ids sentinel-auth + replicas sha-1 transactions i-all glommio i-redis-stack enable-rustls enable-native-tls" + +RUSTDOCFLAGS="" cargo +nightly rustdoc --features "$FEATURES" "$@" -- --cfg docsrs \ No newline at end of file diff --git a/tests/doc.sh b/tests/doc.sh index f4e91448..e4a64963 100755 --- a/tests/doc.sh +++ b/tests/doc.sh @@ -1,3 +1,8 @@ #!/bin/bash -cargo +nightly rustdoc --all-features "$@" -- --cfg docsrs \ No newline at end of file + +FEATURES="network-logs custom-reconnect-errors serde-json blocking-encoding credential-provider + full-tracing monitor metrics sentinel-client subscriber-client dns debug-ids sentinel-auth + replicas sha-1 transactions i-all unix-sockets i-redis-stack enable-rustls enable-native-tls" + +cargo +nightly rustdoc --features "$FEATURES" "$@" -- --cfg docsrs \ No newline at end of file diff --git a/tests/docker/compose/base.yml b/tests/docker/compose/base.yml index 997f06ce..a5df970b 100644 --- a/tests/docker/compose/base.yml +++ b/tests/docker/compose/base.yml @@ -1,5 +1,9 @@ version: '2' +networks: + fred-tests: + driver: bridge + services: debug: depends_on: diff --git a/tests/docker/compose/glommio.yml b/tests/docker/compose/glommio.yml new file mode 100644 index 00000000..79d6be79 --- /dev/null +++ b/tests/docker/compose/glommio.yml @@ -0,0 +1,25 @@ +version: '2' + +networks: + fred-tests: + driver: bridge + +services: + glommio: + container_name: "check-glommio" + build: + context: ../../../ + dockerfile: tests/docker/runners/images/debug.dockerfile + args: + REDIS_VERSION: "${REDIS_VERSION}" + networks: + - fred-tests + command: + - "/project/tests/docker/runners/bash/check-glommio.sh" + - "${TEST_ARGV}" + environment: + RUST_LOG: "${RUST_LOG}" + RUST_BACKTRACE: "${RUST_BACKTRACE}" + volumes: + - "../../..:/project" + - "~/.cargo/registry:/usr/local/cargo/registry" \ No newline at end of file diff --git a/tests/docker/runners/bash/all-features.sh b/tests/docker/runners/bash/all-features.sh index 0e12ef68..5ab31598 100755 --- a/tests/docker/runners/bash/all-features.sh +++ b/tests/docker/runners/bash/all-features.sh @@ -15,7 +15,7 @@ done # those features individually. FEATURES="network-logs custom-reconnect-errors serde-json blocking-encoding full-tracing monitor metrics sentinel-client subscriber-client dns debug-ids - replicas sha-1 transactions i-all" + replicas sha-1 transactions i-all credential-provider" if [ -z "$FRED_CI_NEXTEST" ]; then cargo test --release --lib --tests --features "$FEATURES" -- --test-threads=1 "$@" diff --git a/tests/docker/runners/bash/check-glommio.sh b/tests/docker/runners/bash/check-glommio.sh new file mode 100755 index 00000000..b8027635 --- /dev/null +++ b/tests/docker/runners/bash/check-glommio.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +FEATURES="network-logs custom-reconnect-errors serde-json blocking-encoding credential-provider + full-tracing monitor metrics sentinel-client subscriber-client dns debug-ids mocks + replicas sha-1 transactions i-all glommio i-redis-stack enable-rustls enable-native-tls" + +cargo clippy --features "$FEATURES" -- "$@" \ No newline at end of file diff --git a/tests/docker/runners/bash/mocks.sh b/tests/docker/runners/bash/mocks.sh index 585fd118..18f70900 100755 --- a/tests/docker/runners/bash/mocks.sh +++ b/tests/docker/runners/bash/mocks.sh @@ -1,7 +1,7 @@ #!/bin/bash if [ -z "$FRED_CI_NEXTEST" ]; then - cargo test --release --lib --features "mocks i-keys" + cargo test --release --lib --features "mocks i-keys" "$@" else - cargo nextest run --release --lib --features "mocks i-keys" + cargo nextest run --release --lib --features "mocks i-keys" "$@" fi \ No newline at end of file diff --git a/tests/docker/runners/images/base.dockerfile b/tests/docker/runners/images/base.dockerfile index 86832148..ca2e8f98 100644 --- a/tests/docker/runners/images/base.dockerfile +++ b/tests/docker/runners/images/base.dockerfile @@ -1,6 +1,6 @@ # https://github.com/docker/for-mac/issues/5548#issuecomment-1029204019 # FROM rust:1.77-slim-buster -FROM rust:1.78-slim-bullseye +FROM rust:1.80-slim-bullseye WORKDIR /project diff --git a/tests/docker/runners/images/ci.dockerfile b/tests/docker/runners/images/ci.dockerfile index 449bb34c..c787a76a 100644 --- a/tests/docker/runners/images/ci.dockerfile +++ b/tests/docker/runners/images/ci.dockerfile @@ -1,4 +1,4 @@ -FROM rust:1.78-slim-buster +FROM rust:1.80-slim-buster WORKDIR /project # circleci doesn't mount volumes with a remote docker engine so we have to copy everything diff --git a/tests/docker/runners/images/debug.dockerfile b/tests/docker/runners/images/debug.dockerfile index 86832148..1986bd83 100644 --- a/tests/docker/runners/images/debug.dockerfile +++ b/tests/docker/runners/images/debug.dockerfile @@ -1,6 +1,6 @@ # https://github.com/docker/for-mac/issues/5548#issuecomment-1029204019 # FROM rust:1.77-slim-buster -FROM rust:1.78-slim-bullseye +FROM rust:1.80-slim-bullseye WORKDIR /project @@ -19,8 +19,9 @@ ARG FRED_REDIS_SENTINEL_HOST ARG FRED_REDIS_SENTINEL_PORT ARG CIRCLECI_TESTS -RUN USER=root apt-get update && apt-get install -y build-essential libssl-dev dnsutils curl pkg-config cmake +RUN USER=root apt-get update && apt-get install -y build-essential libssl-dev dnsutils curl pkg-config cmake git RUN echo "REDIS_VERSION=$REDIS_VERSION" # For debugging -RUN cargo --version && rustc --version \ No newline at end of file +RUN cargo --version && rustc --version +RUN rustup component add clippy && rustup install nightly \ No newline at end of file diff --git a/tests/integration/centralized.rs b/tests/integration/centralized.rs index d6668bb2..aedec39a 100644 --- a/tests/integration/centralized.rs +++ b/tests/integration/centralized.rs @@ -33,7 +33,6 @@ mod keys { #[cfg(all(feature = "transactions", feature = "i-keys"))] mod multi { - centralized_test!(multi, should_run_get_set_trx); centralized_test_panic!(multi, should_run_error_get_set_trx); } @@ -48,7 +47,15 @@ mod other { centralized_test!(other, pool_should_connect_correctly_via_wait_interface); centralized_test!(other, pool_should_fail_with_bad_host_via_wait_interface); centralized_test!(other, should_fail_on_centralized_connect); + centralized_test!(other, should_safely_change_protocols_repeatedly); + centralized_test!(other, should_gracefully_quit); + #[cfg(all(feature = "transactions", feature = "i-keys", feature = "i-hashes"))] + centralized_test!(other, should_fail_pipeline_transaction_error); + #[cfg(all(feature = "transactions", feature = "i-keys"))] + centralized_test!(other, should_pipeline_transaction); + #[cfg(feature = "credential-provider")] + centralized_test!(other, should_use_credential_provider); #[cfg(feature = "metrics")] centralized_test!(other, should_track_size_stats); #[cfg(all(feature = "i-client", feature = "i-lists"))] @@ -59,7 +66,6 @@ mod other { centralized_test!(other, should_error_when_blocked); #[cfg(all(feature = "i-keys", feature = "i-hashes"))] centralized_test!(other, should_smoke_test_from_redis_impl); - centralized_test!(other, should_safely_change_protocols_repeatedly); #[cfg(feature = "i-keys")] centralized_test!(other, should_pipeline_all); #[cfg(all(feature = "i-keys", feature = "i-hashes"))] @@ -70,7 +76,6 @@ mod other { centralized_test!(other, should_pipeline_try_all); #[cfg(feature = "i-server")] centralized_test!(other, should_use_all_cluster_nodes_repeatedly); - centralized_test!(other, should_gracefully_quit); #[cfg(feature = "i-lists")] centralized_test!(other, should_support_options_with_pipeline); #[cfg(feature = "i-keys")] @@ -138,14 +143,12 @@ mod pubsub { #[cfg(feature = "i-hyperloglog")] mod hyperloglog { - centralized_test!(hyperloglog, should_pfadd_elements); centralized_test!(hyperloglog, should_pfcount_elements); centralized_test!(hyperloglog, should_pfmerge_elements); } mod scanning { - #[cfg(feature = "i-keys")] cluster_test!(scanning, should_scan_keyspace); #[cfg(feature = "i-hashes")] @@ -158,7 +161,6 @@ mod scanning { #[cfg(feature = "i-slowlog")] mod slowlog { - centralized_test!(slowlog, should_read_slowlog_length); centralized_test!(slowlog, should_read_slowlog_entries); centralized_test!(slowlog, should_reset_slowlog); @@ -166,7 +168,6 @@ mod slowlog { #[cfg(feature = "i-server")] mod server { - centralized_test!(server, should_flushall); centralized_test!(server, should_read_server_info); centralized_test!(server, should_ping_server); @@ -179,7 +180,6 @@ mod server { #[cfg(feature = "i-sets")] mod sets { - centralized_test!(sets, should_sadd_elements); centralized_test!(sets, should_scard_elements); centralized_test!(sets, should_sdiff_elements); @@ -199,7 +199,6 @@ mod sets { #[cfg(feature = "i-memory")] pub mod memory { - centralized_test!(memory, should_run_memory_doctor); centralized_test!(memory, should_run_memory_malloc_stats); centralized_test!(memory, should_run_memory_purge); @@ -209,7 +208,6 @@ pub mod memory { #[cfg(feature = "i-scripts")] pub mod lua { - #[cfg(feature = "sha-1")] centralized_test!(lua, should_load_script); centralized_test!(lua, should_eval_echo_script); @@ -243,7 +241,6 @@ pub mod lua { #[cfg(feature = "i-sorted-sets")] pub mod sorted_sets { - centralized_test!(sorted_sets, should_bzpopmin); centralized_test!(sorted_sets, should_bzpopmax); centralized_test!(sorted_sets, should_zadd_values); @@ -308,7 +305,6 @@ pub mod lists { #[cfg(feature = "i-geo")] pub mod geo { - centralized_test!(geo, should_geoadd_values); centralized_test!(geo, should_geohash_values); centralized_test!(geo, should_geopos_values); diff --git a/tests/integration/clustered.rs b/tests/integration/clustered.rs index 24153a00..3cd553cd 100644 --- a/tests/integration/clustered.rs +++ b/tests/integration/clustered.rs @@ -34,7 +34,6 @@ mod keys { #[cfg(all(feature = "transactions", feature = "i-keys"))] mod multi { - cluster_test!(multi, should_run_get_set_trx); cluster_test_panic!(multi, should_fail_with_hashslot_error); cluster_test_panic!(multi, should_run_error_get_set_trx); @@ -49,11 +48,18 @@ mod other { cluster_test!(other, pool_should_fail_with_bad_host_via_init_interface); cluster_test!(other, pool_should_connect_correctly_via_wait_interface); cluster_test!(other, pool_should_fail_with_bad_host_via_wait_interface); + cluster_test!(other, should_split_clustered_connection); + cluster_test!(other, should_safely_change_protocols_repeatedly); + cluster_test!(other, should_gracefully_quit); + #[cfg(all(feature = "transactions", feature = "i-keys", feature = "i-hashes"))] + cluster_test!(other, should_fail_pipeline_transaction_error); + #[cfg(all(feature = "transactions", feature = "i-keys"))] + cluster_test!(other, should_pipeline_transaction); + #[cfg(feature = "credential-provider")] + cluster_test!(other, should_use_credential_provider); #[cfg(feature = "metrics")] cluster_test!(other, should_track_size_stats); - - cluster_test!(other, should_split_clustered_connection); #[cfg(feature = "i-server")] cluster_test!(other, should_run_flushall_cluster); #[cfg(all(feature = "i-client", feature = "i-lists"))] @@ -62,7 +68,6 @@ mod other { cluster_test!(other, should_manually_unblock); #[cfg(all(feature = "i-client", feature = "i-lists"))] cluster_test!(other, should_error_when_blocked); - cluster_test!(other, should_safely_change_protocols_repeatedly); #[cfg(feature = "i-keys")] cluster_test!(other, should_pipeline_all); #[cfg(all(feature = "i-keys", feature = "i-hashes"))] @@ -73,7 +78,6 @@ mod other { cluster_test!(other, should_pipeline_try_all); #[cfg(feature = "i-server")] cluster_test!(other, should_use_all_cluster_nodes_repeatedly); - cluster_test!(other, should_gracefully_quit); #[cfg(feature = "i-lists")] cluster_test!(other, should_support_options_with_pipeline); #[cfg(feature = "i-keys")] @@ -113,7 +117,6 @@ mod pool { #[cfg(feature = "i-hashes")] mod hashes { - cluster_test!(hashes, should_hset_and_hget); cluster_test!(hashes, should_hset_and_hdel); cluster_test!(hashes, should_hexists); @@ -131,7 +134,6 @@ mod hashes { #[cfg(feature = "i-pubsub")] mod pubsub { - cluster_test!(pubsub, should_publish_and_recv_messages); cluster_test!(pubsub, should_ssubscribe_and_recv_messages); cluster_test!(pubsub, should_psubscribe_and_recv_messages); @@ -147,14 +149,12 @@ mod pubsub { #[cfg(feature = "i-hyperloglog")] mod hyperloglog { - cluster_test!(hyperloglog, should_pfadd_elements); cluster_test!(hyperloglog, should_pfcount_elements); cluster_test!(hyperloglog, should_pfmerge_elements); } mod scanning { - #[cfg(feature = "i-keys")] cluster_test!(scanning, should_scan_keyspace); #[cfg(feature = "i-hashes")] @@ -169,7 +169,6 @@ mod scanning { #[cfg(feature = "i-slowlog")] mod slowlog { - cluster_test!(slowlog, should_read_slowlog_length); cluster_test!(slowlog, should_read_slowlog_entries); cluster_test!(slowlog, should_reset_slowlog); @@ -177,7 +176,6 @@ mod slowlog { #[cfg(feature = "i-server")] mod server { - cluster_test!(server, should_flushall); cluster_test!(server, should_read_server_info); cluster_test!(server, should_ping_server); @@ -190,7 +188,6 @@ mod server { #[cfg(feature = "i-sets")] mod sets { - cluster_test!(sets, should_sadd_elements); cluster_test!(sets, should_scard_elements); cluster_test!(sets, should_sdiff_elements); @@ -210,7 +207,6 @@ mod sets { #[cfg(feature = "i-memory")] pub mod memory { - cluster_test!(memory, should_run_memory_doctor); cluster_test!(memory, should_run_memory_malloc_stats); cluster_test!(memory, should_run_memory_purge); @@ -220,7 +216,6 @@ pub mod memory { #[cfg(feature = "i-scripts")] pub mod lua { - #[cfg(feature = "sha-1")] cluster_test!(lua, should_load_script); #[cfg(feature = "sha-1")] @@ -256,7 +251,6 @@ pub mod lua { #[cfg(feature = "i-sorted-sets")] pub mod sorted_sets { - cluster_test!(sorted_sets, should_bzpopmin); cluster_test!(sorted_sets, should_bzpopmax); cluster_test!(sorted_sets, should_zadd_values); @@ -291,7 +285,6 @@ pub mod sorted_sets { #[cfg(feature = "i-lists")] pub mod lists { - cluster_test!(lists, should_blpop_values); cluster_test!(lists, should_brpop_values); cluster_test!(lists, should_brpoplpush_values); @@ -321,7 +314,6 @@ pub mod lists { #[cfg(feature = "i-geo")] pub mod geo { - cluster_test!(geo, should_geoadd_values); cluster_test!(geo, should_geohash_values); cluster_test!(geo, should_geopos_values); diff --git a/tests/integration/other/mod.rs b/tests/integration/other/mod.rs index 99a01c5f..d4bebd70 100644 --- a/tests/integration/other/mod.rs +++ b/tests/integration/other/mod.rs @@ -10,6 +10,7 @@ use fred::{ BackpressureConfig, Builder, ClientUnblockFlag, + ClusterDiscoveryPolicy, ClusterHash, Options, PerformanceConfig, @@ -36,7 +37,9 @@ use tokio::time::sleep; #[cfg(feature = "subscriber-client")] use fred::clients::SubscriberClient; -use fred::types::ClusterDiscoveryPolicy; +use fred::prelude::Server; +#[cfg(feature = "credential-provider")] +use fred::types::CredentialProvider; #[cfg(feature = "replicas")] use fred::types::ReplicaConfig; #[cfg(feature = "dns")] @@ -599,6 +602,41 @@ pub async fn should_support_options_with_trx(client: RedisClient, _: RedisConfig Ok(()) } +#[cfg(all(feature = "transactions", feature = "i-keys"))] +pub async fn should_pipeline_transaction(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { + client.incr("foo{1}").await?; + client.incr("bar{1}").await?; + + let trx = client.multi(); + trx.pipeline(true); + trx.get("foo{1}").await?; + trx.incr("bar{1}").await?; + let (foo, bar): (i64, i64) = trx.exec(true).await?; + assert_eq!((foo, bar), (1, 2)); + + Ok(()) +} + +#[cfg(all(feature = "transactions", feature = "i-keys", feature = "i-hashes"))] +pub async fn should_fail_pipeline_transaction_error(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { + client.incr("foo{1}").await?; + client.incr("bar{1}").await?; + + let trx = client.multi(); + trx.pipeline(true); + trx.get("foo{1}").await?; + trx.hgetall("bar{1}").await?; + trx.get("foo{1}").await?; + + if let Err(e) = trx.exec::(false).await { + assert_eq!(*e.kind(), RedisErrorKind::InvalidArgument); + } else { + panic!("Expected error from transaction."); + } + + Ok(()) +} + #[cfg(all(feature = "i-keys", feature = "i-lists"))] pub async fn should_manually_connect_twice(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { let client = client.clone_new(); @@ -797,3 +835,32 @@ pub async fn should_fail_on_centralized_connect(_: RedisClient, mut config: Redi Err(RedisError::new(RedisErrorKind::Unknown, "Expected a config error.")) } + +#[derive(Debug, Default)] +#[cfg(feature = "credential-provider")] +pub struct FakeCreds {} + +#[async_trait] +#[cfg(feature = "credential-provider")] +impl CredentialProvider for FakeCreds { + async fn fetch(&self, _: Option<&Server>) -> Result<(Option, Option), RedisError> { + use super::utils::{read_redis_password, read_redis_username}; + Ok((Some(read_redis_username()), Some(read_redis_password()))) + } +} +#[cfg(feature = "credential-provider")] +pub async fn should_use_credential_provider(_client: RedisClient, mut config: RedisConfig) -> Result<(), RedisError> { + let (perf, connection) = (_client.perf_config(), _client.connection_config().clone()); + config.username = None; + config.password = None; + config.credential_provider = Some(Arc::new(FakeCreds::default())); + let client = Builder::from_config(config) + .set_connection_config(connection) + .set_performance_config(perf) + .build()?; + + client.init().await?; + client.ping().await?; + client.quit().await?; + Ok(()) +} diff --git a/tests/runners/check-glommio.sh b/tests/runners/check-glommio.sh new file mode 100755 index 00000000..4ffab163 --- /dev/null +++ b/tests/runners/check-glommio.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +docker-compose -f tests/docker/compose/glommio.yml run -u $(id -u ${USER}):$(id -g ${USER}) --rm glommio + diff --git a/tests/runners/mocks.sh b/tests/runners/mocks.sh index 1c3ae2aa..c4562a15 100755 --- a/tests/runners/mocks.sh +++ b/tests/runners/mocks.sh @@ -1,3 +1,3 @@ #!/bin/bash -tests/docker/runners/bash/mocks.sh "$0" \ No newline at end of file +tests/docker/runners/bash/mocks.sh "$1" \ No newline at end of file diff --git a/tests/scripts/build-gh-pages.sh b/tests/scripts/build-gh-pages.sh new file mode 100755 index 00000000..c5b3c448 --- /dev/null +++ b/tests/scripts/build-gh-pages.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +rm -rf .doc +mkdir -p .doc .doc/tokio .doc/glommio + +FEATURES="network-logs custom-reconnect-errors serde-json blocking-encoding unix-sockets mocks + full-tracing monitor metrics sentinel-client subscriber-client dns debug-ids sentinel-auth + replicas sha-1 transactions i-all i-redis-stack enable-rustls enable-native-tls credential-provider" + +cargo +nightly rustdoc --features "$FEATURES" "$@" -- --cfg docsrs +mv target/doc/* .doc/tokio/ + +FEATURES="network-logs custom-reconnect-errors serde-json blocking-encoding mocks sentinel-auth + full-tracing monitor metrics sentinel-client subscriber-client dns debug-ids credential-provider + replicas sha-1 transactions i-all glommio i-redis-stack enable-rustls enable-native-tls" + +cargo +nightly rustdoc --features "$FEATURES" "$@" -- --cfg docsrs +mv target/doc/* .doc/glommio/ + diff --git a/tests/scripts/check_glommio_features.sh b/tests/scripts/check_glommio_features.sh new file mode 100755 index 00000000..08a0d591 --- /dev/null +++ b/tests/scripts/check_glommio_features.sh @@ -0,0 +1,14 @@ +#!/bin/bash -e + +all_features=`yq -oy '.features["i-all"]' Cargo.toml | tr -d '\n' | sed -e 's/- / /g' | cut -c 2-` +redis_stack_features=`yq -oy '.features["i-redis-stack"]' Cargo.toml | tr -d '\n' | sed -e 's/- / /g' | cut -c 2-` + +for feature in $all_features; do + echo "Checking $feature" + cargo clippy --lib -p fred --no-default-features --features "glommio $feature" -- -Dwarnings +done + +for feature in $redis_stack_features; do + echo "Checking $feature" + cargo clippy --lib -p fred --no-default-features --features "glommio $feature" -- -Dwarnings +done \ No newline at end of file