From 9bbb9221cbb1881b9ce2ff2e3b2754d6c0c68ab5 Mon Sep 17 00:00:00 2001 From: Alec Embke Date: Thu, 5 Oct 2023 18:13:42 -0700 Subject: [PATCH] 7.0.0 (#172) * feat: added a new client builder and configuration interface * feat: reworked or removed the majority of the `globals` interface * feat: support multiple IP addresses in the `Resolve` interface * feat: add `with_options` command configuration interface * feat: replaced the `no-client-setname` feature flag with `auto-client-setname` * fix: redesign the connection timeout configuration interface * feat: add an interface to configure TCP socket options * fix: removed the automatic `serde_json::Value` -> `RedisValue` type conversion logic * fix: implement `ClientLike` for `RedisPool` * feat: moved and refactored the `on_*` functions into a new `EventInterface` * fix: fixed several bugs with the `Replica` routing implementation * fix: fixed several bugs and inconsistencies related to parsing single-element arrays * fix: changed several `FromRedis` type conversion rules * feat: add a RedisJSON interface * feat: add a RESP2 and RESP3 codec interface * fix: all commands now return generic types --------- Co-authored-by: Rob Day Co-authored-by: Alexander May <115396695+amay2k@users.noreply.github.com> --- .circleci/Dockerfile.sentinel | 2 +- .circleci/config.yml | 125 +- CHANGELOG.md | 47 +- CONTRIBUTING.md | 21 +- Cargo.toml | 48 +- LICENSE-APACHE | 2 +- LICENSE-MIT | 2 +- README.md | 92 +- bin/inf_loop/docker-compose.yml | 2 +- bin/inf_loop/rustfmt.toml | 8 +- bin/pipeline_test/Cargo.toml | 2 +- bin/pipeline_test/cli.yml | 10 + bin/pipeline_test/docker-compose.yml | 2 +- bin/pipeline_test/run.sh | 2 +- bin/pipeline_test/rustfmt.toml | 8 +- bin/pipeline_test/src/main.rs | 56 +- examples/README.md | 6 +- examples/basic.rs | 47 +- examples/blocking.rs | 26 +- examples/client_tracking.rs | 48 +- examples/custom.rs | 52 +- examples/dns.rs | 31 +- examples/globals.rs | 8 +- examples/lua.rs | 12 +- examples/misc.rs | 96 +- examples/monitor.rs | 3 + examples/pipeline.rs | 19 +- examples/pool.rs | 32 +- examples/prometheus.rs | 42 - examples/pubsub.rs | 12 +- examples/redis_json.rs | 42 + examples/scan.rs | 23 +- examples/sentinel.rs | 11 +- examples/{serde.rs => serde_json.rs} | 27 +- examples/tls.rs | 7 +- examples/transactions.rs | 8 +- src/clients/caching.rs | 100 -- src/clients/mod.rs | 12 +- src/clients/node.rs | 211 ---- src/clients/options.rs | 118 ++ src/clients/pipeline.rs | 72 +- src/clients/pool.rs | 269 ++++ src/clients/pubsub.rs | 41 +- src/clients/redis.rs | 133 +- src/clients/replica.rs | 25 +- src/clients/sentinel.rs | 13 +- src/clients/transaction.rs | 129 +- src/commands/impls/acl.rs | 33 +- src/commands/impls/client.rs | 24 +- src/commands/impls/cluster.rs | 22 +- src/commands/impls/geo.rs | 40 +- src/commands/impls/hashes.rs | 14 +- src/commands/impls/hyperloglog.rs | 7 +- src/commands/impls/keys.rs | 46 +- src/commands/impls/lists.rs | 34 +- src/commands/impls/lua.rs | 100 +- src/commands/impls/memory.rs | 14 +- src/commands/impls/mod.rs | 110 +- src/commands/impls/pubsub.rs | 48 +- src/commands/impls/redis_json.rs | 358 ++++++ src/commands/impls/scan.rs | 6 +- src/commands/impls/server.rs | 52 +- src/commands/impls/sets.rs | 19 +- src/commands/impls/slowlog.rs | 22 +- src/commands/impls/sorted_sets.rs | 44 +- src/commands/impls/streams.rs | 16 +- src/commands/impls/tracking.rs | 6 +- src/commands/interfaces/acl.rs | 35 +- src/commands/interfaces/client.rs | 25 +- src/commands/interfaces/cluster.rs | 33 +- src/commands/interfaces/geo.rs | 30 +- src/commands/interfaces/keys.rs | 16 +- src/commands/interfaces/lua.rs | 2 + src/commands/interfaces/memory.rs | 23 +- src/commands/interfaces/mod.rs | 3 + src/commands/interfaces/redis_json.rs | 399 ++++++ src/commands/interfaces/server.rs | 6 +- src/commands/interfaces/slowlog.rs | 16 +- src/commands/interfaces/sorted_sets.rs | 9 +- src/commands/interfaces/streams.rs | 33 +- src/commands/interfaces/tracking.rs | 32 +- src/commands/interfaces/transactions.rs | 2 +- src/error.rs | 27 +- src/interfaces.rs | 238 +++- src/lib.rs | 50 +- src/macros.rs | 83 +- src/modules/backchannel.rs | 93 +- src/modules/globals.rs | 51 +- src/modules/inner.rs | 121 +- src/modules/metrics.rs | 2 +- src/modules/mocks.rs | 18 +- src/modules/mod.rs | 2 - src/modules/pool.rs | 149 --- src/modules/response.rs | 415 +++++-- src/monitor/utils.rs | 12 +- src/protocol/cluster.rs | 207 ++-- src/protocol/codec.rs | 16 +- src/protocol/command.rs | 614 ++++----- src/protocol/connection.rs | 305 +++-- src/protocol/hashers.rs | 18 +- src/protocol/mod.rs | 3 + src/protocol/public.rs | 210 ++++ src/protocol/responders.rs | 47 +- src/protocol/types.rs | 78 +- src/protocol/utils.rs | 1094 +++-------------- src/router/centralized.rs | 38 +- src/router/clustered.rs | 228 ++-- src/router/commands.rs | 464 +++---- src/router/mod.rs | 482 +++----- src/router/replicas.rs | 328 ++--- src/router/responses.rs | 47 +- src/router/sentinel.rs | 87 +- src/router/transactions.rs | 117 +- src/router/types.rs | 28 +- src/router/utils.rs | 117 +- src/trace/disabled.rs | 6 +- src/trace/enabled.rs | 7 +- src/types/acl.rs | 132 -- src/types/args.rs | 449 +++---- src/types/builder.rs | 287 +++++ src/types/cluster.rs | 68 +- src/types/config.rs | 302 +++-- src/types/geo.rs | 89 +- src/types/misc.rs | 181 ++- src/types/mod.rs | 8 +- src/types/multiple.rs | 126 +- src/types/scripts.rs | 16 +- src/types/streams.rs | 8 +- src/utils.rs | 175 +-- tests/README.md | 4 +- tests/docker/compose/base.yml | 5 +- tests/docker/compose/centralized.yml | 4 +- tests/docker/compose/cluster-tls.yml | 14 +- tests/docker/compose/cluster.yml | 14 +- tests/docker/compose/redis-stack.yml | 23 + tests/docker/compose/sentinel.yml | 13 +- tests/docker/runners/bash/all-features.sh | 4 +- .../docker/runners/bash/default-nil-types.sh | 19 + tests/docker/runners/bash/redis-stack.sh | 19 + .../docker/runners/bash/sentinel-features.sh | 2 +- tests/docker/runners/compose/all-features.yml | 3 +- .../runners/compose/cluster-native-tls.yml | 3 +- .../docker/runners/compose/cluster-rustls.yml | 3 +- .../runners/compose/default-features.yml | 3 +- .../runners/compose/default-nil-types.yml | 32 + tests/docker/runners/compose/no-features.yml | 3 +- tests/docker/runners/compose/redis-stack.yml | 28 + .../runners/compose/sentinel-features.yml | 4 +- tests/docker/runners/images/base.dockerfile | 4 +- tests/docker/runners/images/ci.dockerfile | 2 +- tests/docker/runners/images/debug.dockerfile | 2 +- tests/environ | 5 +- tests/integration/acl/mod.rs | 36 +- tests/integration/centralized.rs | 36 +- tests/integration/cluster/mod.rs | 4 +- tests/integration/clustered.rs | 20 +- tests/integration/docker.rs | 210 ++++ tests/integration/geo/mod.rs | 124 +- tests/integration/hashes/mod.rs | 34 +- tests/integration/hyperloglog/mod.rs | 2 +- tests/integration/keys/mod.rs | 90 +- tests/integration/lists/mod.rs | 32 +- tests/integration/lua/mod.rs | 76 +- tests/integration/memory/mod.rs | 15 +- tests/integration/mod.rs | 42 +- tests/integration/multi/mod.rs | 16 +- tests/integration/other/mod.rs | 217 +++- tests/integration/pool/mod.rs | 19 +- tests/integration/pubsub/mod.rs | 103 +- tests/integration/redis_json/mod.rs | 199 +++ tests/integration/scanning/mod.rs | 24 +- tests/integration/server/mod.rs | 12 +- tests/integration/sets/mod.rs | 50 +- tests/integration/slowlog/mod.rs | 15 +- tests/integration/sorted_sets/mod.rs | 39 +- tests/integration/streams/mod.rs | 158 +-- tests/integration/tracking/mod.rs | 4 +- tests/integration/utils.rs | 182 ++- tests/lib.rs | 11 + tests/runners/default-features.sh | 3 +- tests/runners/default-nil-types.sh | 7 + tests/runners/docker-bash.sh | 8 +- tests/runners/everything.sh | 4 +- tests/runners/redis-stack.sh | 5 + tests/runners/sentinel-features.sh | 3 +- 185 files changed, 7797 insertions(+), 5642 deletions(-) delete mode 100644 examples/prometheus.rs create mode 100644 examples/redis_json.rs rename examples/{serde.rs => serde_json.rs} (52%) delete mode 100644 src/clients/caching.rs delete mode 100644 src/clients/node.rs create mode 100644 src/clients/options.rs create mode 100644 src/clients/pool.rs create mode 100644 src/commands/impls/redis_json.rs create mode 100644 src/commands/interfaces/redis_json.rs delete mode 100644 src/modules/pool.rs create mode 100644 src/protocol/public.rs delete mode 100644 src/types/acl.rs create mode 100644 src/types/builder.rs create mode 100644 tests/docker/compose/redis-stack.yml create mode 100755 tests/docker/runners/bash/default-nil-types.sh create mode 100755 tests/docker/runners/bash/redis-stack.sh create mode 100644 tests/docker/runners/compose/default-nil-types.yml create mode 100644 tests/docker/runners/compose/redis-stack.yml create mode 100644 tests/integration/docker.rs create mode 100644 tests/integration/redis_json/mod.rs create mode 100755 tests/runners/default-nil-types.sh create mode 100755 tests/runners/redis-stack.sh diff --git a/.circleci/Dockerfile.sentinel b/.circleci/Dockerfile.sentinel index 9f15ff75..8c40d8c4 100644 --- a/.circleci/Dockerfile.sentinel +++ b/.circleci/Dockerfile.sentinel @@ -6,7 +6,7 @@ # note: the top level target directory must be removed prior to running this -FROM cimg/rust:1.72.0 +FROM cimg/rust:1.72.1 USER circleci ARG REDIS_VERSION diff --git a/.circleci/config.yml b/.circleci/config.yml index 4e4be933..db108e7c 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -34,6 +34,11 @@ commands: - run: name: Build documentation command: tests/doc.sh + test_mocks: + steps: + - run: + name: Run mock tests + command: cargo test --lib --features mocks --no-default-features test_default_features: steps: - checkout @@ -50,6 +55,14 @@ commands: name: Run tests with all features command: source tests/environ && tests/runners/all-features.sh - save-cargo-deps-cache + test_redis_stack: + steps: + - checkout + - restore-cargo-deps-cache + - run: + name: Run tests with redis-stack features + command: source tests/environ && tests/runners/redis-stack.sh + - save-cargo-deps-cache test_no_features: steps: - checkout @@ -82,122 +95,122 @@ commands: name: Run cluster tests with rustls features command: source tests/environ && tests/scripts/tls-creds.sh && tests/runners/cluster-rustls.sh - save-cargo-deps-cache + test_default_nil_types_features: + steps: + - checkout + - restore-cargo-deps-cache + - run: + name: Run tests with default-nil-types features + command: source tests/environ && tests/runners/default-nil-types.sh + - save-cargo-deps-cache jobs: - test-default-7_0: + test-default-nil-types-7_2: machine: image: ubuntu-2204:2022.10.2 docker_layer_caching: true resource_class: medium environment: - REDIS_VERSION: 7.0.9 + REDIS_VERSION: 7.2.1 steps: - - test_default_features - test-no-features-7_0: + - test_default_nil_types_features + test-default-7_2: machine: image: ubuntu-2204:2022.10.2 docker_layer_caching: true resource_class: medium environment: - REDIS_VERSION: 7.0.9 + REDIS_VERSION: 7.2.1 steps: - - test_no_features - test-all-features-7_0: + - test_default_features + test-redis-stack-7_2: machine: image: ubuntu-2204:2022.10.2 docker_layer_caching: true resource_class: medium environment: - REDIS_VERSION: 7.0.9 + REDIS_VERSION: 7.2.1 steps: - - test_all_features - test-cluster-tls-features-7_0: + - test_redis_stack + test-no-features-7_2: machine: image: ubuntu-2204:2022.10.2 docker_layer_caching: true resource_class: medium environment: - REDIS_VERSION: 7.0.9 - FRED_CI_TLS: true + REDIS_VERSION: 7.2.1 steps: - - test_tls_cluster - test-cluster-rustls-features-7_0: - machine: - image: ubuntu-2204:2022.10.2 - docker_layer_caching: true - resource_class: medium - environment: - REDIS_VERSION: 7.0.9 - FRED_CI_TLS: true - steps: - - test_rustls_cluster - test-default-6_2: + - test_no_features + test-all-features-7_2: machine: image: ubuntu-2204:2022.10.2 docker_layer_caching: true resource_class: medium environment: - REDIS_VERSION: 6.2.2 + REDIS_VERSION: 7.2.1 steps: - - test_default_features - test-no-features-6_2: + - test_all_features + test-cluster-tls-features-7_2: machine: image: ubuntu-2204:2022.10.2 docker_layer_caching: true resource_class: medium environment: - REDIS_VERSION: 6.2.2 + REDIS_VERSION: 7.2.1 + FRED_CI_TLS: true steps: - - test_no_features - test-all-features-6_2: + - test_tls_cluster + test-cluster-rustls-features-7_2: machine: image: ubuntu-2204:2022.10.2 docker_layer_caching: true resource_class: medium environment: - REDIS_VERSION: 6.2.2 + REDIS_VERSION: 7.2.1 + FRED_CI_TLS: true steps: - - test_all_features - test-sentinel-6_2: + - test_rustls_cluster + test-sentinel-7_2: machine: image: ubuntu-2204:2022.10.2 docker_layer_caching: true resource_class: medium environment: - REDIS_VERSION: 6.2.2 + REDIS_VERSION: 7.2.1 steps: - test_sentinel - test-sentinel-7_0: - machine: - image: ubuntu-2204:2022.10.2 - docker_layer_caching: true - resource_class: medium + test-misc: + docker: + - image: cimg/rust:1.72.1 environment: - REDIS_VERSION: 7.0.9 + CARGO_NET_GIT_FETCH_WITH_CLI: true steps: - - test_sentinel - test-docs: + - checkout + - build_docs + - test_mocks + clippy-lint: docker: - - image: cimg/rust:1.72.0 + - image: cimg/rust:1.72.1 environment: CARGO_NET_GIT_FETCH_WITH_CLI: true steps: - checkout - - build_docs + - run: + name: Clippy + command: cargo clippy -- -Dwarnings + workflows: version: 2 build: jobs: - # the older bitnami Redis images require a different process to bootstrap ACL rules... - #- test-default-6_2 - #- test-all-features-6_2 - #- test-no-features-6_2 - #- test-sentinel-6_2 - - test-default-7_0 - - test-all-features-7_0 - - test-no-features-7_0 - - test-sentinel-7_0 - - test-docs - - test-cluster-tls-features-7_0 - - test-cluster-rustls-features-7_0 \ No newline at end of file + - test-default-7_2 + - test-all-features-7_2 + - test-no-features-7_2 + - test-default-nil-types-7_2 + - test-redis-stack-7_2 + - test-sentinel-7_2 + - test-misc + - test-cluster-tls-features-7_2 + - test-cluster-rustls-features-7_2 + - clippy-lint \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index ea23a9bf..d620be37 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,43 @@ +## 7.0.0 + +* Added a new client [builder](src/types/builder.rs) and configuration interface. +* Reworked or removed the majority of the `globals` interface. +* Support multiple IP addresses in the `Resolve` interface. +* Add `with_options` command configuration interface. +* Replaced the `no-client-setname` feature flag with `auto-client-setname`. +* Add an interface to configure TCP socket options. +* Removed the automatic `serde_json::Value` -> `RedisValue` type conversion logic. + * This unintentionally introduced some ambiguity on certain interfaces. + * The `RedisValue` -> `serde_json::Value` type conversion logic was not changed. +* Reworked the majority of the `RedisPool` interface. +* Moved and refactored the `on_*` functions into a new `EventInterface`. +* Fixed bugs with the `Replica` routing implementation. +* Fixed bugs related to parsing single-element arrays. +* Changed several `FromRedis` type conversion rules. See below or the `FromRedis` docs for more information. +* Add a [RedisJSON](https://github.com/RedisJSON/RedisJSON/) interface. +* Add a RESP2 and RESP3 codec interface. + +### Upgrading from 6.x + +Notable interface changes: + +* `ArcStr` has been replaced with `bytes_utils::Str`. +* Timeout arguments or fields now all use `std::time::Duration`. +* Many of the old global or performance config values can now be set on individual commands via the `with_options` interface. +* The `RedisPool` interface now directly implements `ClientLike` rather than relying on `Deref` shenanigans. +* The `on_*` event functions were moved and renamed. Reconnection events now include the associated `Server`. +* The `tls_server_name` field on `Server` is now properly gated by the TLS feature flags. +* Mocks are now optional even when the feature flag is enabled. + +Notable implementation Changes: + +* `Pipeline` and `Transaction` structs can now be reused. Calling `exec`, `all`, `last`, or `try_all` no longer drains the inner command buffer. +* Many of the default timeout values have been lowered significantly, often from 60 sec to 10 sec. +* In earlier versions the `FromRedis` trait implemented a few inconsistent or ambiguous type conversions policies. + * Most of these were consolidated under the `default-nil-types` feature flag. + * It is recommended that callers review the updated `FromRedis` docs or see the unit tests in [responses](src/modules/response.rs). +* The `connect` function can now be called more than once to force reset all client state. + ## 6.3.2 * Fix a bug with connection errors unexpectedly ending the connection task. @@ -15,13 +55,6 @@ * Fix compilation error with `full-tracing` * Support `Vec<(T1, T2, ...)>` with `FromRedis` -## 6.2.2 - -* Fix cluster replica discovery with Elasticache -* Fix cluster replica `READONLY` usage -* Fix compilation error with `full-tracing` -* Support `Vec<(T1, T2, ...)>` with `FromRedis` - ## 6.2.1 * Fix cluster failover with paused nodes diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 6c25d4f2..4f64c6d2 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -6,16 +6,10 @@ This document gives some background on how the library is structured and how to # General * Use 2 spaces instead of tabs. -* Run rustfmt before submitting any changes. +* Run rustfmt and clippy before submitting any changes. * Clean up any compiler warnings. * Use the `async` syntax rather than `impl Future` where possible. - -## Branches - -* Create external PRs against the `staging` branch. -* Use topic branches with [conventional commits](https://www.conventionalcommits.org/en/v1.0.0/#summary). -* Branching strategy is `` -> (squash) `staging` -> `main` -> `release-` -* Remove `chore` commits when squashing PRs. +* Please use [conventional commits](https://www.conventionalcommits.org/en/v1.0.0/#summary). ## TODO List @@ -54,7 +48,7 @@ impl RedisCommandKind { // .. - pub fn to_str_debug(&self) -> &'static str { + pub fn to_str_debug(&self) -> &str { match *self { // .. RedisCommandKind::Mget => "MGET", @@ -64,7 +58,7 @@ impl RedisCommandKind { // .. - pub fn cmd_str(&self) -> &'static str { + pub fn cmd_str(&self) -> Str { match *self { // .. RedisCommandKind::Mget => "MGET" @@ -80,10 +74,11 @@ impl RedisCommandKind { ```rust pub async fn mget(client: &C, keys: MultipleKeys) -> Result { + // maybe do some kind of validation utils::check_empty_keys(&keys)?; let frame = utils::request_response(client, move || { - // time spent here will show up in traces + // time spent here will show up in traces in the `prepare_command` span Ok((RedisCommandKind::Mget, keys.into_values())) }) .await?; @@ -92,7 +87,7 @@ pub async fn mget(client: &C, keys: MultipleKeys) -> Result(client: &C, keys: MultipleKeys) -> Result { @@ -121,7 +116,7 @@ pub trait KeysInterface: ClientLike { K: Into + Send, { into!(keys); - commands::keys::mget(self, keys).await + commands::keys::mget(self, keys).await?.convert() } // ... } diff --git a/Cargo.toml b/Cargo.toml index 57109d65..9384a33a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "fred" -version = "6.3.2" +version = "7.0.0" authors = ["Alec Embke "] edition = "2021" description = "An async Redis client built on Tokio." @@ -29,16 +29,18 @@ features = [ "sentinel-auth", "check-unresponsive", "replicas", - "client-tracking" + "client-tracking", + "default-nil-types", + "codec", + "redis-json", + "sha-1" ] rustdoc-args = ["--cfg", "docsrs"] [dependencies] -arcstr = "1.1" arc-swap = "1.5" tokio = { version = "1.19.0", features = ["net", "sync", "rt", "rt-multi-thread", "macros"] } tokio-util = { version = "0.7.1", features = ["codec"] } -cfg-if = "1.0.0" bytes = "1.1" bytes-utils = "0.1" futures = "0.3" @@ -49,10 +51,11 @@ log = "0.4" float-cmp = "0.9" url = "2.3" tokio-stream = "0.1.1" -sha-1 = "0.10" +sha-1 = { version = "0.10", optional = true } rand = "0.8" -async-trait = "0.1" semver = "1.0" +socket2 = "0.5" +async-trait = { version = "0.1" } rustls = { version = "0.21", optional = true } native-tls = { version = "0.2", optional = true } tokio-native-tls = { version = "0.3", optional = true } @@ -66,11 +69,11 @@ rustls-native-certs = { version = "0.6", optional = true } trust-dns-resolver = { version = "0.23", optional = true } [dev-dependencies] -prometheus = "0.13" base64 = "0.21" -subprocess = "0.2.7" -serde = { version = "1.0", features = ["derive"] } +subprocess = "0.2" pretty_env_logger = "0.5" +bollard = "0.14" +serde = "1.0" [lib] doc = true @@ -81,18 +84,18 @@ test = true name = "monitor" required-features = ["monitor"] -[[example]] -name = "prometheus" -required-features = ["metrics"] - [[example]] name = "pubsub" required-features = ["subscriber-client"] [[example]] -name = "serde" +name = "serde_json" required-features = ["serde-json"] +[[example]] +name = "redis_json" +required-features = ["redis-json"] + [[example]] name = "dns" required-features = ["dns"] @@ -101,9 +104,12 @@ required-features = ["dns"] name = "client_tracking" required-features = ["client-tracking"] +[[example]] +name = "lua" +required-features = ["sha-1"] + [features] default = ["ignore-auth-error", "pool-prefer-active"] -fallback = [] serde-json = ["serde_json"] subscriber-client = [] metrics = [] @@ -118,15 +124,19 @@ pool-prefer-active = [] full-tracing = ["partial-tracing", "tracing", "tracing-futures"] partial-tracing = ["tracing", "tracing-futures"] blocking-encoding = ["tokio/rt-multi-thread"] -network-logs = [] custom-reconnect-errors = [] monitor = ["nom"] sentinel-client = [] sentinel-auth = [] -no-client-setname = [] check-unresponsive = [] replicas = [] +auto-client-setname = [] client-tracking = [] -# Testing Features +default-nil-types = [] +codec = [] +# Redis Stack Features +redis-stack = ["redis-json"] +redis-json = ["serde-json"] +# Debugging Features debug-ids = [] -sentinel-tests = [] +network-logs = [] diff --git a/LICENSE-APACHE b/LICENSE-APACHE index fda91dbb..4d8cd551 100644 --- a/LICENSE-APACHE +++ b/LICENSE-APACHE @@ -187,7 +187,7 @@ same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright 2022 Alec Embke + Copyright 2023 Alec Embke Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/LICENSE-MIT b/LICENSE-MIT index 1b8c5577..c80fa4b0 100644 --- a/LICENSE-MIT +++ b/LICENSE-MIT @@ -1,4 +1,4 @@ -Copyright 2022 Alec Embke +Copyright 2023 Alec Embke Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: diff --git a/README.md b/README.md index 0fc4cace..0bb53b43 100644 --- a/README.md +++ b/README.md @@ -16,12 +16,7 @@ use fred::prelude::*; #[tokio::main] async fn main() -> Result<(), RedisError> { - let config = RedisConfig::default(); - let perf = PerformanceConfig::default(); - let policy = ReconnectPolicy::default(); - let client = RedisClient::new(config, Some(perf), Some(policy)); - - // connect to the server, returning a handle to the task that drives the connection + let client = RedisClient::default(); let _ = client.connect(); let _ = client.wait_for_connect().await?; @@ -46,49 +41,52 @@ See the [examples](https://github.com/aembke/fred.rs/tree/main/examples) for mor ## Features -* Supports RESP2 and RESP3 protocol modes. -* Supports clustered, centralized, and sentinel Redis deployments. -* Optional built-in reconnection logic with multiple backoff policies. +* RESP2 and RESP3 protocol modes. +* Clustered, centralized, and sentinel Redis deployments. +* TLS connections via `native-tls` and/or `rustls`. +* Optional reconnection logic with multiple backoff policies. * Publish-Subscribe and keyspace events interfaces. -* Supports transactions. -* Supports Lua scripts. -* Supports streaming results from the `MONITOR` command. -* Supports custom commands provided by third party modules. -* Supports TLS connections via `native-tls` and/or `rustls`. -* Supports streaming interfaces for scanning functions. -* Supports [pipelining](https://redis.io/topics/pipelining). -* Automatically retry requests under bad network conditions. -* Optional built-in tracking for network latency and payload size metrics. -* An optional client pooling interface to round-robin requests among a pool of clients. -* An optional sentinel client for interacting directly with sentinel nodes to manually fail over servers, etc. +* A round-robin client pooling interface. +* Lua [scripts](https://redis.io/docs/interact/programmability/eval-intro/) or [functions](https://redis.io/docs/interact/programmability/functions-intro/). +* Streaming results from the `MONITOR` command. +* Custom commands. +* Streaming interfaces for scanning functions. +* [Transactions](https://redis.io/docs/interact/transactions/) +* [Pipelining](https://redis.io/topics/pipelining) +* [Client Tracking](https://redis.io/docs/manual/client-side-caching/) +* An optional [RedisJSON](https://github.com/RedisJSON/RedisJSON) interface. +* A round-robin cluster replica routing interface. * An optional pubsub subscriber client that will automatically manage channel subscriptions. -* An optional interface to override DNS resolution logic. -* Optional support for JSON values. +* [Tracing](https://github.com/tokio-rs/tracing) -## Build Time Features +## Build Features -| Name | Default | Description | -|-------------------------|---------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| enable-native-tls | | Enable TLS support via [native-tls](https://crates.io/crates/native-tls). | -| enable-rustls | | Enable TLS support via [rustls](https://crates.io/crates/rustls). | -| vendored-openssl | | Enable the `native-tls/vendored` feature, if possible. | -| ignore-auth-error | x | Ignore auth errors that occur when a password is supplied but not required. | -| metrics | | Enable the metrics interface to track overall latency, network latency, and request/response sizes. | +| Name | Default | Description | +|-------------------------|---------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| enable-native-tls | | Enable TLS support via [native-tls](https://crates.io/crates/native-tls). | +| enable-rustls | | Enable TLS support via [rustls](https://crates.io/crates/rustls). | +| vendored-openssl | | Enable the `native-tls/vendored` feature, if possible. | +| ignore-auth-error | x | Ignore auth errors that occur when a password is supplied but not required. | +| metrics | | Enable the metrics interface to track overall latency, network latency, and request/response sizes. | | reconnect-on-auth-error | | A NOAUTH error is treated the same as a general connection failure and the client will reconnect based on the reconnection policy. This is [recommended](https://github.com/StackExchange/StackExchange.Redis/issues/1273#issuecomment-651823824) if callers are using ElastiCache. | -| pool-prefer-active | x | Prefer connected clients over clients in a disconnected state when using the `RedisPool` interface. | -| full-tracing | | Enable full [tracing](./src/trace/README.md) support. This can emit a lot of data. | -| partial-tracing | | Enable partial [tracing](./src/trace/README.md) support, only emitting traces for top level commands and network latency. | -| blocking-encoding | | Use a blocking task for encoding or decoding frames. This can be useful for clients that send or receive large payloads, but will only work when used with a multi-thread Tokio runtime. | -| network-logs | | Enable TRACE level logging statements that will print out all data sent to or received from the server. These are the only logging statements that can ever contain potentially sensitive user data. | -| custom-reconnect-errors | | Enable an interface for callers to customize the types of errors that should automatically trigger reconnection logic. | -| monitor | | Enable an interface for running the `MONITOR` command. | -| sentinel-client | | Enable an interface for communicating directly with Sentinel nodes. This is not necessary to use normal Redis clients behind a sentinel layer. | -| sentinel-auth | | Enable an interface for using different authentication credentials to sentinel nodes. | -| subscriber-client | | Enable an optional subscriber client that manages channel subscription state for callers. | -| serde-json | | Enable an interface to automatically convert Redis types to JSON. | -| no-client-setname | | Disable the automatic `CLIENT SETNAME` command used to associate server logs with client logs. | -| mocks | | Enable a mocking layer interface that can be used to intercept and process commands in tests. | -| dns | | Enable an interface that allows callers to override the DNS lookup logic. | -| check-unresponsive | | Enable additional monitoring to detect unresponsive connections. | -| replicas | | Enable an interface that routes commands to replica nodes. | -| client-tracking | | Enable a [client tracking](https://redis.io/docs/manual/client-side-caching/) interface. | +| pool-prefer-active | x | Prefer connected clients over clients in a disconnected state when using the `RedisPool` interface. | +| full-tracing | | Enable full [tracing](./src/trace/README.md) support. This can emit a lot of data. | +| partial-tracing | | Enable partial [tracing](./src/trace/README.md) support, only emitting traces for top level commands and network latency. | +| blocking-encoding | | Use a blocking task for encoding or decoding frames. This can be useful for clients that send or receive large payloads, but will only work when used with a multi-thread Tokio runtime. | +| network-logs | | Enable TRACE level logging statements that will print out all data sent to or received from the server. These are the only logging statements that can ever contain potentially sensitive user data. | +| custom-reconnect-errors | | Enable an interface for callers to customize the types of errors that should automatically trigger reconnection logic. | +| monitor | | Enable an interface for running the `MONITOR` command. | +| sentinel-client | | Enable an interface for communicating directly with Sentinel nodes. This is not necessary to use normal Redis clients behind a sentinel layer. | +| sentinel-auth | | Enable an interface for using different authentication credentials to sentinel nodes. | +| subscriber-client | | Enable an optional subscriber client that manages channel subscription state for callers. | +| serde-json | | Enable an interface to automatically convert Redis types to JSON. | +| auto-client-setname | | Automatically send `CLIENT SETNAME` on each connection associated with a client instance. | +| mocks | | Enable a mocking layer interface that can be used to intercept and process commands in tests. | +| dns | | Enable an interface that allows callers to override the DNS lookup logic. | +| check-unresponsive | | Enable additional monitoring to detect unresponsive connections. | +| replicas | | Enable an interface that routes commands to replica nodes. | +| client-tracking | | Enable a [client tracking](https://redis.io/docs/manual/client-side-caching/) interface. | +| default-nil-types | | Enable a looser parsing interface for `nil` values. | +| redis-json | | Enable an interface for [RedisJSON](https://github.com/RedisJSON/RedisJSON). | +| codec | | Enable a lower level framed codec interface for use with [tokio-util](https://docs.rs/tokio-util/latest/tokio_util/codec/index.html). | +| sha-1 | | Enable an interface for hashing Lua scripts. | diff --git a/bin/inf_loop/docker-compose.yml b/bin/inf_loop/docker-compose.yml index 99dee1f0..81d733af 100644 --- a/bin/inf_loop/docker-compose.yml +++ b/bin/inf_loop/docker-compose.yml @@ -12,7 +12,7 @@ services: args: REDIS_VERSION: "${REDIS_VERSION}" networks: - - app-tier + - fred-tests entrypoint: "cargo run --release --features \"replicas\" -- ${TEST_ARGV}" environment: RUST_LOG: "${RUST_LOG}" diff --git a/bin/inf_loop/rustfmt.toml b/bin/inf_loop/rustfmt.toml index 8482db5e..ef2c3648 100644 --- a/bin/inf_loop/rustfmt.toml +++ b/bin/inf_loop/rustfmt.toml @@ -1,7 +1,7 @@ -edition = "2018" +edition = "2021" binop_separator = "Front" blank_lines_upper_bound = 1 -brace_style = "PreferSameLine" +brace_style = "SameLineWhere" combine_control_expr = true comment_width = 125 condense_wildcard_suffixes = false @@ -10,7 +10,7 @@ empty_item_single_line = true error_on_line_overflow = false enum_discrim_align_threshold = 0 error_on_unformatted = false -fn_args_layout = "Tall" +fn_params_layout = "Tall" fn_single_line = false force_explicit_abi = true force_multiline_blocks = false @@ -28,7 +28,7 @@ match_arm_blocks = true match_block_trailing_comma = true max_width = 118 merge_derives = true -merge_imports = true +imports_granularity="Crate" newline_style = "Auto" normalize_comments = true normalize_doc_attributes = true diff --git a/bin/pipeline_test/Cargo.toml b/bin/pipeline_test/Cargo.toml index de553f51..be22cb6c 100644 --- a/bin/pipeline_test/Cargo.toml +++ b/bin/pipeline_test/Cargo.toml @@ -27,7 +27,7 @@ indicatif = "=0.17.1" [dependencies.fred] #path = "../.." path = "/fred" -features = ["ignore-auth-error"] +features = ["ignore-auth-error", "replicas"] default-features = false [features] diff --git a/bin/pipeline_test/cli.yml b/bin/pipeline_test/cli.yml index 8576f374..f7cdaac2 100644 --- a/bin/pipeline_test/cli.yml +++ b/bin/pipeline_test/cli.yml @@ -12,6 +12,10 @@ args: long: cluster help: Whether or not to assume a clustered deployment. takes_value: false + - replicas: + long: replicas + help: Whether or not to use `GET` with replica nodes instead of `INCR` with primary nodes. + takes_value: false - quiet: short: q long: quiet @@ -52,6 +56,12 @@ args: help: The number of clients in the redis connection pool. takes_value: true default_value: "1" + - auth: + short: a + long: auth + value_name: "STRING" + help: The password/key to use. + takes_value: true subcommands: - pipeline: about: Run the test with pipelining. diff --git a/bin/pipeline_test/docker-compose.yml b/bin/pipeline_test/docker-compose.yml index 0a9d7bd2..31f8f8db 100644 --- a/bin/pipeline_test/docker-compose.yml +++ b/bin/pipeline_test/docker-compose.yml @@ -12,7 +12,7 @@ services: args: REDIS_VERSION: "${REDIS_VERSION}" networks: - - app-tier + - fred-tests environment: RUST_LOG: "${RUST_LOG}" REDIS_VERSION: "${REDIS_VERSION}" diff --git a/bin/pipeline_test/run.sh b/bin/pipeline_test/run.sh index 01589df6..414837ef 100755 --- a/bin/pipeline_test/run.sh +++ b/bin/pipeline_test/run.sh @@ -1,4 +1,4 @@ #!/bin/bash docker-compose -f ../../tests/docker/compose/cluster.yml -f ../../tests/docker/compose/centralized.yml -f ./docker-compose.yml \ - run -u $(id -u ${USER}):$(id -g ${USER}) --rm pipeline-test cargo run --release -- "$0" \ No newline at end of file + run -u $(id -u ${USER}):$(id -g ${USER}) --rm pipeline-test cargo run --release -- "${@:1}" \ No newline at end of file diff --git a/bin/pipeline_test/rustfmt.toml b/bin/pipeline_test/rustfmt.toml index 8482db5e..ef2c3648 100644 --- a/bin/pipeline_test/rustfmt.toml +++ b/bin/pipeline_test/rustfmt.toml @@ -1,7 +1,7 @@ -edition = "2018" +edition = "2021" binop_separator = "Front" blank_lines_upper_bound = 1 -brace_style = "PreferSameLine" +brace_style = "SameLineWhere" combine_control_expr = true comment_width = 125 condense_wildcard_suffixes = false @@ -10,7 +10,7 @@ empty_item_single_line = true error_on_line_overflow = false enum_discrim_align_threshold = 0 error_on_unformatted = false -fn_args_layout = "Tall" +fn_params_layout = "Tall" fn_single_line = false force_explicit_abi = true force_multiline_blocks = false @@ -28,7 +28,7 @@ match_arm_blocks = true match_block_trailing_comma = true max_width = 118 merge_derives = true -merge_imports = true +imports_granularity="Crate" newline_style = "Auto" normalize_comments = true normalize_doc_attributes = true diff --git a/bin/pipeline_test/src/main.rs b/bin/pipeline_test/src/main.rs index 3be850a7..d674adb2 100644 --- a/bin/pipeline_test/src/main.rs +++ b/bin/pipeline_test/src/main.rs @@ -13,11 +13,14 @@ extern crate tracing_subscriber; extern crate log; extern crate pretty_env_logger; +#[cfg(any(feature = "partial-tracing", feature = "full-tracing", feature = "stdout-tracing"))] +use fred::types::TracingConfig; + use clap::{App, ArgMatches}; use fred::{ - pool::RedisPool, + clients::RedisPool, prelude::*, - types::{BackpressureConfig, BackpressurePolicy, PerformanceConfig, TracingConfig}, + types::{BackpressureConfig, BackpressurePolicy, Builder as RedisBuilder, PerformanceConfig, Server}, }; use indicatif::ProgressBar; use opentelemetry::{ @@ -34,6 +37,7 @@ use std::{ default::Default, sync::{atomic::AtomicUsize, Arc}, thread::{self, JoinHandle as ThreadJoinHandle}, + time::Duration, }; use tokio::{runtime::Builder, task::JoinHandle, time::Instant}; use tracing_subscriber::{layer::SubscriberExt, Layer, Registry}; @@ -48,6 +52,7 @@ mod utils; #[derive(Debug)] struct Argv { pub cluster: bool, + pub replicas: bool, pub tracing: bool, pub count: usize, pub tasks: usize, @@ -56,15 +61,21 @@ struct Argv { pub pipeline: bool, pub pool: usize, pub quiet: bool, + pub auth: Option, } fn parse_argv() -> Arc { let yaml = load_yaml!("../cli.yml"); let matches = App::from_yaml(yaml).get_matches(); let tracing = matches.is_present("tracing"); - let cluster = matches.is_present("cluster"); + let mut cluster = matches.is_present("cluster"); + let replicas = matches.is_present("replicas"); let quiet = matches.is_present("quiet"); + if replicas { + cluster = true; + } + let count = matches .value_of("count") .map(|v| { @@ -93,6 +104,7 @@ fn parse_argv() -> Arc { .value_of("pool") .map(|v| v.parse::().expect("Invalid pool")) .unwrap_or(1); + let auth = matches.value_of("auth").map(|v| v.to_owned()); let pipeline = matches.subcommand_matches("pipeline").is_some(); Arc::new(Argv { @@ -105,6 +117,8 @@ fn parse_argv() -> Arc { port, pipeline, pool, + replicas, + auth, }) } @@ -185,12 +199,16 @@ fn spawn_client_task( let mut expected = 0; while utils::incr_atomic(&counter) < argv.count { - expected += 1; - let actual: i64 = client.incr(&key).await?; + if argv.replicas { + let _: () = client.replicas().get(&key).await?; + } else { + expected += 1; + let actual: i64 = client.incr(&key).await?; + // assert_eq!(actual, expected); + } if let Some(ref bar) = bar { bar.inc(1); } - // assert_eq!(actual, expected); } Ok::<_, RedisError>(()) @@ -210,28 +228,26 @@ fn main() { let config = RedisConfig { server: if argv.cluster { ServerConfig::Clustered { - hosts: vec![(argv.host.clone(), argv.port)], + hosts: vec![Server::new(&argv.host, argv.port)], } } else { ServerConfig::new_centralized(&argv.host, argv.port) }, + password: argv.auth.clone(), #[cfg(any(feature = "stdout-tracing", feature = "partial-tracing", feature = "full-tracing"))] tracing: TracingConfig::new(argv.tracing), ..Default::default() }; - let perf = PerformanceConfig { - auto_pipeline: argv.pipeline, - default_command_timeout_ms: 5000, - backpressure: BackpressureConfig { - policy: BackpressurePolicy::Drain, - max_in_flight_commands: 100_000_000, - ..Default::default() - }, - ..Default::default() - }; - let policy = ReconnectPolicy::new_constant(0, 500); - - let pool = RedisPool::new(config, Some(perf), Some(policy), argv.pool)?; + let pool = RedisBuilder::from_config(config) + .with_performance_config(|config| { + config.auto_pipeline = argv.pipeline; + config.backpressure.max_in_flight_commands = 100_000_000; + }) + .with_connection_config(|config| { + config.internal_command_timeout = Duration::from_secs(5); + }) + .set_policy(ReconnectPolicy::new_constant(0, 500)) + .build_pool(argv.pool)?; info!("Connecting to {}:{}...", argv.host, argv.port); let _ = pool.connect(); diff --git a/examples/README.md b/examples/README.md index 9fa172c6..adf48d4d 100644 --- a/examples/README.md +++ b/examples/README.md @@ -10,14 +10,14 @@ Examples * [Pipeline](./pipeline.rs) - Use the manual pipeline interface. * [Lua](./lua.rs) - Use the Lua scripting interface on a client. * [Scan](./scan.rs) - Use the SCAN interface to scan and read keys. -* [Prometheus](./prometheus.rs) - Use the metrics interface with prometheus. * [Pool](./pool.rs) - Use a redis connection pool. * [Monitor](./monitor.rs) - Process a `MONITOR` stream. * [Sentinel](./sentinel.rs) - Connect using a sentinel deployment. -* [Serde](./serde.rs) - Use the `serde-json` feature to convert between Redis types and JSON. +* [Serde JSON](./serde_json.rs) - Use the `serde-json` feature to convert between Redis types and JSON. +* [Redis JSON](./redis_json.rs) - Use the `redis-json` feature with `serde-json` types. * [Custom](./custom.rs) - Send custom commands or operate on RESP frames. * [DNS](./dns.rs) - Customize the DNS resolution logic. * [Client Tracking](./client_tracking.rs) - Implement [client side caching](https://redis.io/docs/manual/client-side-caching/). -* [Misc](./misc.rs) - Miscellaneous features or examples. +* [Misc](./misc.rs) - Miscellaneous or advanced features. Or see the [tests](../tests/integration) for more examples. \ No newline at end of file diff --git a/examples/basic.rs b/examples/basic.rs index bebdfc42..3a2747a7 100644 --- a/examples/basic.rs +++ b/examples/basic.rs @@ -1,16 +1,13 @@ -use fred::{prelude::*, types::RespVersion}; +#![allow(clippy::disallowed_names)] +#![allow(clippy::let_underscore_future)] +use fred::{prelude::*, types::RespVersion}; #[cfg(feature = "partial-tracing")] -use fred::tracing::Level; -#[cfg(any(feature = "enable-native-tls", feature = "enable-rustls"))] -use fred::types::TlsConfig; -#[cfg(feature = "partial-tracing")] -use fred::types::TracingConfig; +use fred::{tracing::Level, types::TracingConfig}; #[tokio::main] async fn main() -> Result<(), RedisError> { - pretty_env_logger::init(); - + // create a config from a URL let _ = RedisConfig::from_url("redis://username:password@foo.com:6379/1")?; // full configuration with testing values let config = RedisConfig { @@ -31,42 +28,26 @@ async fn main() -> Result<(), RedisError> { full_tracing_level: Level::DEBUG, }, }; - - // configure exponential backoff when reconnecting, starting at 100 ms, and doubling each time up to 30 sec. - let policy = ReconnectPolicy::new_exponential(0, 100, 30_000, 2); - let perf = PerformanceConfig::default(); - let client = RedisClient::new(config, Some(perf), Some(policy)); - - // spawn tasks that listen for connection close or reconnect events - let mut error_rx = client.on_error(); - let mut reconnect_rx = client.on_reconnect(); - - tokio::spawn(async move { - while let Ok(error) = error_rx.recv().await { - println!("Client disconnected with error: {:?}", error); - } - }); - tokio::spawn(async move { - while reconnect_rx.recv().await.is_ok() { - println!("Client reconnected."); - } - }); - + // see the Builder interface for more information + let _client = Builder::from_config(config).build()?; + // or use default values + let client = Builder::default_centralized().build()?; let connection_task = client.connect(); - let _ = client.wait_for_connect().await?; + client.wait_for_connect().await?; // convert response types to most common rust types let foo: Option = client.get("foo").await?; println!("Foo: {:?}", foo); - let _: () = client + client .set("foo", "bar", Some(Expiration::EX(1)), Some(SetOptions::NX), false) .await?; // or use turbofish. the first type is always the response type. - println!("Foo: {:?}", client.get::("foo").await?); + println!("Foo: {:?}", client.get::, _>("foo").await?); - let _ = client.quit().await?; + client.quit().await?; + // calling quit ends the connection and event listener tasks let _ = connection_task.await; Ok(()) } diff --git a/examples/blocking.rs b/examples/blocking.rs index 6d09ae54..bf25807f 100644 --- a/examples/blocking.rs +++ b/examples/blocking.rs @@ -1,9 +1,10 @@ +#![allow(clippy::disallowed_names)] +#![allow(clippy::let_underscore_future)] + use fred::prelude::*; use std::time::Duration; use tokio::time::sleep; -static COUNT: i64 = 50; - #[tokio::main] async fn main() -> Result<(), RedisError> { pretty_env_logger::init(); @@ -13,30 +14,25 @@ async fn main() -> Result<(), RedisError> { let _ = publisher_client.connect(); let _ = subscriber_client.connect(); - let _ = publisher_client.wait_for_connect().await?; - let _ = subscriber_client.wait_for_connect().await?; + publisher_client.wait_for_connect().await?; + subscriber_client.wait_for_connect().await?; - #[allow(unreachable_code)] let subscriber_jh = tokio::spawn(async move { loop { - let (key, value): (String, i64) = if let Some(result) = subscriber_client.blpop("foo", 5.0).await? { - result - } else { - // retry after a timeout - continue; + let (key, value): (String, i64) = match subscriber_client.blpop("foo", 5.0).await.ok() { + Some(value) => value, + None => continue, }; println!("BLPOP result on {}: {}", key, value); } - - Ok::<(), RedisError>(()) }); - for idx in 0 .. COUNT { - let _ = publisher_client.rpush("foo", idx).await?; + for idx in 0 .. 30 { + publisher_client.rpush("foo", idx).await?; sleep(Duration::from_secs(1)).await; } - let _ = subscriber_jh.abort(); + subscriber_jh.abort(); Ok(()) } diff --git a/examples/client_tracking.rs b/examples/client_tracking.rs index f61c6010..a62f0a45 100644 --- a/examples/client_tracking.rs +++ b/examples/client_tracking.rs @@ -1,24 +1,25 @@ +#![allow(clippy::disallowed_names)] +#![allow(clippy::let_underscore_future)] + use fred::{interfaces::TrackingInterface, prelude::*, types::RespVersion}; -// this library exposes 2 interfaces for implementing client-side caching - a high level `TrackingInterface` trait +// this library supports 2 interfaces for implementing client-side caching - a high level `TrackingInterface` trait // that requires RESP3 and works with all deployment types, and a lower level interface that directly exposes the // `CLIENT TRACKING` commands but often requires a centralized server config. async fn resp3_tracking_interface_example() -> Result<(), RedisError> { - let policy = ReconnectPolicy::new_constant(0, 1000); - let mut config = RedisConfig::default(); - config.version = RespVersion::RESP3; - - let client = RedisClient::new(config, None, Some(policy)); + let client = Builder::default_centralized() + .with_config(|config| { + config.version = RespVersion::RESP3; + }) + .build()?; let _ = client.connect(); let _ = client.wait_for_connect().await?; // spawn a task that processes invalidation messages. - let mut invalidations = client.on_invalidation(); - tokio::spawn(async move { - while let Ok(invalidation) = invalidations.recv().await { - println!("{}: Invalidate {:?}", invalidation.server, invalidation.keys); - } + let _ = client.on_invalidation(|invalidation| { + println!("{}: Invalidate {:?}", invalidation.server, invalidation.keys); + Ok(()) }); // enable client tracking on all connections. it's usually a good idea to do this in an `on_reconnect` block. @@ -27,10 +28,23 @@ async fn resp3_tracking_interface_example() -> Result<(), RedisError> { // send `CLIENT CACHING yes|no` before subsequent commands. the preceding `CLIENT CACHING yes|no` command will be // sent when the command is retried as well. - println!("foo: {}", client.caching(false).incr::("foo").await?); - println!("foo: {}", client.caching(true).incr::("foo").await?); - let _ = client.stop_tracking().await?; + let foo: i64 = client + .with_options(&Options { + caching: Some(true), + ..Default::default() + }) + .incr("foo") + .await?; + let bar: i64 = client + .with_options(&Options { + caching: Some(false), + ..Default::default() + }) + .incr("bar") + .await?; + println!("foo: {}, bar: {}", foo, bar); + let _ = client.stop_tracking().await?; Ok(()) } @@ -46,7 +60,7 @@ async fn resp2_basic_interface_example() -> Result<(), RedisError> { // the invalidation subscriber interface is the same as above even in RESP2 mode **as long as the `client-tracking` // feature is enabled**. if the feature is disabled then the message will appear on the `on_message` receiver. - let mut invalidations = subscriber.on_invalidation(); + let mut invalidations = subscriber.invalidation_rx(); tokio::spawn(async move { while let Ok(invalidation) = invalidations.recv().await { println!("{}: Invalidate {:?}", invalidation.server, invalidation.keys); @@ -80,10 +94,8 @@ async fn resp2_basic_interface_example() -> Result<(), RedisError> { } #[tokio::main] -// see https://redis.io/docs/manual/client-side-caching/ for more information +// see https://redis.io/docs/manual/client-side-caching/ async fn main() -> Result<(), RedisError> { - pretty_env_logger::init(); - resp3_tracking_interface_example().await?; // resp2_basic_interface_example().await?; diff --git a/examples/custom.rs b/examples/custom.rs index 49c5e933..03a542cb 100644 --- a/examples/custom.rs +++ b/examples/custom.rs @@ -1,54 +1,36 @@ +#![allow(clippy::disallowed_names)] +#![allow(clippy::let_underscore_future)] + use fred::{ + cmd, prelude::*, - types::{CustomCommand, RedisKey}, + types::{ClusterHash, CustomCommand}, }; use std::convert::TryInto; -fn get_hash_slot(client: &RedisClient, key: &'static str) -> (RedisKey, Option) { - let key = RedisKey::from_static_str(key); - let hash_slot = if client.is_clustered() { - // or use redis_protocol::redis_keyslot(key.as_bytes()) - Some(key.cluster_hash()) - } else { - None - }; - - (key, hash_slot) -} - #[tokio::main] async fn main() -> Result<(), RedisError> { - pretty_env_logger::init(); - - let client = RedisClient::new(RedisConfig::default(), None, None); + let client = Builder::default_centralized().build()?; let _ = client.connect(); - let _ = client.wait_for_connect().await?; - - let (key, hash_slot) = get_hash_slot(&client, "ts:carbon_monoxide"); - let args: Vec = vec![key.into(), 1112596200.into(), 1112603400.into()]; - let cmd = CustomCommand::new_static("TS.RANGE", hash_slot, false); - // >> TS.RANGE ts:carbon_monoxide 1112596200 1112603400 - // 1) 1) (integer) 1112596200 - // 2) "2.4" - // 2) 1) (integer) 1112599800 - // 2) "2.1" - // 3) 1) (integer) 1112603400 - // 2) "2.2" - let values: Vec<(i64, f64)> = client.custom(cmd, args).await?; - println!("TS.RANGE Values: {:?}", values); + client.wait_for_connect().await?; + client.lpush("foo", vec![1, 2, 3]).await?; - let _: () = client.lpush("foo", vec![1, 2, 3]).await?; - let (key, hash_slot) = get_hash_slot(&client, "foo"); - let cmd = CustomCommand::new_static("LRANGE", hash_slot, false); // some types require TryInto - let args: Vec = vec![key.into(), 0.into(), 3_u64.try_into()?]; + let args: Vec = vec!["foo".into(), 0.into(), 3_u64.try_into()?]; // returns a frame (https://docs.rs/redis-protocol/latest/redis_protocol/resp3/types/enum.Frame.html) - let frame = client.custom_raw(cmd, args).await?; + let frame = client.custom_raw(cmd!("LRANGE"), args).await?; // or convert back to client types let value: RedisValue = frame.try_into()?; // and/or use the type conversion shorthand let value: Vec = value.convert()?; println!("LRANGE Values: {:?}", value); + // or customize routing and blocking parameters + let _command = cmd!("FOO.BAR", blocking: true); + let _command = cmd!("FOO.BAR", hash: ClusterHash::FirstKey); + let _command = cmd!("FOO.BAR", hash: ClusterHash::FirstKey, blocking: true); + // which is shorthand for + let _command = CustomCommand::new("FOO.BAR", ClusterHash::FirstKey, true); + Ok(()) } diff --git a/examples/dns.rs b/examples/dns.rs index 94a3a9ec..41f84d7d 100644 --- a/examples/dns.rs +++ b/examples/dns.rs @@ -1,4 +1,8 @@ +#![allow(clippy::disallowed_names)] +#![allow(clippy::let_underscore_future)] + use async_trait::async_trait; +use bytes_utils::Str; use fred::{prelude::*, types::Resolve}; use std::{net::SocketAddr, sync::Arc}; use trust_dns_resolver::{ @@ -10,28 +14,31 @@ pub struct TrustDnsResolver(TokioAsyncResolver); impl TrustDnsResolver { fn new() -> Self { - TrustDnsResolver(TokioAsyncResolver::tokio(ResolverConfig::default(), ResolverOpts::default()).unwrap()) + TrustDnsResolver( + TokioAsyncResolver::tokio(ResolverConfig::default(), ResolverOpts::default()) + .expect("Failed to create DNS resolver"), + ) } } #[async_trait] impl Resolve for TrustDnsResolver { - async fn resolve(&self, host: String, port: u16) -> Result { - self.0.lookup_ip(&host).await.map_err(|e| e.into()).and_then(|ips| { - let ip = match ips.iter().next() { - Some(ip) => ip, - None => return Err(RedisError::new(RedisErrorKind::IO, "Failed to lookup IP address.")), - }; - - Ok(SocketAddr::new(ip, port)) - }) + async fn resolve(&self, host: Str, port: u16) -> Result, RedisError> { + Ok( + self + .0 + .lookup_ip(&host) + .await? + .into_iter() + .map(|ip| SocketAddr::new(ip, port)) + .collect(), + ) } } #[tokio::main] async fn main() -> Result<(), RedisError> { - let config = RedisConfig::default(); - let client = RedisClient::new(config, None, None); + let client = Builder::default_centralized().build()?; client.set_resolver(Arc::new(TrustDnsResolver::new())).await; let _ = client.connect(); diff --git a/examples/globals.rs b/examples/globals.rs index f2d9be27..75b1de7c 100644 --- a/examples/globals.rs +++ b/examples/globals.rs @@ -1,3 +1,6 @@ +#![allow(clippy::disallowed_names)] +#![allow(clippy::let_underscore_future)] + use fred::{globals, prelude::*}; #[cfg(feature = "custom-reconnect-errors")] @@ -5,7 +8,8 @@ use globals::ReconnectError; #[tokio::main] async fn main() -> Result<(), RedisError> { - globals::set_sentinel_connection_timeout_ms(10_000); + globals::set_default_broadcast_channel_capacity(64); + #[cfg(feature = "blocking-encoding")] globals::set_blocking_encode_threshold(10_000_000); #[cfg(feature = "custom-reconnect-errors")] @@ -14,6 +18,8 @@ async fn main() -> Result<(), RedisError> { ReconnectError::MasterDown, ReconnectError::ReadOnly, ]); + #[cfg(feature = "check-unresponsive")] + globals::set_unresponsive_interval_ms(1000); // ... diff --git a/examples/lua.rs b/examples/lua.rs index fe57c6d1..b59d3cef 100644 --- a/examples/lua.rs +++ b/examples/lua.rs @@ -1,10 +1,14 @@ +#![allow(clippy::disallowed_names)] +#![allow(clippy::let_underscore_future)] +#![allow(clippy::let_unit_value)] + use fred::{ prelude::*, types::{Library, Script}, util as fred_utils, }; -static SCRIPTS: &'static [&'static str] = &[ +static SCRIPTS: &[&str] = &[ "return {KEYS[1],KEYS[2],ARGV[1],ARGV[2]}", "return {KEYS[2],KEYS[1],ARGV[1],ARGV[2]}", "return {KEYS[1],KEYS[2],ARGV[2],ARGV[1]}", @@ -15,14 +19,14 @@ static SCRIPTS: &'static [&'static str] = &[ async fn main() -> Result<(), RedisError> { let client = RedisClient::default(); let _ = client.connect(); - let _ = client.wait_for_connect().await?; + client.wait_for_connect().await?; for script in SCRIPTS.iter() { let hash = fred_utils::sha1_hash(script); let mut script_exists: Vec = client.script_exists(&hash).await?; if !script_exists.pop().unwrap_or(false) { - let _ = client.script_load(*script).await?; + client.script_load(*script).await?; } let results = client.evalsha(&hash, vec!["foo", "bar"], vec![1, 2]).await?; @@ -38,6 +42,7 @@ async fn main() -> Result<(), RedisError> { } // or use the `Script` utility types +#[allow(dead_code)] async fn scripts() -> Result<(), RedisError> { let client = RedisClient::default(); let _ = client.connect(); @@ -52,6 +57,7 @@ async fn scripts() -> Result<(), RedisError> { } // use the `Function` and `Library` utility types +#[allow(dead_code)] async fn functions() -> Result<(), RedisError> { let client = RedisClient::default(); let _ = client.connect(); diff --git a/examples/misc.rs b/examples/misc.rs index dbcea01f..afb13767 100644 --- a/examples/misc.rs +++ b/examples/misc.rs @@ -1,51 +1,68 @@ -use fred::{ - prelude::*, - types::{BackpressureConfig, BackpressurePolicy, PerformanceConfig}, -}; +#![allow(clippy::disallowed_names)] +#![allow(clippy::let_underscore_future)] + +use fred::prelude::*; +use std::time::Duration; #[tokio::main] async fn main() -> Result<(), RedisError> { - // full configuration for performance tuning options - let perf = PerformanceConfig { - // whether or not to automatically pipeline commands across tasks - auto_pipeline: true, - // the max number of frames to feed into a socket before flushing it - max_feed_count: 1000, - // a default timeout to apply to all commands (0 means no timeout) - default_command_timeout_ms: 0, - // the amount of time to wait before rebuilding the client's cached cluster state after a MOVED error. - cluster_cache_update_delay_ms: 10, - // the maximum number of times to retry commands - max_command_attempts: 3, - // backpressure config options - backpressure: BackpressureConfig { - // whether to disable automatic backpressure features - disable_auto_backpressure: false, - // the max number of in-flight commands before applying backpressure or returning backpressure errors - max_in_flight_commands: 5000, - // the policy to apply when the max in-flight commands count is reached - policy: BackpressurePolicy::Drain, - }, - // the amount of time a command can wait in memory without a response before the connection is considered - // unresponsive - #[cfg(feature = "check-unresponsive")] - network_timeout_ms: 60_000, - }; - let config = RedisConfig { - server: ServerConfig::default_clustered(), - ..RedisConfig::default() - }; - - let client = RedisClient::new(config, Some(perf), None); + let client = Builder::default_centralized() + .with_performance_config(|config| { + config.max_feed_count = 1000; + config.auto_pipeline = true; + }) + .with_connection_config(|config| { + config.tcp = TcpConfig { + nodelay: Some(true), + ..Default::default() + }; + config.max_command_attempts = 5; + config.max_redirections = 5; + config.internal_command_timeout = Duration::from_secs(2); + config.connection_timeout = Duration::from_secs(10); + }) + // use exponential backoff, starting at 100 ms and doubling on each failed attempt up to 30 sec + .set_policy(ReconnectPolicy::new_exponential(0, 100, 30_000, 2)) + .build()?; let _ = client.connect(); - let _ = client.wait_for_connect().await?; + client.wait_for_connect().await?; + + // run all event listener functions in one task + let events_task = client.on_any( + |error| { + println!("Connection error: {:?}", error); + Ok(()) + }, + |server| { + println!("Reconnected to {:?}", server); + Ok(()) + }, + |changes| { + println!("Cluster changed: {:?}", changes); + Ok(()) + }, + ); // update performance config options let mut perf_config = client.perf_config(); - perf_config.max_command_attempts = 100; perf_config.max_feed_count = 1000; client.update_perf_config(perf_config); + // overwrite configuration options on individual commands + let options = Options { + max_attempts: Some(5), + max_redirections: Some(5), + timeout: Some(Duration::from_secs(10)), + ..Default::default() + }; + let _: Option = client.with_options(&options).get("foo").await?; + + // apply custom options to a pipeline + let pipeline = client.pipeline().with_options(&options); + pipeline.get("foo").await?; + pipeline.get("bar").await?; + let (_, _): (Option, Option) = pipeline.all().await?; + // interact with specific cluster nodes if client.is_clustered() { let connections = client.active_connections().await?; @@ -56,6 +73,7 @@ async fn main() -> Result<(), RedisError> { } } - let _ = client.quit().await?; + client.quit().await?; + let _ = events_task.await; Ok(()) } diff --git a/examples/monitor.rs b/examples/monitor.rs index bf42441b..9e5546ef 100644 --- a/examples/monitor.rs +++ b/examples/monitor.rs @@ -1,3 +1,6 @@ +#![allow(clippy::disallowed_names)] +#![allow(clippy::let_underscore_future)] + use fred::{monitor, prelude::*}; use futures::stream::StreamExt; use std::time::Duration; diff --git a/examples/pipeline.rs b/examples/pipeline.rs index db9a1c17..34d47ed6 100644 --- a/examples/pipeline.rs +++ b/examples/pipeline.rs @@ -1,3 +1,6 @@ +#![allow(clippy::disallowed_names)] +#![allow(clippy::let_underscore_future)] + use fred::prelude::*; #[tokio::main] @@ -6,7 +9,7 @@ async fn main() -> Result<(), RedisError> { // this example shows how to pipeline commands within one task. let client = RedisClient::default(); let _ = client.connect(); - let _ = client.wait_for_connect().await?; + client.wait_for_connect().await?; let pipeline = client.pipeline(); // commands are queued in memory @@ -19,22 +22,22 @@ async fn main() -> Result<(), RedisError> { let (first, second): (i64, i64) = pipeline.all().await?; assert_eq!((first, second), (1, 2)); - let _: () = client.del("foo").await?; + client.del("foo").await?; // or send the pipeline and only return the last result let pipeline = client.pipeline(); - let _: () = pipeline.incr("foo").await?; - let _: () = pipeline.incr("foo").await?; + pipeline.incr("foo").await?; + pipeline.incr("foo").await?; assert_eq!(pipeline.last::().await?, 2); - let _: () = client.del("foo").await?; + client.del("foo").await?; // or handle each command result individually let pipeline = client.pipeline(); - let _: () = pipeline.incr("foo").await?; - let _: () = pipeline.hgetall("foo").await?; // this will result in a `WRONGTYPE` error + pipeline.incr("foo").await?; + pipeline.hgetall("foo").await?; // this will result in a `WRONGTYPE` error let results = pipeline.try_all::().await; assert_eq!(results[0].clone().unwrap(), 1); assert!(results[1].is_err()); - let _ = client.quit().await?; + client.quit().await?; Ok(()) } diff --git a/examples/pool.rs b/examples/pool.rs index d31a70c4..eb5c8251 100644 --- a/examples/pool.rs +++ b/examples/pool.rs @@ -1,21 +1,35 @@ -use fred::{pool::RedisPool, prelude::*}; +#![allow(clippy::disallowed_names)] +#![allow(clippy::let_underscore_future)] + +use fred::prelude::*; #[tokio::main] async fn main() -> Result<(), RedisError> { - let config = RedisConfig::default(); - let pool = RedisPool::new(config, None, None, 5)?; + let pool = Builder::default_centralized().build_pool(5)?; let _ = pool.connect(); - let _ = pool.wait_for_connect().await?; + pool.wait_for_connect().await?; + + // interact with specific clients via next(), last(), or clients() + let pipeline = pool.next().pipeline(); + pipeline.incr("foo").await?; + pipeline.incr("foo").await?; + let _: i64 = pipeline.last().await?; for client in pool.clients() { println!("{} connected to {:?}", client.id(), client.active_connections().await?); + + // set up event listeners on each client + client.on_error(|error| { + println!("Connection error: {:?}", error); + Ok(()) + }); } - // use the pool like any other RedisClient - let _ = pool.get("foo").await?; - let _ = pool.set("foo", "bar", None, None, false).await?; - let _ = pool.get("foo").await?; + // or use the pool like any other RedisClient + pool.get("foo").await?; + pool.set("foo", "bar", None, None, false).await?; + pool.get("foo").await?; - let _ = pool.quit_pool().await; + let _ = pool.quit().await; Ok(()) } diff --git a/examples/prometheus.rs b/examples/prometheus.rs deleted file mode 100644 index 8b49002f..00000000 --- a/examples/prometheus.rs +++ /dev/null @@ -1,42 +0,0 @@ -use fred::prelude::*; -use prometheus::{register_int_counter_vec, register_int_gauge_vec, IntCounterVec, IntGaugeVec}; - -fn sample_metrics( - client: &RedisClient, - num_commands: IntCounterVec, - avg_latency: IntGaugeVec, - bytes_sent: IntCounterVec, -) { - let client_id = client.id(); - let latency_stats = client.take_latency_metrics(); - let req_size_stats = client.take_req_size_metrics(); - - if let Ok(metric) = num_commands.get_metric_with_label_values(&[client_id]) { - metric.inc_by(latency_stats.samples); - } - if let Ok(metric) = avg_latency.get_metric_with_label_values(&[client_id]) { - metric.set(latency_stats.avg as i64); - } - if let Ok(metric) = bytes_sent.get_metric_with_label_values(&[client_id]) { - metric.inc_by(req_size_stats.sum as u64); - } -} - -#[tokio::main] -async fn main() -> Result<(), RedisError> { - let num_commands = register_int_counter_vec!("redis_num_commands", "Number of redis commands", &["id"]).unwrap(); - let avg_latency = register_int_gauge_vec!("redis_avg_latency", "Average latency to redis.", &["id"]).unwrap(); - let bytes_sent = register_int_counter_vec!("redis_bytes_sent", "Total bytes sent to redis.", &["id"]).unwrap(); - - let config = RedisConfig::default(); - let client = RedisClient::new(config, None, None); - - let _ = client.connect(); - let _ = client.wait_for_connect(); - - // ... - - sample_metrics(&client, num_commands, avg_latency, bytes_sent); - let _ = client.quit().await?; - Ok(()) -} diff --git a/examples/pubsub.rs b/examples/pubsub.rs index a77ffa44..0e43cf83 100644 --- a/examples/pubsub.rs +++ b/examples/pubsub.rs @@ -1,3 +1,6 @@ +#![allow(clippy::disallowed_names)] +#![allow(clippy::let_underscore_future)] + #[allow(unused_imports)] use fred::clients::SubscriberClient; use fred::{prelude::*, types::PerformanceConfig}; @@ -5,8 +8,6 @@ use futures::stream::StreamExt; use std::time::Duration; use tokio::time::sleep; -const COUNT: usize = 60; - #[tokio::main] async fn main() -> Result<(), RedisError> { let publisher_client = RedisClient::default(); @@ -30,9 +31,9 @@ async fn main() -> Result<(), RedisError> { Ok::<_, RedisError>(()) }); - for idx in 0 .. COUNT { + for idx in 0 .. 50 { let _ = publisher_client.publish("foo", idx).await?; - sleep(Duration::from_millis(1000)).await; + sleep(Duration::from_secs(1)).await; } let _ = subscribe_task.abort(); @@ -41,8 +42,7 @@ async fn main() -> Result<(), RedisError> { #[cfg(feature = "subscriber-client")] async fn subscriber_example() -> Result<(), RedisError> { - let config = RedisConfig::default(); - let subscriber = SubscriberClient::new(config, None, None); + let subscriber = Builder::default_centralized().build_subscriber_client()?; let _ = subscriber.connect(); let _ = subscriber.wait_for_connect().await?; diff --git a/examples/redis_json.rs b/examples/redis_json.rs new file mode 100644 index 00000000..11368f09 --- /dev/null +++ b/examples/redis_json.rs @@ -0,0 +1,42 @@ +#![allow(clippy::disallowed_names)] +#![allow(clippy::let_underscore_future)] + +use fred::{interfaces::RedisJsonInterface, json_quote, prelude::*, util::NONE}; +use serde_json::{json, Value}; + +// see the serde-json example for more information on deserializing responses +#[tokio::main] +async fn main() -> Result<(), RedisError> { + let client = Builder::default_centralized().build()?; + + // operate on objects + let value = json!({ + "a": "b", + "c": 1, + "d": true + }); + let _: () = client.json_set("foo", "$", value.clone(), None).await?; + let result: Value = client.json_get("foo", NONE, NONE, NONE, "$").await?; + assert_eq!(value, result[0]); + let count: i64 = client.json_del("foo", "$..c").await?; + assert_eq!(count, 1); + let result: Value = client.json_get("foo", NONE, NONE, NONE, "$").await?; + assert_eq!(result[0], json!({ "a": "b", "d": true })); + + // operate on arrays + let _: () = client.json_set("foo", "$", json!(["a", "b"]), None).await?; + let size: i64 = client + .json_arrappend("foo", "$", vec![json_quote!("c"), json_quote!("d")]) + .await?; + assert_eq!(size, 4); + let size: i64 = client.json_arrappend("foo", "$", vec![json!({"e": "f"})]).await?; + assert_eq!(size, 5); + let len: i64 = client.json_arrlen("foo", NONE).await?; + assert_eq!(len, 5); + + let result: Value = client.json_get("foo", NONE, NONE, NONE, "$").await?; + assert_eq!(result[0], json!(["a", "b", "c", "d", { "e": "f" }])); + + // or see the redis-json integration tests for more + Ok(()) +} diff --git a/examples/scan.rs b/examples/scan.rs index d833bf11..800745e2 100644 --- a/examples/scan.rs +++ b/examples/scan.rs @@ -1,18 +1,21 @@ +#![allow(clippy::disallowed_names)] +#![allow(clippy::let_underscore_future)] + use fred::{prelude::*, types::Scanner}; use futures::stream::StreamExt; -static COUNT: u32 = 50; +static COUNT: usize = 50; async fn create_fake_data(client: &RedisClient) -> Result<(), RedisError> { for idx in 0 .. COUNT { - let _ = client.set(format!("foo-{}", idx), idx, None, None, false).await?; + client.set(format!("foo-{}", idx), idx, None, None, false).await?; } Ok(()) } async fn delete_fake_data(client: &RedisClient) -> Result<(), RedisError> { for idx in 0 .. COUNT { - let _ = client.del(format!("foo-{}", idx)).await?; + client.del(format!("foo-{}", idx)).await?; } Ok(()) } @@ -21,11 +24,11 @@ async fn delete_fake_data(client: &RedisClient) -> Result<(), RedisError> { async fn main() -> Result<(), RedisError> { let client = RedisClient::default(); let _ = client.connect(); - let _ = client.wait_for_connect().await?; - let _ = create_fake_data(&client).await?; + client.wait_for_connect().await?; + create_fake_data(&client).await?; // build up a buffer of (key, value) pairs from pages (~10 keys per page) - let mut buffer = Vec::with_capacity(COUNT as usize); + let mut buffer = Vec::with_capacity(COUNT); let mut scan_stream = client.scan("foo*", Some(10), None); while let Some(result) = scan_stream.next().await { @@ -42,12 +45,12 @@ async fn main() -> Result<(), RedisError> { } } - // move on to the next page now that we're done reading the values. or move this before we call `get` on each key - // to scan results in the background as quickly as possible. + // **important:** move on to the next page now that we're done reading the values. or move this before we call + // `get` on each key to scan results in the background as quickly as possible. let _ = page.next(); } - let _ = delete_fake_data(&client).await?; - let _ = client.quit().await?; + delete_fake_data(&client).await?; + client.quit().await?; Ok(()) } diff --git a/examples/sentinel.rs b/examples/sentinel.rs index 4960a3e0..4923e80f 100644 --- a/examples/sentinel.rs +++ b/examples/sentinel.rs @@ -1,3 +1,6 @@ +#![allow(clippy::disallowed_names)] +#![allow(clippy::let_underscore_future)] + use fred::{prelude::*, types::Server}; #[tokio::main] @@ -13,7 +16,6 @@ async fn main() -> Result<(), RedisError> { Server::new("localhost", 26380), Server::new("localhost", 26381), ], - // note: by default sentinel nodes use the same authentication settings as the redis servers, however // callers can also use the `sentinel-auth` feature to use different credentials to sentinel nodes #[cfg(feature = "sentinel-auth")] username: None, @@ -24,13 +26,12 @@ async fn main() -> Result<(), RedisError> { ..Default::default() }; - let policy = ReconnectPolicy::default(); - let client = RedisClient::new(config, None, Some(policy)); + let client = Builder::from_config(config).build()?; let _ = client.connect(); - let _ = client.wait_for_connect().await?; + client.wait_for_connect().await?; // ... - let _ = client.quit().await?; + client.quit().await?; Ok(()) } diff --git a/examples/serde.rs b/examples/serde_json.rs similarity index 52% rename from examples/serde.rs rename to examples/serde_json.rs index b0301366..4c9cc31d 100644 --- a/examples/serde.rs +++ b/examples/serde_json.rs @@ -1,3 +1,6 @@ +#![allow(clippy::disallowed_names)] +#![allow(clippy::let_underscore_future)] + use fred::prelude::*; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; @@ -12,8 +15,6 @@ struct Person { #[tokio::main] async fn main() -> Result<(), RedisError> { - pretty_env_logger::init(); - let client = RedisClient::default(); let _ = client.connect(); let _ = client.wait_for_connect().await?; @@ -22,19 +23,11 @@ async fn main() -> Result<(), RedisError> { "foo": "a", "bar": "b" }); - // json `Value` objects can also be used interchangeably with `RedisMap` type arguments. - let _: () = client.hset("wobble", value.clone()).await?; let _: () = client.set("wibble", value.to_string(), None, None, false).await?; - // converting back to a json `Value` will also try to parse nested json strings, if possible. - // the type conversion logic will not attempt the json parsing if the value doesn't look like json. - // if a value looks like json, but cannot be parsed as json, then it will be returned as a string. - let get_result: Value = client.get("wibble").await?; - println!("GET Result: {}", get_result); - let hget_result: Value = client.hgetall("wobble").await?; - println!("HGETALL Result: {}", hget_result); - assert_eq!(value, get_result); - assert_eq!(value, hget_result); + // converting back to a json `Value` will also try to parse nested json strings. if a value looks like json, but + // cannot be parsed as json, then it will be returned as a string. + assert_eq!(value, client.get("wibble").await?); // or store types as json strings via Serialize and Deserialize let person = Person { @@ -44,12 +37,14 @@ async fn main() -> Result<(), RedisError> { }; let serialized = serde_json::to_string(&person)?; - let _: () = client.set("person 1", serialized, None, None, false).await?; + let _: () = client.set("foo", serialized, None, None, false).await?; // deserialize as a json value - let deserialized: Person = serde_json::from_value(client.get::("person 1").await?)?; + let person_json: Value = client.get("foo").await?; + let deserialized: Person = serde_json::from_value(person_json)?; assert_eq!(person, deserialized); // or as a json string - let deserialized: Person = serde_json::from_str(&client.get::("person 1").await?)?; + let person_string: String = client.get("foo").await?; + let deserialized: Person = serde_json::from_str(&person_string)?; assert_eq!(person, deserialized); let _ = client.quit().await; diff --git a/examples/tls.rs b/examples/tls.rs index 8ed29ab2..a5e76535 100644 --- a/examples/tls.rs +++ b/examples/tls.rs @@ -1,3 +1,6 @@ +#![allow(clippy::disallowed_names)] +#![allow(clippy::let_underscore_future)] + use fred::prelude::*; #[cfg(any(feature = "enable-native-tls", feature = "enable-rustls"))] @@ -36,7 +39,7 @@ async fn main() -> Result<(), RedisError> { tls: Some(create_tls_config().into()), ..RedisConfig::default() }; - let client = RedisClient::new(config, None, None); + let client = Builder::from_config(config).build()?; let _ = client.connect(); if let Err(error) = client.wait_for_connect().await { @@ -45,6 +48,6 @@ async fn main() -> Result<(), RedisError> { // ... - let _ = client.quit().await?; + client.quit().await?; Ok(()) } diff --git a/examples/transactions.rs b/examples/transactions.rs index a95dc203..3b152886 100644 --- a/examples/transactions.rs +++ b/examples/transactions.rs @@ -1,11 +1,13 @@ +#![allow(clippy::disallowed_names)] +#![allow(clippy::let_underscore_future)] + use fred::prelude::*; #[tokio::main] async fn main() -> Result<(), RedisError> { let client = RedisClient::default(); - let _ = client.connect(); - let _ = client.wait_for_connect().await?; + client.wait_for_connect().await?; let trx = client.multi(); let result: RedisValue = trx.get("foo").await?; @@ -18,6 +20,6 @@ async fn main() -> Result<(), RedisError> { let values: (Option, (), String) = trx.exec(true).await?; println!("Transaction results: {:?}", values); - let _ = client.quit().await?; + client.quit().await?; Ok(()) } diff --git a/src/clients/caching.rs b/src/clients/caching.rs deleted file mode 100644 index a0d3ff00..00000000 --- a/src/clients/caching.rs +++ /dev/null @@ -1,100 +0,0 @@ -use crate::{ - interfaces::{ - AclInterface, - ClientLike, - FunctionInterface, - GeoInterface, - HashesInterface, - HyperloglogInterface, - KeysInterface, - ListInterface, - MemoryInterface, - SetsInterface, - SortedSetsInterface, - StreamsInterface, - }, - modules::inner::RedisClientInner, - protocol::command::RedisCommand, - utils, -}; -use std::{ - fmt, - fmt::Formatter, - sync::{atomic::AtomicBool, Arc}, -}; - -/// A struct for controlling [client caching](https://redis.io/commands/client-caching/) on commands. -/// -/// ```rust no_run -/// # use fred::prelude::*; -/// -/// async fn example(client: &RedisClient) -> Result<(), RedisError> { -/// // send `CLIENT CACHING no` before `HSET foo bar baz` -/// let _ = client.caching(false).hset("foo", "bar", "baz").await?; -/// -/// // or reuse the caching interface -/// let caching = client.caching(true); -/// // send `CLIENT CACHING yes` before each `incr` command -/// println!("abc: {}", caching.incr::("abc").await?); -/// println!("abc: {}", caching.incr::("abc").await?); -/// Ok(()) -/// } -/// ``` -#[derive(Clone)] -#[cfg_attr(docsrs, doc(cfg(feature = "client-tracking")))] -pub struct Caching { - inner: Arc, - enabled: Arc, -} - -impl fmt::Debug for Caching { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - f.debug_struct("Caching") - .field("id", &self.inner.id) - .field("enabled", &utils::read_bool_atomic(&self.enabled)) - .finish() - } -} - -impl Caching { - pub(crate) fn new(inner: &Arc, value: bool) -> Caching { - Caching { - inner: inner.clone(), - enabled: Arc::new(AtomicBool::new(value)), - } - } - - /// Read whether caching is enabled. - pub fn is_enabled(&self) -> bool { - utils::read_bool_atomic(&self.enabled) - } - - /// Set whether caching is enabled, returning the previous value. - pub fn set_enabled(&self, val: bool) -> bool { - utils::set_bool_atomic(&self.enabled, val) - } -} - -impl ClientLike for Caching { - #[doc(hidden)] - fn inner(&self) -> &Arc { - &self.inner - } - - #[doc(hidden)] - fn change_command(&self, cmd: &mut RedisCommand) { - cmd.caching = Some(utils::read_bool_atomic(&self.enabled)); - } -} - -impl AclInterface for Caching {} -impl GeoInterface for Caching {} -impl HashesInterface for Caching {} -impl HyperloglogInterface for Caching {} -impl KeysInterface for Caching {} -impl ListInterface for Caching {} -impl MemoryInterface for Caching {} -impl SetsInterface for Caching {} -impl SortedSetsInterface for Caching {} -impl FunctionInterface for Caching {} -impl StreamsInterface for Caching {} diff --git a/src/clients/mod.rs b/src/clients/mod.rs index 1b125270..490ca4d0 100644 --- a/src/clients/mod.rs +++ b/src/clients/mod.rs @@ -1,10 +1,12 @@ -mod node; +mod options; mod pipeline; +mod pool; mod redis; mod transaction; -pub use node::Node; +pub use options::WithOptions; pub use pipeline::Pipeline; +pub use pool::RedisPool; pub use redis::RedisClient; pub use transaction::Transaction; @@ -25,9 +27,3 @@ mod replica; #[cfg(feature = "replicas")] #[cfg_attr(docsrs, doc(cfg(feature = "replicas")))] pub use replica::Replicas; - -#[cfg(feature = "client-tracking")] -mod caching; -#[cfg(feature = "client-tracking")] -#[cfg_attr(docsrs, doc(cfg(feature = "client-tracking")))] -pub use caching::Caching; diff --git a/src/clients/node.rs b/src/clients/node.rs deleted file mode 100644 index 1274ae06..00000000 --- a/src/clients/node.rs +++ /dev/null @@ -1,211 +0,0 @@ -use crate::{ - clients::{Pipeline, RedisClient}, - commands, - error::RedisError, - interfaces::{ - AclInterface, - AuthInterface, - ClientInterface, - ClientLike, - ClusterInterface, - ConfigInterface, - FunctionInterface, - GeoInterface, - HashesInterface, - HyperloglogInterface, - KeysInterface, - ListInterface, - LuaInterface, - MemoryInterface, - ServerInterface, - SetsInterface, - SlowlogInterface, - SortedSetsInterface, - StreamsInterface, - }, - modules::inner::RedisClientInner, - protocol::command::RedisCommand, - types::{ScanResult, ScanType, Server}, -}; -use bytes_utils::Str; -use futures::Stream; -use std::sync::Arc; - -use crate::interfaces::PubsubInterface; -#[cfg(feature = "client-tracking")] -use crate::{ - interfaces::RedisResult, - types::{FromRedis, MultipleStrings, Toggle}, -}; - -/// A struct for interacting with individual nodes in a cluster. -/// -/// See [with_cluster_node](crate::clients::RedisClient::with_cluster_node) for more information. -/// -/// ``` -/// # use fred::prelude::*; -/// async fn example(client: &RedisClient) -> Result<(), RedisError> { -/// // discover servers via the `RedisConfig` or active connections -/// let connections = client.active_connections().await?; -/// -/// // ping each node in the cluster individually -/// for server in connections.into_iter() { -/// let _: () = client.with_cluster_node(server).ping().await?; -/// } -/// -/// // or use the cached cluster routing table to discover servers -/// let servers = client -/// .cached_cluster_state() -/// .expect("Failed to read cached cluster state") -/// .unique_primary_nodes(); -/// for server in servers { -/// // verify the server address with `CLIENT INFO` -/// let server_addr = client -/// .with_cluster_node(&server) -/// .client_info::() -/// .await? -/// .split(" ") -/// .find_map(|s| { -/// let parts: Vec<&str> = s.split("=").collect(); -/// if parts[0] == "laddr" { -/// Some(parts[1].to_owned()) -/// } else { -/// None -/// } -/// }) -/// .expect("Failed to read or parse client info."); -/// -/// assert_eq!(server_addr, server.to_string()); -/// } -/// -/// Ok(()) -/// } -/// ``` -#[derive(Clone)] -pub struct Node { - inner: Arc, - server: Server, -} - -impl ClientLike for Node { - #[doc(hidden)] - fn inner(&self) -> &Arc { - &self.inner - } - - #[doc(hidden)] - fn change_command(&self, cmd: &mut RedisCommand) { - cmd.cluster_node = Some(self.server.clone()); - } -} - -impl Node { - pub(crate) fn new(inner: &Arc, server: Server) -> Node { - Node { - inner: inner.clone(), - server, - } - } - - /// Read the server to which all commands will be sent. - pub fn server(&self) -> &Server { - &self.server - } - - /// Create a client instance that can interact with all cluster nodes. - pub fn client(&self) -> RedisClient { - self.inner().into() - } - - /// Send a series of commands in a pipeline to the cluster node. - pub fn pipeline(&self) -> Pipeline { - Pipeline::from(self.clone()) - } - - /// Incrementally iterate over a set of keys matching the `pattern` argument, returning `count` results per page, if - /// specified. - /// - /// The scan operation can be canceled by dropping the returned stream. - /// - /// - pub fn scan

( - &self, - pattern: P, - count: Option, - r#type: Option, - ) -> impl Stream> - where - P: Into, - { - commands::scan::scan(&self.inner, pattern.into(), count, r#type, Some(self.server.clone())) - } -} - -impl AclInterface for Node {} -impl ClusterInterface for Node {} -impl ConfigInterface for Node {} -impl GeoInterface for Node {} -impl HashesInterface for Node {} -impl HyperloglogInterface for Node {} -impl KeysInterface for Node {} -impl LuaInterface for Node {} -impl ListInterface for Node {} -impl MemoryInterface for Node {} -impl AuthInterface for Node {} -impl ServerInterface for Node {} -impl SlowlogInterface for Node {} -impl SetsInterface for Node {} -impl SortedSetsInterface for Node {} -impl StreamsInterface for Node {} -impl FunctionInterface for Node {} -impl PubsubInterface for Node {} - -// remove the restriction on clustered deployments with the basic `CLIENT TRACKING` commands here -#[async_trait] -impl ClientInterface for Node { - /// This command enables the tracking feature of the Redis server that is used for server assisted client side - /// caching. - /// - /// - #[cfg(feature = "client-tracking")] - #[cfg_attr(docsrs, doc(cfg(feature = "client-tracking")))] - async fn client_tracking( - &self, - toggle: T, - redirect: Option, - prefixes: P, - bcast: bool, - optin: bool, - optout: bool, - noloop: bool, - ) -> RedisResult - where - R: FromRedis, - T: TryInto + Send, - T::Error: Into + Send, - P: Into + Send, - { - try_into!(toggle); - into!(prefixes); - commands::tracking::client_tracking(self, toggle, redirect, prefixes, bcast, optin, optout, noloop) - .await? - .convert() - } - - /// This command controls the tracking of the keys in the next command executed by the connection, when tracking is - /// enabled in OPTIN or OPTOUT mode. - /// - /// - /// - /// Note: **This function requires a centralized server**. See - /// [crate::interfaces::TrackingInterface::caching] for a version that works with all server deployment - /// modes. - #[cfg(feature = "client-tracking")] - #[cfg_attr(docsrs, doc(cfg(feature = "client-tracking")))] - async fn client_caching(&self, enabled: bool) -> RedisResult - where - R: FromRedis, - { - commands::tracking::client_caching(self, enabled).await?.convert() - } -} diff --git a/src/clients/options.rs b/src/clients/options.rs new file mode 100644 index 00000000..ac80cb93 --- /dev/null +++ b/src/clients/options.rs @@ -0,0 +1,118 @@ +use crate::{ + error::RedisError, + interfaces::*, + modules::inner::RedisClientInner, + protocol::command::RedisCommand, + types::Options, +}; +use std::{fmt, ops::Deref, sync::Arc}; + +/// A client interface used to customize command configuration options. +/// +/// See [Options](crate::types::Options) for more information. +/// +/// ```rust +/// # use fred::prelude::*; +/// # use std::time::Duration; +/// async fn example() -> Result<(), RedisError> { +/// let client = RedisClient::default(); +/// let _ = client.connect(); +/// let _ = client.wait_for_connect().await?; +/// +/// let options = Options { +/// max_redirections: Some(3), +/// max_attempts: Some(1), +/// timeout: Some(Duration::from_secs(10)), +/// ..Default::default() +/// }; +/// let foo: Option = client.with_options(&options).get("foo").await?; +/// +/// // reuse the options bindings +/// let with_options = client.with_options(&options); +/// let foo: () = with_options.get("foo").await?; +/// let bar: () = with_options.get("bar").await?; +/// +/// // combine with other client types +/// let pipeline = client.pipeline().with_options(&options); +/// let _: () = pipeline.get("foo").await?; +/// let _: () = pipeline.get("bar").await?; +/// // custom options will be applied to each command +/// println!("results: {:?}", pipeline.all::().await?); +/// +/// Ok(()) +/// } +/// ``` +#[derive(Clone)] +pub struct WithOptions { + pub(crate) client: C, + pub(crate) options: Options, +} + +impl WithOptions { + /// Read the options that will be applied to commands. + pub fn options(&self) -> &Options { + &self.options + } +} + +impl Deref for WithOptions { + type Target = C; + + fn deref(&self) -> &Self::Target { + &self.client + } +} + +impl fmt::Debug for WithOptions { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("WithOptions") + .field("client", &self.client.id()) + .field("options", &self.options) + .finish() + } +} + +impl ClientLike for WithOptions { + #[doc(hidden)] + fn inner(&self) -> &Arc { + self.client.inner() + } + + #[doc(hidden)] + fn change_command(&self, command: &mut RedisCommand) { + self.client.change_command(command); + self.options.apply(command); + } + + #[doc(hidden)] + fn send_command(&self, command: T) -> Result<(), RedisError> + where + T: Into, + { + let mut command: RedisCommand = command.into(); + self.options.apply(&mut command); + self.client.send_command(command) + } +} + +impl AclInterface for WithOptions {} +impl ClientInterface for WithOptions {} +impl ClusterInterface for WithOptions {} +impl PubsubInterface for WithOptions {} +impl ConfigInterface for WithOptions {} +impl GeoInterface for WithOptions {} +impl HashesInterface for WithOptions {} +impl HyperloglogInterface for WithOptions {} +impl KeysInterface for WithOptions {} +impl ListInterface for WithOptions {} +impl MemoryInterface for WithOptions {} +impl AuthInterface for WithOptions {} +impl ServerInterface for WithOptions {} +impl SlowlogInterface for WithOptions {} +impl SetsInterface for WithOptions {} +impl SortedSetsInterface for WithOptions {} +impl StreamsInterface for WithOptions {} +impl FunctionInterface for WithOptions {} +#[cfg(feature = "redis-json")] +#[cfg_attr(docsrs, doc(cfg(feature = "redis-json")))] +impl RedisJsonInterface for WithOptions {} diff --git a/src/clients/pipeline.rs b/src/clients/pipeline.rs index 5bba6b7c..5b06a5df 100644 --- a/src/clients/pipeline.rs +++ b/src/clients/pipeline.rs @@ -36,6 +36,20 @@ use parking_lot::Mutex; use std::{collections::VecDeque, fmt, fmt::Formatter, sync::Arc}; use tokio::sync::oneshot::{channel as oneshot_channel, Receiver as OneshotReceiver}; +#[cfg(feature = "redis-json")] +use crate::interfaces::RedisJsonInterface; + +fn clone_buffered_commands(buffer: &Mutex>) -> VecDeque { + let guard = buffer.lock(); + let mut out = VecDeque::with_capacity(guard.len()); + + for command in guard.iter() { + out.push_back(command.duplicate(ResponseKind::Skip)); + } + + out +} + fn prepare_all_commands( commands: VecDeque, error_early: bool, @@ -63,6 +77,8 @@ fn prepare_all_commands( } /// Send a series of commands in a [pipeline](https://redis.io/docs/manual/pipelining/). +/// +/// See the [all](Self::all), [last](Self::last), and [try_all](Self::try_all) functions for more information. pub struct Pipeline { commands: Arc>>, client: C, @@ -100,7 +116,7 @@ impl From for Pipeline { impl ClientLike for Pipeline { #[doc(hidden)] fn inner(&self) -> &Arc { - &self.client.inner() + self.client.inner() } #[doc(hidden)] @@ -125,7 +141,7 @@ impl ClientLike for Pipeline { let _ = tx.send(Ok(protocol_utils::queued_frame())); } - self.commands.lock().push_back(command.into()); + self.commands.lock().push_back(command); Ok(()) } } @@ -148,6 +164,9 @@ impl SetsInterface for Pipeline {} impl SortedSetsInterface for Pipeline {} impl StreamsInterface for Pipeline {} impl FunctionInterface for Pipeline {} +#[cfg(feature = "redis-json")] +#[cfg_attr(docsrs, doc(cfg(feature = "redis-json")))] +impl RedisJsonInterface for Pipeline {} impl Pipeline { /// Send the pipeline and respond with an array of all responses. @@ -166,11 +185,11 @@ impl Pipeline { /// Ok(()) /// } /// ``` - pub async fn all(self) -> Result + pub async fn all(&self) -> Result where R: FromRedis, { - let commands = { self.commands.lock().drain(..).collect() }; + let commands = clone_buffered_commands(&self.commands); send_all(self.client.inner(), commands).await?.convert() } @@ -195,11 +214,11 @@ impl Pipeline { /// Ok(()) /// } /// ``` - pub async fn try_all(self) -> Vec> + pub async fn try_all(&self) -> Vec> where R: FromRedis, { - let commands = { self.commands.lock().drain(..).collect() }; + let commands = clone_buffered_commands(&self.commands); try_send_all(self.client.inner(), commands) .await .into_iter() @@ -216,16 +235,17 @@ impl Pipeline { /// let _: () = pipeline.incr("foo").await?; // returns when the command is queued in memory /// let _: () = pipeline.incr("foo").await?; // returns when the command is queued in memory /// - /// let result: i64 = pipeline.last().await?; - /// assert_eq!(results, 2); + /// assert_eq!(pipeline.last::().await?, 2); + /// // pipelines can also be reused + /// assert_eq!(pipeline.last::().await?, 4); /// Ok(()) /// } /// ``` - pub async fn last(self) -> Result + pub async fn last(&self) -> Result where R: FromRedis, { - let commands = { self.commands.lock().drain(..).collect() }; + let commands = clone_buffered_commands(&self.commands); send_last(self.client.inner(), commands).await?.convert() } } @@ -238,11 +258,14 @@ async fn try_send_all( return Vec::new(); } - let (command, rx) = prepare_all_commands(commands, false); + let (mut command, rx) = prepare_all_commands(commands, false); + command.inherit_options(inner); + let timeout_dur = command.timeout_dur().unwrap_or_else(|| inner.default_command_timeout()); + if let Err(e) = interfaces::send_to_router(inner, command) { return vec![Err(e)]; }; - let frame = match utils::apply_timeout(rx, inner.default_command_timeout()).await { + let frame = match utils::apply_timeout(rx, timeout_dur).await { Ok(result) => match result { Ok(f) => f, Err(e) => return vec![Err(e)], @@ -253,10 +276,10 @@ async fn try_send_all( if let Resp3Frame::Array { data, .. } = frame { data .into_iter() - .map(|frame| protocol_utils::frame_to_results(frame)) + .map(protocol_utils::frame_to_results) .collect() } else { - vec![protocol_utils::frame_to_results_raw(frame)] + vec![protocol_utils::frame_to_results(frame)] } } @@ -265,10 +288,13 @@ async fn send_all(inner: &Arc, commands: VecDeque = commands.into_iter().collect(); commands[len - 1].response = ResponseKind::Respond(Some(tx)); - let command = RouterCommand::Pipeline { commands }; + let mut command = RouterCommand::Pipeline { commands }; + command.inherit_options(inner); + let timeout_dur = command.timeout_dur().unwrap_or_else(|| inner.default_command_timeout()); - let _ = interfaces::send_to_router(inner, command)?; - let frame = utils::apply_timeout(rx, inner.default_command_timeout()).await??; - protocol_utils::frame_to_results_raw(frame) + interfaces::send_to_router(inner, command)?; + let frame = utils::apply_timeout(rx, timeout_dur).await??; + protocol_utils::frame_to_results(frame) } diff --git a/src/clients/pool.rs b/src/clients/pool.rs new file mode 100644 index 00000000..6ea4b5ab --- /dev/null +++ b/src/clients/pool.rs @@ -0,0 +1,269 @@ +use crate::{ + clients::RedisClient, + error::{RedisError, RedisErrorKind}, + interfaces::*, + modules::inner::RedisClientInner, + types::{ConnectHandle, ConnectionConfig, PerformanceConfig, ReconnectPolicy, RedisConfig, Server}, + utils, +}; +use futures::future::{join_all, try_join_all}; +use std::{ + fmt, + sync::{atomic::AtomicUsize, Arc}, + time::Duration, +}; +use tokio::time::interval as tokio_interval; + +#[cfg(feature = "dns")] +use crate::protocol::types::Resolve; + +#[cfg(feature = "replicas")] +use crate::clients::Replicas; + +/// A cheaply cloneable round-robin client pool. +/// +/// ### Restrictions +/// +/// The following interfaces are not implemented on `RedisPool`: +/// * [MetricsInterface](crate::interfaces::MetricsInterface) +/// * [PubsubInterface](crate::interfaces::PubsubInterface) +/// * [TransactionInterface](crate::interfaces::TransactionInterface) +/// * [EventInterface](crate::interfaces::EventInterface) +/// * [ClientInterface](crate::interfaces::ClientInterface) +/// +/// In some cases, such as [publish](crate::interfaces::PubsubInterface::publish), callers can work around this by +/// adding a call to [next](Self::next), but in other scenarios this may not work. As a general rule, any commands +/// that change or depend on local connection state will not be implemented directly on `RedisPool`. Callers can use +/// [clients](Self::clients), [next](Self::next), or [last](Self::last) to operate on individual clients if needed. +#[derive(Clone)] +pub struct RedisPool { + clients: Arc>, + counter: Arc, +} + +impl fmt::Debug for RedisPool { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("RedisPool").field("size", &self.clients.len()).finish() + } +} + +impl RedisPool { + /// Create a new pool from an existing set of clients. + pub fn from_clients(clients: Vec) -> Result { + if clients.is_empty() { + Err(RedisError::new(RedisErrorKind::Config, "Pool cannot be empty.")) + } else { + Ok(RedisPool { + clients: Arc::new(clients), + counter: Arc::new(AtomicUsize::new(0)), + }) + } + } + + /// Create a new pool without connecting to the server. + /// + /// See the [builder](crate::types::Builder) interface for more information. + pub fn new( + config: RedisConfig, + perf: Option, + connection: Option, + policy: Option, + size: usize, + ) -> Result { + if size == 0 { + Err(RedisError::new(RedisErrorKind::Config, "Pool cannot be empty.")) + } else { + let mut clients = Vec::with_capacity(size); + for _ in 0 .. size { + clients.push(RedisClient::new( + config.clone(), + perf.clone(), + connection.clone(), + policy.clone(), + )); + } + + Ok(RedisPool { + clients: Arc::new(clients), + counter: Arc::new(AtomicUsize::new(0)), + }) + } + } + + /// Read the individual clients in the pool. + pub fn clients(&self) -> &[RedisClient] { + &self.clients + } + + /// Connect each client to the server, returning the task driving each connection. + /// + /// Use the base [connect](Self::connect) function to return one handle that drives all connections via [join](https://docs.rs/futures/latest/futures/macro.join.html). + pub fn connect_pool(&self) -> Vec { + self.clients.iter().map(|c| c.connect()).collect() + } + + /// Read the size of the pool. + pub fn size(&self) -> usize { + self.clients.len() + } + + /// Read the client that should run the next command. + #[cfg(feature = "pool-prefer-active")] + pub fn next(&self) -> &RedisClient { + let mut idx = utils::incr_atomic(&self.counter) % self.clients.len(); + + for _ in 0 .. self.clients.len() { + let client = &self.clients[idx]; + if client.is_connected() { + return client; + } + idx = (idx + 1) % self.clients.len(); + } + + &self.clients[idx] + } + + /// Read the client that should run the next command. + #[cfg(not(feature = "pool-prefer-active"))] + pub fn next(&self) -> &RedisClient { + &self.clients[utils::incr_atomic(&self.counter) % self.clients.len()] + } + + /// Read the client that ran the last command. + pub fn last(&self) -> &RedisClient { + &self.clients[utils::read_atomic(&self.counter) % self.clients.len()] + } + + /// Create a client that interacts with the replica nodes associated with the [next](Self::next) client. + #[cfg(feature = "replicas")] + #[cfg_attr(docsrs, doc(cfg(feature = "replicas")))] + pub fn replicas(&self) -> Replicas { + Replicas::from(self.inner()) + } +} + +#[async_trait] +impl ClientLike for RedisPool { + #[doc(hidden)] + fn inner(&self) -> &Arc { + &self.next().inner + } + + /// Update the internal [PerformanceConfig](crate::types::PerformanceConfig) on each client in place with new + /// values. + fn update_perf_config(&self, config: PerformanceConfig) { + for client in self.clients.iter() { + client.update_perf_config(config.clone()); + } + } + + /// Read the set of active connections across all clients in the pool. + async fn active_connections(&self) -> Result, RedisError> { + let all_connections = try_join_all(self.clients.iter().map(|c| c.active_connections())).await?; + let total_size = if all_connections.is_empty() { + return Ok(Vec::new()); + } else { + all_connections.len() * all_connections[0].len() + }; + let mut out = Vec::with_capacity(total_size); + + for connections in all_connections.into_iter() { + out.extend(connections); + } + Ok(out) + } + + /// Override the DNS resolution logic for all clients in the pool. + #[cfg(feature = "dns")] + #[cfg_attr(docsrs, doc(cfg(feature = "dns")))] + async fn set_resolver(&self, resolver: Arc) { + for client in self.clients.iter() { + client.set_resolver(resolver.clone()).await; + } + } + + /// Connect each client to the Redis server. + /// + /// This function returns a `JoinHandle` to a task that drives **all** connections via [join](https://docs.rs/futures/latest/futures/macro.join.html). + /// + /// See [connect_pool](crate::clients::RedisPool::connect_pool) for a variation of this function that separates the + /// connection tasks. + fn connect(&self) -> ConnectHandle { + let clients = self.clients.clone(); + tokio::spawn(async move { + let tasks: Vec<_> = clients.iter().map(|c| c.connect()).collect(); + for result in join_all(tasks).await.into_iter() { + result??; + } + + Ok(()) + }) + } + + /// Force a reconnection to the server(s) for each client. + /// + /// When running against a cluster this function will also refresh the cached cluster routing table. + async fn force_reconnection(&self) -> RedisResult<()> { + let _ = try_join_all(self.clients.iter().map(|c| c.force_reconnection())).await?; + + Ok(()) + } + + /// Wait for all the clients to connect to the server. + async fn wait_for_connect(&self) -> RedisResult<()> { + let _ = try_join_all(self.clients.iter().map(|c| c.wait_for_connect())).await?; + + Ok(()) + } + + /// Close the connection to the Redis server for each client. The returned future resolves when the command has been + /// written to the socket, not when the connection has been fully closed. Some time after this future resolves the + /// future returned by [connect](Self::connect) will resolve which indicates that the connection has been fully + /// closed. + /// + /// This function will also close all error, pubsub message, and reconnection event streams on all clients in the + /// pool. + async fn quit(&self) -> RedisResult<()> { + let _ = join_all(self.clients.iter().map(|c| c.quit())).await; + + Ok(()) + } +} + +#[async_trait] +impl HeartbeatInterface for RedisPool { + async fn enable_heartbeat(&self, interval: Duration, break_on_error: bool) -> RedisResult<()> { + let mut interval = tokio_interval(interval); + + loop { + interval.tick().await; + + if let Err(error) = try_join_all(self.clients.iter().map(|c| c.ping::<()>())).await { + if break_on_error { + return Err(error); + } + } + } + } +} + +impl AclInterface for RedisPool {} +impl ClusterInterface for RedisPool {} +impl ConfigInterface for RedisPool {} +impl GeoInterface for RedisPool {} +impl HashesInterface for RedisPool {} +impl HyperloglogInterface for RedisPool {} +impl KeysInterface for RedisPool {} +impl LuaInterface for RedisPool {} +impl ListInterface for RedisPool {} +impl MemoryInterface for RedisPool {} +impl AuthInterface for RedisPool {} +impl ServerInterface for RedisPool {} +impl SlowlogInterface for RedisPool {} +impl SetsInterface for RedisPool {} +impl SortedSetsInterface for RedisPool {} +impl StreamsInterface for RedisPool {} +impl FunctionInterface for RedisPool {} +#[cfg(feature = "redis-json")] +#[cfg_attr(docsrs, doc(cfg(feature = "redis-json")))] +impl RedisJsonInterface for RedisPool {} diff --git a/src/clients/pubsub.rs b/src/clients/pubsub.rs index 2060c365..5b2fbc82 100644 --- a/src/clients/pubsub.rs +++ b/src/clients/pubsub.rs @@ -1,35 +1,10 @@ use crate::{ commands, error::RedisError, - interfaces::{ - AclInterface, - AuthInterface, - ClientInterface, - ClientLike, - ClusterInterface, - ConfigInterface, - FunctionInterface, - GeoInterface, - HashesInterface, - HeartbeatInterface, - HyperloglogInterface, - KeysInterface, - ListInterface, - LuaInterface, - MemoryInterface, - MetricsInterface, - PubsubInterface, - RedisResult, - ServerInterface, - SetsInterface, - SlowlogInterface, - SortedSetsInterface, - StreamsInterface, - TransactionInterface, - }, + interfaces::*, modules::inner::RedisClientInner, prelude::{FromRedis, RedisClient}, - types::{MultipleStrings, PerformanceConfig, ReconnectPolicy, RedisConfig, RedisKey}, + types::{ConnectionConfig, MultipleStrings, PerformanceConfig, ReconnectPolicy, RedisConfig, RedisKey}, }; use bytes_utils::Str; use parking_lot::RwLock; @@ -107,6 +82,7 @@ impl ClientLike for SubscriberClient { } } +impl EventInterface for SubscriberClient {} impl AclInterface for SubscriberClient {} impl ClientInterface for SubscriberClient {} impl ClusterInterface for SubscriberClient {} @@ -128,6 +104,9 @@ impl SortedSetsInterface for SubscriberClient {} impl HeartbeatInterface for SubscriberClient {} impl StreamsInterface for SubscriberClient {} impl FunctionInterface for SubscriberClient {} +#[cfg(feature = "redis-json")] +#[cfg_attr(docsrs, doc(cfg(feature = "redis-json")))] +impl RedisJsonInterface for SubscriberClient {} #[cfg(feature = "client-tracking")] #[cfg_attr(docsrs, doc(cfg(feature = "client-tracking")))] @@ -268,16 +247,19 @@ impl PubsubInterface for SubscriberClient { impl SubscriberClient { /// Create a new client instance without connecting to the server. + /// + /// See the [builder](crate::types::Builder) interface for more information. pub fn new( config: RedisConfig, perf: Option, + connection: Option, policy: Option, ) -> SubscriberClient { SubscriberClient { channels: Arc::new(RwLock::new(BTreeSet::new())), patterns: Arc::new(RwLock::new(BTreeSet::new())), shard_channels: Arc::new(RwLock::new(BTreeSet::new())), - inner: RedisClientInner::new(config, perf.unwrap_or_default(), policy), + inner: RedisClientInner::new(config, perf.unwrap_or_default(), connection.unwrap_or_default(), policy), } } @@ -289,6 +271,7 @@ impl SubscriberClient { let inner = RedisClientInner::new( self.inner.config.as_ref().clone(), self.inner.performance_config(), + self.inner.connection.as_ref().clone(), self.inner.reconnect_policy(), ); @@ -304,7 +287,7 @@ impl SubscriberClient { pub fn manage_subscriptions(&self) -> JoinHandle<()> { let _self = self.clone(); tokio::spawn(async move { - let mut stream = _self.on_reconnect(); + let mut stream = _self.reconnect_rx(); while let Ok(_) = stream.recv().await { if let Err(error) = _self.resubscribe_all().await { diff --git a/src/clients/redis.rs b/src/clients/redis.rs index 70fb98d7..e948da11 100644 --- a/src/clients/redis.rs +++ b/src/clients/redis.rs @@ -1,45 +1,22 @@ use crate::{ - clients::{Node, Pipeline}, + clients::{Pipeline, WithOptions}, commands, error::{RedisError, RedisErrorKind}, - interfaces::{ - AclInterface, - AuthInterface, - ClientInterface, - ClusterInterface, - ConfigInterface, - FunctionInterface, - GeoInterface, - HashesInterface, - HeartbeatInterface, - HyperloglogInterface, - KeysInterface, - ListInterface, - LuaInterface, - MemoryInterface, - MetricsInterface, - PubsubInterface, - ServerInterface, - SetsInterface, - SlowlogInterface, - SortedSetsInterface, - TransactionInterface, - }, + interfaces::*, modules::inner::RedisClientInner, prelude::{ClientLike, StreamsInterface}, types::*, }; use bytes_utils::Str; use futures::Stream; -use std::{fmt, sync::Arc}; - -#[cfg(feature = "client-tracking")] -use crate::{clients::Caching, interfaces::TrackingInterface}; +use std::{fmt, fmt::Formatter, sync::Arc}; #[cfg(feature = "replicas")] use crate::clients::Replicas; +#[cfg(feature = "client-tracking")] +use crate::interfaces::TrackingInterface; -/// The primary Redis client struct. +/// A cheaply cloneable Redis client struct. #[derive(Clone)] pub struct RedisClient { pub(crate) inner: Arc, @@ -47,11 +24,11 @@ pub struct RedisClient { impl Default for RedisClient { fn default() -> Self { - RedisClient::new(RedisConfig::default(), None, None) + RedisClient::new(RedisConfig::default(), None, None, None) } } -impl fmt::Display for RedisClient { +impl fmt::Debug for RedisClient { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("RedisClient") .field("id", &self.inner.id) @@ -60,6 +37,12 @@ impl fmt::Display for RedisClient { } } +impl fmt::Display for RedisClient { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.inner.id) + } +} + #[doc(hidden)] impl<'a> From<&'a Arc> for RedisClient { fn from(inner: &'a Arc) -> RedisClient { @@ -74,6 +57,7 @@ impl ClientLike for RedisClient { } } +impl EventInterface for RedisClient {} impl AclInterface for RedisClient {} impl ClientInterface for RedisClient {} impl ClusterInterface for RedisClient {} @@ -96,6 +80,9 @@ impl SortedSetsInterface for RedisClient {} impl HeartbeatInterface for RedisClient {} impl StreamsInterface for RedisClient {} impl FunctionInterface for RedisClient {} +#[cfg(feature = "redis-json")] +#[cfg_attr(docsrs, doc(cfg(feature = "redis-json")))] +impl RedisJsonInterface for RedisClient {} #[cfg(feature = "client-tracking")] #[cfg_attr(docsrs, doc(cfg(feature = "client-tracking")))] @@ -103,9 +90,16 @@ impl TrackingInterface for RedisClient {} impl RedisClient { /// Create a new client instance without connecting to the server. - pub fn new(config: RedisConfig, perf: Option, policy: Option) -> RedisClient { + /// + /// See the [builder](crate::types::Builder) interface for more information. + pub fn new( + config: RedisConfig, + perf: Option, + connection: Option, + policy: Option, + ) -> RedisClient { RedisClient { - inner: RedisClientInner::new(config, perf.unwrap_or_default(), policy), + inner: RedisClientInner::new(config, perf.unwrap_or_default(), connection.unwrap_or_default(), policy), } } @@ -121,17 +115,15 @@ impl RedisClient { RedisClient::new( self.inner.config.as_ref().clone(), Some(self.inner.performance_config()), + Some(self.inner.connection_config()), policy, ) } /// Split a clustered Redis client into a set of centralized clients - one for each primary node in the cluster. /// - /// Some Redis commands are not designed to work with hash slots against a clustered deployment. For example, - /// `FLUSHDB`, `PING`, etc all work on one node in the cluster, but no interface exists for the client to - /// select a specific node in the cluster against which to run the command. This function allows the caller to - /// create a list of clients such that each connect to one of the primary nodes in the cluster and functions - /// as if it were operating against a single centralized Redis server. + /// Alternatively, callers can use [with_cluster_node](crate::clients::RedisClient::with_cluster_node) to avoid + /// creating new connections. /// /// The clients returned by this function will not be connected to their associated servers. The caller needs to /// call `connect` on each client before sending any commands. @@ -242,17 +234,61 @@ impl RedisClient { Pipeline::from(self.clone()) } - /// Send commands to the provided cluster node. + /// Shorthand to bind subsequent commands to the provided server. /// - /// The caller will receive a `RedisErrorKind::Cluster` error if the provided server does not exist. + /// See [with_options](crate::interfaces::ClientLike::with_options) for more information. /// - /// The client will still automatically follow `MOVED` errors via this interface. Callers may not notice this, but - /// incorrect server arguments here could result in unnecessary calls to refresh the cached cluster routing table. - pub fn with_cluster_node(&self, server: S) -> Node + /// ```rust + /// # use fred::prelude::*; + /// async fn example(client: &RedisClient) -> Result<(), RedisError> { + /// // discover servers via the `RedisConfig` or active connections + /// let connections = client.active_connections().await?; + /// + /// // ping each node in the cluster individually + /// for server in connections.into_iter() { + /// let _: () = client.with_cluster_node(server).ping().await?; + /// } + /// + /// // or use the cached cluster routing table to discover servers + /// let servers = client + /// .cached_cluster_state() + /// .expect("Failed to read cached cluster state") + /// .unique_primary_nodes(); + /// + /// for server in servers.into_iter() { + /// // verify the server address with `CLIENT INFO` + /// let server_addr = client + /// .with_cluster_node(&server) + /// .client_info::() + /// .await? + /// .split(" ") + /// .find_map(|s| { + /// let parts: Vec<&str> = s.split("=").collect(); + /// if parts[0] == "laddr" { + /// Some(parts[1].to_owned()) + /// } else { + /// None + /// } + /// }) + /// .expect("Failed to read or parse client info."); + /// + /// assert_eq!(server_addr, server.to_string()); + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn with_cluster_node(&self, server: S) -> WithOptions where S: Into, { - Node::new(&self.inner, server.into()) + WithOptions { + client: self.clone(), + options: Options { + cluster_node: Some(server.into()), + ..Default::default() + }, + } } /// Create a client that interacts with replica nodes. @@ -261,20 +297,15 @@ impl RedisClient { pub fn replicas(&self) -> Replicas { Replicas::from(&self.inner) } - - /// Send a [CLIENT CACHING yes|no](https://redis.io/commands/client-caching/) command before subsequent commands. - #[cfg(feature = "client-tracking")] - #[cfg_attr(docsrs, doc(cfg(feature = "client-tracking")))] - pub fn caching(&self, enabled: bool) -> Caching { - Caching::new(&self.inner, enabled) - } } #[cfg(test)] mod tests { + #[cfg(feature = "sha-1")] use crate::util; #[test] + #[cfg(feature = "sha-1")] fn should_correctly_sha1_hash() { assert_eq!( &util::sha1_hash("foobarbaz"), diff --git a/src/clients/replica.rs b/src/clients/replica.rs index adbc1e18..06803fb8 100644 --- a/src/clients/replica.rs +++ b/src/clients/replica.rs @@ -1,25 +1,7 @@ use crate::{ clients::{Pipeline, RedisClient}, error::RedisError, - interfaces::{ - self, - AuthInterface, - ClientLike, - FunctionInterface, - GeoInterface, - HashesInterface, - HyperloglogInterface, - KeysInterface, - ListInterface, - LuaInterface, - MemoryInterface, - MetricsInterface, - ServerInterface, - SetsInterface, - SlowlogInterface, - SortedSetsInterface, - StreamsInterface, - }, + interfaces::{self, *}, modules::inner::RedisClientInner, protocol::command::{RedisCommand, RouterCommand}, types::Server, @@ -27,7 +9,7 @@ use crate::{ use std::{collections::HashMap, fmt, fmt::Formatter, sync::Arc}; use tokio::sync::oneshot::channel as oneshot_channel; -/// A struct for interacting with replica nodes. +/// A struct for interacting with cluster replica nodes. /// /// All commands sent via this interface will use a replica node, if possible. The underlying connections are shared /// with the main client in order to maintain an up-to-date view of the system in the event that replicas change or @@ -81,6 +63,9 @@ impl SlowlogInterface for Replicas {} impl SetsInterface for Replicas {} impl SortedSetsInterface for Replicas {} impl StreamsInterface for Replicas {} +#[cfg(feature = "redis-json")] +#[cfg_attr(docsrs, doc(cfg(feature = "redis-json")))] +impl RedisJsonInterface for Replicas {} impl Replicas { /// Read a mapping of replica server IDs to primary server IDs. diff --git a/src/clients/sentinel.rs b/src/clients/sentinel.rs index 392376ec..8ddba2b6 100644 --- a/src/clients/sentinel.rs +++ b/src/clients/sentinel.rs @@ -1,7 +1,7 @@ use crate::{ interfaces::*, modules::inner::RedisClientInner, - types::{PerformanceConfig, ReconnectPolicy, SentinelConfig}, + types::{ConnectionConfig, PerformanceConfig, ReconnectPolicy, SentinelConfig}, }; use std::{fmt, sync::Arc}; @@ -42,6 +42,7 @@ impl<'a> From<&'a Arc> for SentinelClient { } } +impl EventInterface for SentinelClient {} impl SentinelInterface for SentinelClient {} impl MetricsInterface for SentinelClient {} impl AclInterface for SentinelClient {} @@ -52,13 +53,21 @@ impl HeartbeatInterface for SentinelClient {} impl SentinelClient { /// Create a new client instance without connecting to the sentinel node. + /// + /// See the [builder](crate::types::Builder) interface for more information. pub fn new( config: SentinelConfig, perf: Option, + connection: Option, policy: Option, ) -> SentinelClient { SentinelClient { - inner: RedisClientInner::new(config.into(), perf.unwrap_or_default(), policy), + inner: RedisClientInner::new( + config.into(), + perf.unwrap_or_default(), + connection.unwrap_or_default(), + policy, + ), } } } diff --git a/src/clients/transaction.rs b/src/clients/transaction.rs index a0665037..1cc661e4 100644 --- a/src/clients/transaction.rs +++ b/src/clients/transaction.rs @@ -10,14 +10,15 @@ use crate::{ responders::ResponseKind, utils as protocol_utils, }, - types::{FromRedis, MultipleKeys, RedisKey, Server}, + types::{FromRedis, MultipleKeys, Options, RedisKey, Server}, utils, }; use parking_lot::Mutex; use std::{collections::VecDeque, fmt, sync::Arc}; use tokio::sync::oneshot::channel as oneshot_channel; -/// A client struct for commands in a `MULTI`/`EXEC` transaction block. +/// A cheaply cloneable transaction block. +#[derive(Clone)] pub struct Transaction { id: u64, inner: Arc, @@ -26,19 +27,6 @@ pub struct Transaction { hash_slot: Arc>>, } -#[doc(hidden)] -impl Clone for Transaction { - fn clone(&self) -> Self { - Transaction { - id: self.id.clone(), - inner: self.inner.clone(), - commands: self.commands.clone(), - watched: self.watched.clone(), - hash_slot: self.hash_slot.clone(), - } - } -} - impl fmt::Debug for Transaction { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Transaction") @@ -50,6 +38,14 @@ impl fmt::Debug for Transaction { } } +impl PartialEq for Transaction { + fn eq(&self, other: &Self) -> bool { + self.id == other.id + } +} + +impl Eq for Transaction {} + impl ClientLike for Transaction { #[doc(hidden)] fn inner(&self) -> &Arc { @@ -62,9 +58,9 @@ impl ClientLike for Transaction { C: Into, { let mut command: RedisCommand = command.into(); - let _ = self.disallow_all_cluster_commands(&command)?; + self.disallow_all_cluster_commands(&command)?; // check cluster slot mappings as commands are added - let _ = self.update_hash_slot(&command)?; + self.update_hash_slot(&command)?; if let Some(tx) = command.take_responder() { trace!( @@ -97,8 +93,22 @@ impl SetsInterface for Transaction {} impl SortedSetsInterface for Transaction {} impl StreamsInterface for Transaction {} impl FunctionInterface for Transaction {} +#[cfg(feature = "redis-json")] +#[cfg_attr(docsrs, doc(cfg(feature = "redis-json")))] +impl RedisJsonInterface for Transaction {} impl Transaction { + /// Create a new transaction. + pub(crate) fn from_inner(inner: &Arc) -> Self { + Transaction { + inner: inner.clone(), + commands: Arc::new(Mutex::new(VecDeque::new())), + watched: Arc::new(Mutex::new(VecDeque::new())), + hash_slot: Arc::new(Mutex::new(None)), + id: utils::random_u64(u64::MAX), + } + } + /// Check and update the hash slot for the transaction. pub(crate) fn update_hash_slot(&self, command: &RedisCommand) -> Result<(), RedisError> { if !self.inner.config.server.is_clustered() { @@ -141,6 +151,28 @@ impl Transaction { } } + /// An ID identifying the underlying transaction state. + pub fn id(&self) -> u64 { + self.id + } + + /// Clear the internal command buffer and watched keys. + pub fn reset(&self) { + self.commands.lock().clear(); + self.watched.lock().clear(); + self.hash_slot.lock().take(); + } + + /// Read the number of commands queued to run. + pub fn len(&self) -> usize { + self.commands.lock().len() + } + + /// Read the number of keys to `WATCH` before the starting the transaction. + pub fn watched_len(&self) -> usize { + self.watched.lock().len() + } + /// Executes all previously queued commands in a transaction. /// /// If `abort_on_error` is `true` the client will automatically send `DISCARD` if an error is received from @@ -164,13 +196,20 @@ impl Transaction { /// Ok(()) /// } /// ``` - pub async fn exec(self, abort_on_error: bool) -> Result + pub async fn exec(&self, abort_on_error: bool) -> Result where R: FromRedis, { - let commands = { self.commands.lock().drain(..).collect() }; - let watched = { self.watched.lock().drain(..).collect() }; - let hash_slot = utils::take_mutex(&self.hash_slot); + let commands = { + self + .commands + .lock() + .iter() + .map(|cmd| cmd.duplicate(ResponseKind::Skip)) + .collect() + }; + let watched = { self.watched.lock().iter().cloned().collect() }; + let hash_slot = utils::read_mutex(&self.hash_slot); exec(&self.inner, commands, watched, hash_slot, abort_on_error, self.id) .await? .convert() @@ -184,14 +223,6 @@ impl Transaction { self.watched.lock().extend(keys.into().inner()); } - /// Flushes all previously queued commands in a transaction. - /// - /// - pub async fn discard(self) -> Result<(), RedisError> { - // don't need to do anything here since the commands are queued in memory - Ok(()) - } - /// Read the hash slot against which this transaction will run, if known. pub fn hash_slot(&self) -> Option { utils::read_mutex(&self.hash_slot) @@ -209,22 +240,6 @@ impl Transaction { } } -#[doc(hidden)] -impl<'a> From<&'a Arc> for Transaction { - fn from(inner: &'a Arc) -> Self { - let mut commands = VecDeque::with_capacity(4); - commands.push_back(RedisCommandKind::Multi.into()); - - Transaction { - inner: inner.clone(), - commands: Arc::new(Mutex::new(commands)), - watched: Arc::new(Mutex::new(VecDeque::new())), - hash_slot: Arc::new(Mutex::new(None)), - id: utils::random_u64(u64::MAX), - } - } -} - async fn exec( inner: &Arc, commands: VecDeque, @@ -237,16 +252,23 @@ async fn exec( return Ok(RedisValue::Null); } let (tx, rx) = oneshot_channel(); + let trx_options = Options::from_command(&commands[0]); + + let mut multi = RedisCommand::new(RedisCommandKind::Multi, vec![]); + trx_options.apply(&mut multi); - let commands: Vec = commands + let commands: Vec = [multi] .into_iter() + .chain(commands.into_iter()) .map(|mut command| { + command.inherit_options(inner); command.response = ResponseKind::Skip; command.can_pipeline = false; command.skip_backpressure = true; - command.transaction_id = Some(id.clone()); + command.transaction_id = Some(id); + command.use_replica = false; if let Some(hash_slot) = hash_slot.as_ref() { - command.hasher = ClusterHash::Custom(hash_slot.clone()); + command.hasher = ClusterHash::Custom(*hash_slot); } command }) @@ -259,9 +281,9 @@ async fn exec( let mut watch_cmd = RedisCommand::new(RedisCommandKind::Watch, args); watch_cmd.can_pipeline = false; watch_cmd.skip_backpressure = true; - watch_cmd.transaction_id = Some(id.clone()); + watch_cmd.transaction_id = Some(id); if let Some(hash_slot) = hash_slot.as_ref() { - watch_cmd.hasher = ClusterHash::Custom(hash_slot.clone()); + watch_cmd.hasher = ClusterHash::Custom(*hash_slot); } Some(watch_cmd) }; @@ -280,8 +302,9 @@ async fn exec( watched, abort_on_error, }; + let timeout_dur = trx_options.timeout.unwrap_or_else(|| inner.default_command_timeout()); - let _ = interfaces::send_to_router(inner, command)?; - let frame = utils::apply_timeout(rx, inner.default_command_timeout()).await??; - protocol_utils::frame_to_results_raw(frame) + interfaces::send_to_router(inner, command)?; + let frame = utils::apply_timeout(rx, timeout_dur).await??; + protocol_utils::frame_to_results(frame) } diff --git a/src/commands/impls/acl.rs b/src/commands/impls/acl.rs index d4b9050f..32b36f6c 100644 --- a/src/commands/impls/acl.rs +++ b/src/commands/impls/acl.rs @@ -1,12 +1,10 @@ use super::*; use crate::{ - error::*, protocol::{command::RedisCommandKind, utils as protocol_utils}, types::*, utils, }; use bytes_utils::Str; -use redis_protocol::resp3::types::Frame; ok_cmd!(acl_load, AclLoad); ok_cmd!(acl_save, AclSave); @@ -14,48 +12,35 @@ values_cmd!(acl_list, AclList); values_cmd!(acl_users, AclUsers); value_cmd!(acl_whoami, AclWhoAmI); -pub async fn acl_setuser(client: &C, username: Str, rules: Vec) -> Result<(), RedisError> { +pub async fn acl_setuser(client: &C, username: Str, rules: MultipleValues) -> Result<(), RedisError> { let frame = utils::request_response(client, move || { + let rules = rules.into_multiple_values(); let mut args = Vec::with_capacity(rules.len() + 1); args.push(username.into()); for rule in rules.into_iter() { - args.push(rule.to_value()); + args.push(rule); } - Ok((RedisCommandKind::AclSetUser, args)) }) .await?; - let response = protocol_utils::frame_to_single_result(frame)?; + let response = protocol_utils::frame_to_results(frame)?; protocol_utils::expect_ok(&response) } -pub async fn acl_getuser(client: &C, username: Str) -> Result, RedisError> { +pub async fn acl_getuser(client: &C, username: Str) -> Result { let frame = utils::request_response(client, move || { Ok((RedisCommandKind::AclGetUser, vec![username.into()])) }) .await?; - - if protocol_utils::is_null(&frame) { - return Ok(None); - } - let frame = protocol_utils::frame_map_or_set_to_nested_array(frame)?; - - if let Frame::Array { data, .. } = frame { - protocol_utils::parse_acl_getuser_frames(data).map(|u| Some(u)) - } else { - Err(RedisError::new( - RedisErrorKind::Protocol, - "Invalid response frame. Expected array or nil.", - )) - } + protocol_utils::frame_to_results(frame) } pub async fn acl_deluser(client: &C, usernames: MultipleKeys) -> Result { let args: Vec = usernames.inner().into_iter().map(|k| k.into()).collect(); let frame = utils::request_response(client, move || Ok((RedisCommandKind::AclDelUser, args))).await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } pub async fn acl_cat(client: &C, category: Option) -> Result { @@ -77,12 +62,12 @@ pub async fn acl_genpass(client: &C, bits: Option) -> Result }; let frame = utils::request_response(client, move || Ok((RedisCommandKind::AclGenPass, args))).await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } pub async fn acl_log_reset(client: &C) -> Result<(), RedisError> { let frame = utils::request_response(client, || Ok((RedisCommandKind::AclLog, vec![static_val!(RESET)]))).await?; - let response = protocol_utils::frame_to_single_result(frame)?; + let response = protocol_utils::frame_to_results(frame)?; protocol_utils::expect_ok(&response) } diff --git a/src/commands/impls/client.rs b/src/commands/impls/client.rs index da62b5b6..174c8b36 100644 --- a/src/commands/impls/client.rs +++ b/src/commands/impls/client.rs @@ -2,7 +2,7 @@ use super::*; use crate::{ interfaces, protocol::{ - command::{RouterCommand, RedisCommand, RedisCommandKind}, + command::{RedisCommand, RedisCommandKind, RouterCommand}, utils as protocol_utils, }, types::*, @@ -31,7 +31,7 @@ pub async fn client_kill( }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } pub async fn client_list( @@ -62,7 +62,7 @@ pub async fn client_list( }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } pub async fn client_pause( @@ -82,24 +82,16 @@ pub async fn client_pause( }) .await?; - let response = protocol_utils::frame_to_single_result(frame)?; + let response = protocol_utils::frame_to_results(frame)?; protocol_utils::expect_ok(&response) } value_cmd!(client_getname, ClientGetName); pub async fn client_setname(client: &C, name: Str) -> Result<(), RedisError> { - let inner = client.inner(); - _warn!( - inner, - "Changing client name from {} to {}", - client.inner().id.as_str(), - name - ); - let frame = utils::request_response(client, move || Ok((RedisCommandKind::ClientSetname, vec![name.into()]))).await?; - let response = protocol_utils::frame_to_single_result(frame)?; + let response = protocol_utils::frame_to_results(frame)?; protocol_utils::expect_ok(&response) } @@ -111,7 +103,7 @@ pub async fn client_reply(client: &C, flag: ClientReplyFlag) -> R }) .await?; - let response = protocol_utils::frame_to_single_result(frame)?; + let response = protocol_utils::frame_to_results(frame)?; protocol_utils::expect_ok(&response) } @@ -130,7 +122,7 @@ pub async fn client_unblock( let command = RedisCommand::new(RedisCommandKind::ClientUnblock, args); let frame = utils::backchannel_request_response(inner, command, false).await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } pub async fn unblock_self(client: &C, flag: Option) -> Result<(), RedisError> { @@ -144,7 +136,7 @@ pub async fn unblock_self(client: &C, flag: Option(client: &C) -> Result, RedisError> { let (tx, rx) = oneshot_channel(); let command = RouterCommand::Connections { tx }; - let _ = interfaces::send_to_router(client.inner(), command)?; + interfaces::send_to_router(client.inner(), command)?; rx.await.map_err(|e| e.into()) } diff --git a/src/commands/impls/cluster.rs b/src/commands/impls/cluster.rs index 856ecb3c..1a29121f 100644 --- a/src/commands/impls/cluster.rs +++ b/src/commands/impls/cluster.rs @@ -2,7 +2,7 @@ use super::*; use crate::{ interfaces, protocol::{ - command::{RouterCommand, RedisCommandKind}, + command::{RedisCommandKind, RouterCommand}, utils as protocol_utils, }, types::*, @@ -19,9 +19,9 @@ value_cmd!(cluster_nodes, ClusterNodes); ok_cmd!(cluster_saveconfig, ClusterSaveConfig); values_cmd!(cluster_slots, ClusterSlots); -pub async fn cluster_info(client: &C) -> Result { +pub async fn cluster_info(client: &C) -> Result { let frame = utils::request_response(client, || Ok((RedisCommandKind::ClusterInfo, vec![]))).await?; - protocol_utils::parse_cluster_info(frame) + protocol_utils::frame_to_results(frame) } pub async fn cluster_add_slots(client: &C, slots: MultipleHashSlots) -> Result<(), RedisError> { @@ -36,7 +36,7 @@ pub async fn cluster_add_slots(client: &C, slots: MultipleHashSlo }) .await?; - let response = protocol_utils::frame_to_single_result(frame)?; + let response = protocol_utils::frame_to_results(frame)?; protocol_utils::expect_ok(&response) } @@ -48,7 +48,7 @@ pub async fn cluster_count_failure_reports( Ok((RedisCommandKind::ClusterCountFailureReports, vec![node_id.into()])) }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } pub async fn cluster_count_keys_in_slot(client: &C, slot: u16) -> Result { @@ -57,7 +57,7 @@ pub async fn cluster_count_keys_in_slot(client: &C, slot: u16) -> }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } pub async fn cluster_del_slots(client: &C, slots: MultipleHashSlots) -> Result<(), RedisError> { @@ -72,7 +72,7 @@ pub async fn cluster_del_slots(client: &C, slots: MultipleHashSlo }) .await?; - let response = protocol_utils::frame_to_single_result(frame)?; + let response = protocol_utils::frame_to_results(frame)?; protocol_utils::expect_ok(&response) } @@ -91,7 +91,7 @@ pub async fn cluster_failover( }) .await?; - let response = protocol_utils::frame_to_single_result(frame)?; + let response = protocol_utils::frame_to_results(frame)?; protocol_utils::expect_ok(&response) } @@ -142,7 +142,7 @@ pub async fn cluster_reset(client: &C, mode: Option( }) .await?; - let response = protocol_utils::frame_to_single_result(frame)?; + let response = protocol_utils::frame_to_results(frame)?; protocol_utils::expect_ok(&response) } pub async fn sync_cluster(client: &C) -> Result<(), RedisError> { let (tx, rx) = oneshot_channel(); let command = RouterCommand::SyncCluster { tx }; - let _ = interfaces::send_to_router(client.inner(), command)?; + interfaces::send_to_router(client.inner(), command)?; rx.await? } diff --git a/src/commands/impls/geo.rs b/src/commands/impls/geo.rs index 01a60abd..4a01479f 100644 --- a/src/commands/impls/geo.rs +++ b/src/commands/impls/geo.rs @@ -7,14 +7,14 @@ use crate::{ }; use std::convert::TryInto; -static WITH_COORD: &'static str = "WITHCOORD"; -static WITH_DIST: &'static str = "WITHDIST"; -static WITH_HASH: &'static str = "WITHHASH"; -static STORE_DIST: &'static str = "STOREDIST"; -static FROM_MEMBER: &'static str = "FROMMEMBER"; -static FROM_LONLAT: &'static str = "FROMLONLAT"; -static BY_RADIUS: &'static str = "BYRADIUS"; -static BY_BOX: &'static str = "BYBOX"; +static WITH_COORD: &str = "WITHCOORD"; +static WITH_DIST: &str = "WITHDIST"; +static WITH_HASH: &str = "WITHHASH"; +static STORE_DIST: &str = "STOREDIST"; +static FROM_MEMBER: &str = "FROMMEMBER"; +static FROM_LONLAT: &str = "FROMLONLAT"; +static BY_RADIUS: &str = "BYRADIUS"; +static BY_BOX: &str = "BYBOX"; pub async fn geoadd( client: &C, @@ -44,7 +44,7 @@ pub async fn geoadd( }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } pub async fn geohash( @@ -53,10 +53,11 @@ pub async fn geohash( members: MultipleValues, ) -> Result { let frame = utils::request_response(client, move || { + let members = members.into_multiple_values(); let mut args = Vec::with_capacity(1 + members.len()); args.push(key.into()); - for member in members.inner().into_iter() { + for member in members.into_iter() { args.push(member); } @@ -73,10 +74,11 @@ pub async fn geopos( members: MultipleValues, ) -> Result { let frame = utils::request_response(client, move || { + let members = members.into_multiple_values(); let mut args = Vec::with_capacity(1 + members.len()); args.push(key.into()); - for member in members.inner().into_iter() { + for member in members.into_iter() { args.push(member); } @@ -108,7 +110,7 @@ pub async fn geodist( }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } pub async fn georadius( @@ -124,7 +126,7 @@ pub async fn georadius( ord: Option, store: Option, storedist: Option, -) -> Result, RedisError> { +) -> Result { let frame = utils::request_response(client, move || { let mut args = Vec::with_capacity(16); args.push(key.into()); @@ -165,7 +167,7 @@ pub async fn georadius( }) .await?; - protocol_utils::parse_georadius_result(frame, withcoord, withdist, withhash) + protocol_utils::frame_to_results(frame) } pub async fn georadiusbymember( @@ -181,7 +183,7 @@ pub async fn georadiusbymember( ord: Option, store: Option, storedist: Option, -) -> Result, RedisError> { +) -> Result { let frame = utils::request_response(client, move || { let mut args = Vec::with_capacity(15); args.push(key.into()); @@ -221,7 +223,7 @@ pub async fn georadiusbymember( }) .await?; - protocol_utils::parse_georadius_result(frame, withcoord, withdist, withhash) + protocol_utils::frame_to_results(frame) } pub async fn geosearch( @@ -236,7 +238,7 @@ pub async fn geosearch( withcoord: bool, withdist: bool, withhash: bool, -) -> Result, RedisError> { +) -> Result { let frame = utils::request_response(client, move || { let mut args = Vec::with_capacity(15); args.push(key.into()); @@ -286,7 +288,7 @@ pub async fn geosearch( }) .await?; - protocol_utils::parse_georadius_result(frame, withcoord, withdist, withhash) + protocol_utils::frame_to_results(frame) } pub async fn geosearchstore( @@ -344,5 +346,5 @@ pub async fn geosearchstore( }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } diff --git a/src/commands/impls/hashes.rs b/src/commands/impls/hashes.rs index 47b95282..fec1537e 100644 --- a/src/commands/impls/hashes.rs +++ b/src/commands/impls/hashes.rs @@ -19,7 +19,7 @@ pub async fn hdel(client: &C, key: RedisKey, fields: MultipleKeys }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } pub async fn hexists(client: &C, key: RedisKey, field: RedisKey) -> Result { @@ -99,7 +99,7 @@ pub async fn hmset(client: &C, key: RedisKey, values: RedisMap) - }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } pub async fn hset(client: &C, key: RedisKey, values: RedisMap) -> Result { @@ -116,7 +116,7 @@ pub async fn hset(client: &C, key: RedisKey, values: RedisMap) -> }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } pub async fn hsetnx( @@ -130,7 +130,7 @@ pub async fn hsetnx( }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } pub async fn hrandfield( @@ -158,12 +158,12 @@ pub async fn hrandfield( if has_count { if has_values && !protocol_utils::frame_is_queued(&frame) { let frame = protocol_utils::flatten_frame(frame); - protocol_utils::frame_to_map(frame).map(|m| RedisValue::Map(m)) + protocol_utils::frame_to_map(frame).map(RedisValue::Map) } else { protocol_utils::frame_to_results(frame) } } else { - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } } @@ -173,7 +173,7 @@ pub async fn hstrlen(client: &C, key: RedisKey, field: RedisKey) }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } pub async fn hvals(client: &C, key: RedisKey) -> Result { diff --git a/src/commands/impls/hyperloglog.rs b/src/commands/impls/hyperloglog.rs index 9e32f2d4..71793815 100644 --- a/src/commands/impls/hyperloglog.rs +++ b/src/commands/impls/hyperloglog.rs @@ -11,17 +11,18 @@ pub async fn pfadd( elements: MultipleValues, ) -> Result { let frame = utils::request_response(client, move || { + let elements = elements.into_multiple_values(); let mut args = Vec::with_capacity(1 + elements.len()); args.push(key.into()); - for element in elements.inner().into_iter() { + for element in elements.into_iter() { args.push(element); } Ok((RedisCommandKind::Pfadd, args)) }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } pub async fn pfcount(client: &C, keys: MultipleKeys) -> Result { @@ -45,5 +46,5 @@ pub async fn pfmerge( }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } diff --git a/src/commands/impls/keys.rs b/src/commands/impls/keys.rs index c2aa15c5..846401cf 100644 --- a/src/commands/impls/keys.rs +++ b/src/commands/impls/keys.rs @@ -81,7 +81,7 @@ pub async fn incr_by(client: &C, key: RedisKey, val: i64) -> Resu }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } pub async fn decr_by(client: &C, key: RedisKey, val: i64) -> Result { @@ -90,7 +90,7 @@ pub async fn decr_by(client: &C, key: RedisKey, val: i64) -> Resu }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } pub async fn incr_by_float(client: &C, key: RedisKey, val: f64) -> Result { @@ -100,7 +100,7 @@ pub async fn incr_by_float(client: &C, key: RedisKey, val: f64) - }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } pub async fn ttl(client: &C, key: RedisKey) -> Result { @@ -121,7 +121,7 @@ pub async fn expire(client: &C, key: RedisKey, seconds: i64) -> R }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } pub async fn expire_at(client: &C, key: RedisKey, timestamp: i64) -> Result { @@ -130,7 +130,7 @@ pub async fn expire_at(client: &C, key: RedisKey, timestamp: i64) }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } pub async fn exists(client: &C, keys: MultipleKeys) -> Result { @@ -147,7 +147,7 @@ pub async fn exists(client: &C, keys: MultipleKeys) -> Result(client: &C, key: RedisKey) -> Result { @@ -207,7 +207,7 @@ pub async fn getrange( }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } pub async fn setrange( @@ -221,19 +221,35 @@ pub async fn setrange( }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } pub async fn getset(client: &C, key: RedisKey, value: RedisValue) -> Result { args_values_cmd(client, RedisCommandKind::GetSet, vec![key.into(), value]).await } -pub async fn rename(client: &C, source: RedisKey, destination: RedisKey) -> Result { - args_values_cmd(client, RedisCommandKind::Rename, vec![source.into(), destination.into()]).await +pub async fn rename( + client: &C, + source: RedisKey, + destination: RedisKey, +) -> Result { + args_values_cmd(client, RedisCommandKind::Rename, vec![ + source.into(), + destination.into(), + ]) + .await } -pub async fn renamenx(client: &C, source: RedisKey, destination: RedisKey) -> Result { - args_values_cmd(client, RedisCommandKind::Renamenx, vec![source.into(), destination.into()]).await +pub async fn renamenx( + client: &C, + source: RedisKey, + destination: RedisKey, +) -> Result { + args_values_cmd(client, RedisCommandKind::Renamenx, vec![ + source.into(), + destination.into(), + ]) + .await } pub async fn getdel(client: &C, key: RedisKey) -> Result { @@ -281,7 +297,7 @@ pub async fn mset(client: &C, values: RedisMap) -> Result(client: &C, values: RedisMap) -> Result { @@ -304,7 +320,7 @@ pub async fn msetnx(client: &C, values: RedisMap) -> Result( @@ -331,7 +347,7 @@ pub async fn copy( }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } pub async fn watch(client: &C, keys: MultipleKeys) -> Result<(), RedisError> { diff --git a/src/commands/impls/lists.rs b/src/commands/impls/lists.rs index de39accb..833e117b 100644 --- a/src/commands/impls/lists.rs +++ b/src/commands/impls/lists.rs @@ -32,7 +32,7 @@ pub async fn blmpop( }) .await?; - let _ = protocol_utils::check_null_timeout(&frame)?; + protocol_utils::check_null_timeout(&frame)?; protocol_utils::frame_to_results(frame) } @@ -50,7 +50,7 @@ pub async fn blpop(client: &C, keys: MultipleKeys, timeout: f64) }) .await?; - let _ = protocol_utils::check_null_timeout(&frame)?; + protocol_utils::check_null_timeout(&frame)?; protocol_utils::frame_to_results(frame) } @@ -68,7 +68,7 @@ pub async fn brpop(client: &C, keys: MultipleKeys, timeout: f64) }) .await?; - let _ = protocol_utils::check_null_timeout(&frame)?; + protocol_utils::check_null_timeout(&frame)?; protocol_utils::frame_to_results(frame) } @@ -89,7 +89,7 @@ pub async fn brpoplpush( }) .await?; - let _ = protocol_utils::check_null_timeout(&frame)?; + protocol_utils::check_null_timeout(&frame)?; protocol_utils::frame_to_results(frame) } @@ -116,7 +116,7 @@ pub async fn blmove( }) .await?; - let _ = protocol_utils::check_null_timeout(&frame)?; + protocol_utils::check_null_timeout(&frame)?; protocol_utils::frame_to_results(frame) } @@ -167,7 +167,7 @@ pub async fn linsert( }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } pub async fn llen(client: &C, key: RedisKey) -> Result { @@ -229,10 +229,11 @@ pub async fn lpush( elements: MultipleValues, ) -> Result { let frame = utils::request_response(client, move || { + let elements = elements.into_multiple_values(); let mut args = Vec::with_capacity(1 + elements.len()); args.push(key.into()); - for element in elements.inner().into_iter() { + for element in elements.into_iter() { args.push(element); } @@ -240,7 +241,7 @@ pub async fn lpush( }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } pub async fn lpushx( @@ -249,10 +250,11 @@ pub async fn lpushx( elements: MultipleValues, ) -> Result { let frame = utils::request_response(client, move || { + let elements = elements.into_multiple_values(); let mut args = Vec::with_capacity(1 + elements.len()); args.push(key.into()); - for element in elements.inner().into_iter() { + for element in elements.into_iter() { args.push(element); } @@ -260,7 +262,7 @@ pub async fn lpushx( }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } pub async fn lrange( @@ -347,7 +349,7 @@ pub async fn lmove( }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } pub async fn rpush( @@ -356,10 +358,11 @@ pub async fn rpush( elements: MultipleValues, ) -> Result { let frame = utils::request_response(client, move || { + let elements = elements.into_multiple_values(); let mut args = Vec::with_capacity(1 + elements.len()); args.push(key.into()); - for element in elements.inner().into_iter() { + for element in elements.into_iter() { args.push(element); } @@ -367,7 +370,7 @@ pub async fn rpush( }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } pub async fn rpushx( @@ -376,10 +379,11 @@ pub async fn rpushx( elements: MultipleValues, ) -> Result { let frame = utils::request_response(client, move || { + let elements = elements.into_multiple_values(); let mut args = Vec::with_capacity(1 + elements.len()); args.push(key.into()); - for element in elements.inner().into_iter() { + for element in elements.into_iter() { args.push(element); } @@ -387,5 +391,5 @@ pub async fn rpushx( }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } diff --git a/src/commands/impls/lua.rs b/src/commands/impls/lua.rs index 13894634..dfd06699 100644 --- a/src/commands/impls/lua.rs +++ b/src/commands/impls/lua.rs @@ -10,9 +10,10 @@ use crate::{ utils as protocol_utils, }, types::*, - util::sha1_hash, utils, }; +#[cfg(feature = "sha-1")] +use crate::util::sha1_hash; use bytes::Bytes; use bytes_utils::Str; use std::{convert::TryInto, str, sync::Arc}; @@ -20,7 +21,7 @@ use tokio::sync::oneshot::channel as oneshot_channel; /// Check that all the keys in an EVAL* command belong to the same server, returning a key slot that maps to that /// server. -pub fn check_key_slot(inner: &Arc, keys: &Vec) -> Result, RedisError> { +pub fn check_key_slot(inner: &Arc, keys: &[RedisKey]) -> Result, RedisError> { if inner.config.server.is_clustered() { inner.with_cluster_state(|state| { let (mut cmd_server, mut cmd_slot) = (None, None); @@ -58,6 +59,7 @@ pub async fn script_load(client: &C, script: Str) -> Result(client: &C, script: Str) -> Result { if !client.inner().config.server.is_clustered() { return script_load(client, script).await; @@ -66,10 +68,11 @@ pub async fn script_load_cluster(client: &C, script: Str) -> Resu let (tx, rx) = oneshot_channel(); let response = ResponseKind::new_buffer(tx); - let command: RedisCommand = (RedisCommandKind::_ScriptLoadCluster, vec![script.into()], response).into(); - let _ = client.send_command(command)?; + let mut command: RedisCommand = (RedisCommandKind::_ScriptLoadCluster, vec![script.into()], response).into(); - let _ = rx.await??; + let timeout_dur = utils::prepare_command(client, &mut command); + client.send_command(command)?; + let _ = utils::apply_timeout(rx, timeout_dur).await??; Ok(hash.into()) } @@ -82,10 +85,11 @@ pub async fn script_kill_cluster(client: &C) -> Result<(), RedisE let (tx, rx) = oneshot_channel(); let response = ResponseKind::new_buffer(tx); - let command: RedisCommand = (RedisCommandKind::_ScriptKillCluster, vec![], response).into(); - let _ = client.send_command(command)?; + let mut command: RedisCommand = (RedisCommandKind::_ScriptKillCluster, vec![], response).into(); - let _ = rx.await??; + let timeout_dur = utils::prepare_command(client, &mut command); + client.send_command(command)?; + let _ = utils::apply_timeout(rx, timeout_dur).await??; Ok(()) } @@ -96,7 +100,7 @@ pub async fn script_flush(client: &C, r#async: bool) -> Result<() }) .await?; - let response = protocol_utils::frame_to_single_result(frame)?; + let response = protocol_utils::frame_to_results(frame)?; protocol_utils::expect_ok(&response) } @@ -107,12 +111,13 @@ pub async fn script_flush_cluster(client: &C, r#async: bool) -> R let (tx, rx) = oneshot_channel(); let arg = static_val!(if r#async { ASYNC } else { SYNC }); - let response = ResponseKind::new_buffer(tx); - let command: RedisCommand = (RedisCommandKind::_ScriptFlushCluster, vec![arg], response).into(); - let _ = client.send_command(command)?; + let mut command: RedisCommand = (RedisCommandKind::_ScriptFlushCluster, vec![arg], response).into(); - let _ = rx.await??; + let timeout_dur = utils::prepare_command(client, &mut command); + client.send_command(command)?; + + let _ = utils::apply_timeout(rx, timeout_dur).await??; Ok(()) } @@ -132,7 +137,7 @@ pub async fn script_debug(client: &C, flag: ScriptDebugFlag) -> R }) .await?; - let response = protocol_utils::frame_to_single_result(frame)?; + let response = protocol_utils::frame_to_results(frame)?; protocol_utils::expect_ok(&response) } @@ -146,6 +151,7 @@ pub async fn evalsha( let custom_key_slot = check_key_slot(client.inner(), &keys)?; let frame = utils::request_response(client, move || { + let cmd_args = cmd_args.into_multiple_values(); let mut args = Vec::with_capacity(2 + keys.len() + cmd_args.len()); args.push(hash.into()); args.push(keys.len().try_into()?); @@ -153,14 +159,12 @@ pub async fn evalsha( for key in keys.into_iter() { args.push(key.into()); } - for arg in cmd_args.inner().into_iter() { + for arg in cmd_args.into_iter() { args.push(arg); } let mut command: RedisCommand = (RedisCommandKind::EvalSha, args).into(); - command.hasher = custom_key_slot - .map(|slot| ClusterHash::Custom(slot)) - .unwrap_or(ClusterHash::Random); + command.hasher = custom_key_slot.map(ClusterHash::Custom).unwrap_or(ClusterHash::Random); command.can_pipeline = false; Ok(command) }) @@ -179,6 +183,7 @@ pub async fn eval( let custom_key_slot = check_key_slot(client.inner(), &keys)?; let frame = utils::request_response(client, move || { + let cmd_args = cmd_args.into_multiple_values(); let mut args = Vec::with_capacity(2 + keys.len() + cmd_args.len()); args.push(script.into()); args.push(keys.len().try_into()?); @@ -186,14 +191,12 @@ pub async fn eval( for key in keys.into_iter() { args.push(key.into()); } - for arg in cmd_args.inner().into_iter() { + for arg in cmd_args.into_iter() { args.push(arg); } let mut command: RedisCommand = (RedisCommandKind::Eval, args).into(); - command.hasher = custom_key_slot - .map(|slot| ClusterHash::Custom(slot)) - .unwrap_or(ClusterHash::Random); + command.hasher = custom_key_slot.map(ClusterHash::Custom).unwrap_or(ClusterHash::Random); command.can_pipeline = false; Ok(command) }) @@ -209,6 +212,7 @@ pub async fn fcall( args: MultipleValues, ) -> Result { let frame = utils::request_response(client, move || { + let args = args.into_multiple_values(); let mut arguments = Vec::with_capacity(keys.len() + args.len() + 2); let mut custom_key_slot = None; @@ -219,14 +223,12 @@ pub async fn fcall( custom_key_slot = Some(key.cluster_hash()); arguments.push(key.into()); } - for arg in args.inner().into_iter() { + for arg in args.into_iter() { arguments.push(arg); } let mut command: RedisCommand = (RedisCommandKind::Fcall, arguments).into(); - command.hasher = custom_key_slot - .map(|slot| ClusterHash::Custom(slot)) - .unwrap_or(ClusterHash::Random); + command.hasher = custom_key_slot.map(ClusterHash::Custom).unwrap_or(ClusterHash::Random); command.can_pipeline = false; Ok(command) }) @@ -242,6 +244,7 @@ pub async fn fcall_ro( args: MultipleValues, ) -> Result { let frame = utils::request_response(client, move || { + let args = args.into_multiple_values(); let mut arguments = Vec::with_capacity(keys.len() + args.len() + 2); let mut custom_key_slot = None; @@ -252,14 +255,12 @@ pub async fn fcall_ro( custom_key_slot = Some(key.cluster_hash()); arguments.push(key.into()); } - for arg in args.inner().into_iter() { + for arg in args.into_iter() { arguments.push(arg); } let mut command: RedisCommand = (RedisCommandKind::FcallRO, arguments).into(); - command.hasher = custom_key_slot - .map(|slot| ClusterHash::Custom(slot)) - .unwrap_or(ClusterHash::Random); + command.hasher = custom_key_slot.map(ClusterHash::Custom).unwrap_or(ClusterHash::Random); command.can_pipeline = false; Ok(command) }) @@ -274,7 +275,7 @@ pub async fn function_delete(client: &C, library_name: Str) -> Re }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } pub async fn function_delete_cluster(client: &C, library_name: Str) -> Result<(), RedisError> { @@ -286,10 +287,12 @@ pub async fn function_delete_cluster(client: &C, library_name: St let args: Vec = vec![library_name.into()]; let response = ResponseKind::new_buffer(tx); - let command: RedisCommand = (RedisCommandKind::_FunctionDeleteCluster, args, response).into(); - let _ = client.send_command(command)?; + let mut command: RedisCommand = (RedisCommandKind::_FunctionDeleteCluster, args, response).into(); - let _ = rx.await??; + let timeout_dur = utils::prepare_command(client, &mut command); + client.send_command(command)?; + + let _ = utils::apply_timeout(rx, timeout_dur).await??; Ok(()) } @@ -305,7 +308,7 @@ pub async fn function_flush(client: &C, r#async: bool) -> Result< }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } pub async fn function_flush_cluster(client: &C, r#async: bool) -> Result<(), RedisError> { @@ -322,7 +325,7 @@ pub async fn function_flush_cluster(client: &C, r#async: bool) -> let response = ResponseKind::new_buffer(tx); let command: RedisCommand = (RedisCommandKind::_FunctionFlushCluster, args, response).into(); - let _ = client.send_command(command)?; + client.send_command(command)?; let _ = rx.await??; Ok(()) @@ -356,7 +359,7 @@ pub async fn function_list( }) .await?; - protocol_utils::frame_to_results_raw(frame) + protocol_utils::frame_to_results(frame) } pub async fn function_load(client: &C, replace: bool, code: Str) -> Result { @@ -371,7 +374,7 @@ pub async fn function_load(client: &C, replace: bool, code: Str) }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } pub async fn function_load_cluster( @@ -391,14 +394,16 @@ pub async fn function_load_cluster( args.push(code.into()); let response = ResponseKind::new_buffer(tx); - let command: RedisCommand = (RedisCommandKind::_FunctionLoadCluster, args, response).into(); - let _ = client.send_command(command)?; + let mut command: RedisCommand = (RedisCommandKind::_FunctionLoadCluster, args, response).into(); + + let timeout_dur = utils::prepare_command(client, &mut command); + client.send_command(command)?; // each value in the response array is the response from a different primary node - match rx.await?? { + match utils::apply_timeout(rx, timeout_dur).await?? { Frame::Array { mut data, .. } => { if let Some(frame) = data.pop() { - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } else { Err(RedisError::new( RedisErrorKind::Protocol, @@ -429,7 +434,7 @@ pub async fn function_restore( }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } pub async fn function_restore_cluster( @@ -445,10 +450,11 @@ pub async fn function_restore_cluster( let args: Vec = vec![serialized.into(), policy.to_str().into()]; let response = ResponseKind::new_buffer(tx); - let command: RedisCommand = (RedisCommandKind::_FunctionRestoreCluster, args, response).into(); - let _ = client.send_command(command)?; + let mut command: RedisCommand = (RedisCommandKind::_FunctionRestoreCluster, args, response).into(); - let _ = rx.await??; + let timeout_dur = utils::prepare_command(client, &mut command); + client.send_command(command)?; + let _ = utils::apply_timeout(rx, timeout_dur).await??; Ok(()) } @@ -457,7 +463,7 @@ pub async fn function_stats(client: &C) -> Result(client: &C) -> Result { +pub async fn memory_stats(client: &C) -> Result { let response = utils::request_response(client, || Ok((RedisCommandKind::MemoryStats, vec![]))).await?; - - let frame = protocol_utils::frame_map_or_set_to_nested_array(response)?; - if let Frame::Array { data, .. } = frame { - protocol_utils::parse_memory_stats(&data) - } else { - Err(RedisError::new(RedisErrorKind::Protocol, "Expected array response.")) - } + protocol_utils::frame_to_results(response) } pub async fn memory_usage( @@ -40,5 +32,5 @@ pub async fn memory_usage( }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } diff --git a/src/commands/impls/mod.rs b/src/commands/impls/mod.rs index 78b821cf..8abb7b97 100644 --- a/src/commands/impls/mod.rs +++ b/src/commands/impls/mod.rs @@ -6,54 +6,54 @@ use crate::{ utils, }; -pub static MATCH: &'static str = "MATCH"; -pub static COUNT: &'static str = "COUNT"; -pub static TYPE: &'static str = "TYPE"; -pub static CHANGED: &'static str = "CH"; -pub static INCR: &'static str = "INCR"; -pub static WITH_SCORES: &'static str = "WITHSCORES"; -pub static LIMIT: &'static str = "LIMIT"; -pub static AGGREGATE: &'static str = "AGGREGATE"; -pub static WEIGHTS: &'static str = "WEIGHTS"; -pub static GET: &'static str = "GET"; -pub static RESET: &'static str = "RESET"; -pub static TO: &'static str = "TO"; -pub static FORCE: &'static str = "FORCE"; -pub static ABORT: &'static str = "ABORT"; -pub static TIMEOUT: &'static str = "TIMEOUT"; -pub static LEN: &'static str = "LEN"; -pub static DB: &'static str = "DB"; -pub static REPLACE: &'static str = "REPLACE"; -pub static ID: &'static str = "ID"; -pub static ANY: &'static str = "ANY"; -pub static STORE: &'static str = "STORE"; -pub static WITH_VALUES: &'static str = "WITHVALUES"; -pub static SYNC: &'static str = "SYNC"; -pub static ASYNC: &'static str = "ASYNC"; -pub static RANK: &'static str = "RANK"; -pub static MAXLEN: &'static str = "MAXLEN"; -pub static REV: &'static str = "REV"; -pub static ABSTTL: &'static str = "ABSTTL"; -pub static IDLE_TIME: &'static str = "IDLETIME"; -pub static FREQ: &'static str = "FREQ"; -pub static FULL: &'static str = "FULL"; -pub static NOMKSTREAM: &'static str = "NOMKSTREAM"; -pub static MINID: &'static str = "MINID"; -pub static BLOCK: &'static str = "BLOCK"; -pub static STREAMS: &'static str = "STREAMS"; -pub static MKSTREAM: &'static str = "MKSTREAM"; -pub static GROUP: &'static str = "GROUP"; -pub static NOACK: &'static str = "NOACK"; -pub static IDLE: &'static str = "IDLE"; -pub static TIME: &'static str = "TIME"; -pub static RETRYCOUNT: &'static str = "RETRYCOUNT"; -pub static JUSTID: &'static str = "JUSTID"; -pub static SAMPLES: &'static str = "SAMPLES"; -pub static LIBRARYNAME: &'static str = "LIBRARYNAME"; -pub static WITHCODE: &'static str = "WITHCODE"; -pub static IDX: &'static str = "IDX"; -pub static MINMATCHLEN: &'static str = "MINMATCHLEN"; -pub static WITHMATCHLEN: &'static str = "WITHMATCHLEN"; +pub static MATCH: &str = "MATCH"; +pub static COUNT: &str = "COUNT"; +pub static TYPE: &str = "TYPE"; +pub static CHANGED: &str = "CH"; +pub static INCR: &str = "INCR"; +pub static WITH_SCORES: &str = "WITHSCORES"; +pub static LIMIT: &str = "LIMIT"; +pub static AGGREGATE: &str = "AGGREGATE"; +pub static WEIGHTS: &str = "WEIGHTS"; +pub static GET: &str = "GET"; +pub static RESET: &str = "RESET"; +pub static TO: &str = "TO"; +pub static FORCE: &str = "FORCE"; +pub static ABORT: &str = "ABORT"; +pub static TIMEOUT: &str = "TIMEOUT"; +pub static LEN: &str = "LEN"; +pub static DB: &str = "DB"; +pub static REPLACE: &str = "REPLACE"; +pub static ID: &str = "ID"; +pub static ANY: &str = "ANY"; +pub static STORE: &str = "STORE"; +pub static WITH_VALUES: &str = "WITHVALUES"; +pub static SYNC: &str = "SYNC"; +pub static ASYNC: &str = "ASYNC"; +pub static RANK: &str = "RANK"; +pub static MAXLEN: &str = "MAXLEN"; +pub static REV: &str = "REV"; +pub static ABSTTL: &str = "ABSTTL"; +pub static IDLE_TIME: &str = "IDLETIME"; +pub static FREQ: &str = "FREQ"; +pub static FULL: &str = "FULL"; +pub static NOMKSTREAM: &str = "NOMKSTREAM"; +pub static MINID: &str = "MINID"; +pub static BLOCK: &str = "BLOCK"; +pub static STREAMS: &str = "STREAMS"; +pub static MKSTREAM: &str = "MKSTREAM"; +pub static GROUP: &str = "GROUP"; +pub static NOACK: &str = "NOACK"; +pub static IDLE: &str = "IDLE"; +pub static TIME: &str = "TIME"; +pub static RETRYCOUNT: &str = "RETRYCOUNT"; +pub static JUSTID: &str = "JUSTID"; +pub static SAMPLES: &str = "SAMPLES"; +pub static LIBRARYNAME: &str = "LIBRARYNAME"; +pub static WITHCODE: &str = "WITHCODE"; +pub static IDX: &str = "IDX"; +pub static MINMATCHLEN: &str = "MINMATCHLEN"; +pub static WITHMATCHLEN: &str = "WITHMATCHLEN"; /// Macro to generate a command function that takes no arguments and expects an OK response - returning `()` to the /// caller. @@ -61,7 +61,7 @@ macro_rules! ok_cmd( ($name:ident, $cmd:tt) => { pub async fn $name(client: &C) -> Result<(), RedisError> { let frame = crate::utils::request_response(client, || Ok((RedisCommandKind::$cmd, vec![]))).await?; - let response = crate::protocol::utils::frame_to_single_result(frame)?; + let response = crate::protocol::utils::frame_to_results(frame)?; crate::protocol::utils::expect_ok(&response) } } @@ -72,7 +72,7 @@ macro_rules! simple_cmd( ($name:ident, $cmd:tt, $res:ty) => { pub async fn $name(client: &C) -> Result<$res, RedisError> { let frame = crate::utils::request_response(client, || Ok((RedisCommandKind::$cmd, vec![]))).await?; - crate::protocol::utils::frame_to_single_result(frame) + crate::protocol::utils::frame_to_results(frame) } } ); @@ -102,7 +102,7 @@ pub async fn one_arg_value_cmd( arg: RedisValue, ) -> Result { let frame = utils::request_response(client, move || Ok((kind, vec![arg]))).await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } /// A function that issues a command that only takes one argument and returns a potentially nested `RedisValue`. @@ -124,7 +124,7 @@ pub async fn one_arg_ok_cmd( ) -> Result<(), RedisError> { let frame = utils::request_response(client, move || Ok((kind, vec![arg]))).await?; - let response = protocol_utils::frame_to_single_result(frame)?; + let response = protocol_utils::frame_to_results(frame)?; protocol_utils::expect_ok(&response) } @@ -136,7 +136,7 @@ pub async fn args_value_cmd( args: Vec, ) -> Result { let frame = utils::request_response(client, move || Ok((kind, args))).await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } /// A function that issues a command that takes any number of arguments and returns a potentially nested `RedisValue` @@ -158,7 +158,7 @@ pub async fn args_ok_cmd( args: Vec, ) -> Result<(), RedisError> { let frame = utils::request_response(client, move || Ok((kind, args))).await?; - let response = protocol_utils::frame_to_single_result(frame)?; + let response = protocol_utils::frame_to_results(frame)?; protocol_utils::expect_ok(&response) } @@ -182,6 +182,8 @@ pub mod sorted_sets; pub mod streams; pub mod strings; +#[cfg(feature = "redis-json")] +pub mod redis_json; #[cfg(feature = "sentinel-client")] pub mod sentinel; #[cfg(feature = "client-tracking")] diff --git a/src/commands/impls/pubsub.rs b/src/commands/impls/pubsub.rs index d8cd3015..c2730b13 100644 --- a/src/commands/impls/pubsub.rs +++ b/src/commands/impls/pubsub.rs @@ -31,9 +31,11 @@ pub async fn subscribe(client: &C, channels: MultipleStrings) -> let args = channels.inner().into_iter().map(|c| c.into()).collect(); let mut command: RedisCommand = (RedisCommandKind::Subscribe, args, response).into(); cluster_hash_legacy_command(client, &mut command); - let _ = client.send_command(command)?; - let frame = rx.await??; + let timeout_dur = utils::prepare_command(client, &mut command); + client.send_command(command)?; + + let frame = utils::apply_timeout(rx, timeout_dur).await??; protocol_utils::frame_to_results(frame) } @@ -47,9 +49,11 @@ pub async fn unsubscribe(client: &C, channels: MultipleStrings) - let args = channels.inner().into_iter().map(|c| c.into()).collect(); let mut command: RedisCommand = (RedisCommandKind::Unsubscribe, args, response).into(); cluster_hash_legacy_command(client, &mut command); - let _ = client.send_command(command)?; - let _ = rx.await??; + let timeout_dur = utils::prepare_command(client, &mut command); + client.send_command(command)?; + + let _ = utils::apply_timeout(rx, timeout_dur).await??; Ok(RedisValue::Null) } @@ -59,7 +63,7 @@ pub async fn publish(client: &C, channel: Str, message: RedisValu }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } pub async fn psubscribe(client: &C, patterns: MultipleStrings) -> Result { @@ -72,9 +76,11 @@ pub async fn psubscribe(client: &C, patterns: MultipleStrings) -> let args = patterns.inner().into_iter().map(|p| p.into()).collect(); let mut command: RedisCommand = (RedisCommandKind::Psubscribe, args, response).into(); cluster_hash_legacy_command(client, &mut command); - let _ = client.send_command(command)?; - let frame = rx.await??; + let timeout_dur = utils::prepare_command(client, &mut command); + client.send_command(command)?; + + let frame = utils::apply_timeout(rx, timeout_dur).await??; protocol_utils::frame_to_results(frame) } @@ -88,9 +94,10 @@ pub async fn punsubscribe(client: &C, patterns: MultipleStrings) let args = patterns.inner().into_iter().map(|p| p.into()).collect(); let mut command: RedisCommand = (RedisCommandKind::Punsubscribe, args, response).into(); cluster_hash_legacy_command(client, &mut command); - let _ = client.send_command(command)?; - let _ = rx.await??; + let timeout_dur = utils::prepare_command(client, &mut command); + client.send_command(command)?; + let _ = utils::apply_timeout(rx, timeout_dur).await??; Ok(RedisValue::Null) } @@ -107,7 +114,7 @@ pub async fn spublish( }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } pub async fn ssubscribe(client: &C, channels: MultipleStrings) -> Result { @@ -120,9 +127,11 @@ pub async fn ssubscribe(client: &C, channels: MultipleStrings) -> let args = channels.inner().into_iter().map(|p| p.into()).collect(); let mut command: RedisCommand = (RedisCommandKind::Ssubscribe, args, response).into(); command.hasher = ClusterHash::FirstKey; - let _ = client.send_command(command)?; - let frame = rx.await??; + let timeout_dur = utils::prepare_command(client, &mut command); + client.send_command(command)?; + + let frame = utils::apply_timeout(rx, timeout_dur).await??; protocol_utils::frame_to_results(frame) } @@ -137,9 +146,10 @@ pub async fn sunsubscribe(client: &C, channels: MultipleStrings) let args = channels.inner().into_iter().map(|p| p.into()).collect(); let mut command: RedisCommand = (RedisCommandKind::Sunsubscribe, args, response).into(); command.hasher = hasher; - let _ = client.send_command(command)?; + let timeout_dur = utils::prepare_command(client, &mut command); + client.send_command(command)?; - let _ = rx.await??; + let _ = utils::apply_timeout(rx, timeout_dur).await??; Ok(RedisValue::Null) } @@ -158,7 +168,7 @@ pub async fn pubsub_channels(client: &C, pattern: Str) -> Result< }) .await?; - protocol_utils::frame_to_results_raw(frame) + protocol_utils::frame_to_results(frame) } pub async fn pubsub_numpat(client: &C) -> Result { @@ -183,7 +193,7 @@ pub async fn pubsub_numsub(client: &C, channels: MultipleStrings) }) .await?; - protocol_utils::frame_to_results_raw(frame) + protocol_utils::frame_to_results(frame) } pub async fn pubsub_shardchannels(client: &C, pattern: Str) -> Result { @@ -192,7 +202,7 @@ pub async fn pubsub_shardchannels(client: &C, pattern: Str) -> Re }) .await?; - protocol_utils::frame_to_results_raw(frame) + protocol_utils::frame_to_results(frame) } pub async fn pubsub_shardnumsub( @@ -201,7 +211,7 @@ pub async fn pubsub_shardnumsub( ) -> Result { let frame = utils::request_response(client, || { let args: Vec = channels.inner().into_iter().map(|s| s.into()).collect(); - let has_args = args.len() > 0; + let has_args = !args.is_empty(); let mut command: RedisCommand = RedisCommand::new(RedisCommandKind::PubsubShardnumsub, args); if !has_args { cluster_hash_legacy_command(client, &mut command); @@ -211,5 +221,5 @@ pub async fn pubsub_shardnumsub( }) .await?; - protocol_utils::frame_to_results_raw(frame) + protocol_utils::frame_to_results(frame) } diff --git a/src/commands/impls/redis_json.rs b/src/commands/impls/redis_json.rs new file mode 100644 index 00000000..e25d881c --- /dev/null +++ b/src/commands/impls/redis_json.rs @@ -0,0 +1,358 @@ +use crate::{ + error::{RedisError, RedisErrorKind}, + interfaces::{ClientLike, RedisResult}, + protocol::{command::RedisCommandKind, utils as protocol_utils}, + types::{MultipleKeys, MultipleStrings, RedisKey, RedisValue, SetOptions}, + utils, +}; +use bytes_utils::Str; +use serde_json::Value; + +const INDENT: &'static str = "INDENT"; +const NEWLINE: &'static str = "NEWLINE"; +const SPACE: &'static str = "SPACE"; + +fn key_path_args(key: RedisKey, path: Option, extra: usize) -> Vec { + let mut out = Vec::with_capacity(2 + extra); + out.push(key.into()); + if let Some(path) = path { + out.push(path.into()); + } + out +} + +/// Convert the provided json value to a redis value by serializing into a json string. +fn value_to_bulk_str(value: &Value) -> Result { + Ok(match value { + Value::String(ref s) => RedisValue::String(Str::from(s)), + _ => RedisValue::String(Str::from(serde_json::to_string(value)?)), + }) +} + +/// Convert the provided json value to a redis value directly without serializing into a string. This only works with +/// scalar values. +fn json_to_redis(value: Value) -> Result { + let out = match value { + Value::String(s) => Some(RedisValue::String(Str::from(s))), + Value::Null => Some(RedisValue::Null), + Value::Number(n) => { + if n.is_f64() { + n.as_f64().map(RedisValue::Double) + } else { + n.as_i64().map(RedisValue::Integer) + } + }, + Value::Bool(b) => Some(RedisValue::Boolean(b)), + _ => None, + }; + + out.ok_or(RedisError::new( + RedisErrorKind::InvalidArgument, + "Expected string or number.", + )) +} + +fn values_to_bulk(values: &Vec) -> Result, RedisError> { + values.iter().map(value_to_bulk_str).collect() +} + +pub async fn json_arrappend( + client: &C, + key: RedisKey, + path: Str, + values: Vec, +) -> RedisResult { + let frame = utils::request_response(client, || { + let mut args = key_path_args(key, Some(path), values.len()); + args.extend(values_to_bulk(&values)?); + + Ok((RedisCommandKind::JsonArrAppend, args)) + }) + .await?; + protocol_utils::frame_to_results(frame) +} + +pub async fn json_arrindex( + client: &C, + key: RedisKey, + path: Str, + value: Value, + start: Option, + stop: Option, +) -> RedisResult { + let frame = utils::request_response(client, || { + let mut args = Vec::with_capacity(5); + args.extend([key.into(), path.into(), value_to_bulk_str(&value)?]); + if let Some(start) = start { + args.push(start.into()); + } + if let Some(stop) = stop { + args.push(stop.into()); + } + + Ok((RedisCommandKind::JsonArrIndex, args)) + }) + .await?; + protocol_utils::frame_to_results(frame) +} + +pub async fn json_arrinsert( + client: &C, + key: RedisKey, + path: Str, + index: i64, + values: Vec, +) -> RedisResult { + let frame = utils::request_response(client, || { + let mut args = Vec::with_capacity(3 + values.len()); + args.extend([key.into(), path.into(), index.into()]); + args.extend(values_to_bulk(&values)?); + + Ok((RedisCommandKind::JsonArrInsert, args)) + }) + .await?; + protocol_utils::frame_to_results(frame) +} + +pub async fn json_arrlen(client: &C, key: RedisKey, path: Option) -> RedisResult { + let frame = utils::request_response(client, || { + Ok((RedisCommandKind::JsonArrLen, key_path_args(key, path, 0))) + }) + .await?; + protocol_utils::frame_to_results(frame) +} + +pub async fn json_arrpop( + client: &C, + key: RedisKey, + path: Option, + index: Option, +) -> RedisResult { + let frame = utils::request_response(client, || { + let mut args = key_path_args(key, path, 1); + if let Some(index) = index { + args.push(index.into()); + } + + Ok((RedisCommandKind::JsonArrPop, args)) + }) + .await?; + protocol_utils::frame_to_results(frame) +} + +pub async fn json_arrtrim( + client: &C, + key: RedisKey, + path: Str, + start: i64, + stop: i64, +) -> RedisResult { + let frame = utils::request_response(client, || { + Ok((RedisCommandKind::JsonArrTrim, vec![ + key.into(), + path.into(), + start.into(), + stop.into(), + ])) + }) + .await?; + protocol_utils::frame_to_results(frame) +} + +pub async fn json_clear(client: &C, key: RedisKey, path: Option) -> RedisResult { + let frame = utils::request_response(client, || { + Ok((RedisCommandKind::JsonClear, key_path_args(key, path, 0))) + }) + .await?; + protocol_utils::frame_to_results(frame) +} + +pub async fn json_debug_memory( + client: &C, + key: RedisKey, + path: Option, +) -> RedisResult { + let frame = utils::request_response(client, || { + Ok((RedisCommandKind::JsonDebugMemory, key_path_args(key, path, 0))) + }) + .await?; + protocol_utils::frame_to_results(frame) +} + +pub async fn json_del(client: &C, key: RedisKey, path: Str) -> RedisResult { + let frame = utils::request_response(client, || { + Ok((RedisCommandKind::JsonDel, key_path_args(key, Some(path), 0))) + }) + .await?; + protocol_utils::frame_to_results(frame) +} + +pub async fn json_get( + client: &C, + key: RedisKey, + indent: Option, + newline: Option, + space: Option, + paths: MultipleStrings, +) -> RedisResult { + let frame = utils::request_response(client, || { + let mut args = Vec::with_capacity(7 + paths.len()); + args.push(key.into()); + if let Some(indent) = indent { + args.push(static_val!(INDENT)); + args.push(indent.into()); + } + if let Some(newline) = newline { + args.push(static_val!(NEWLINE)); + args.push(newline.into()); + } + if let Some(space) = space { + args.push(static_val!(SPACE)); + args.push(space.into()); + } + args.extend(paths.into_values()); + + Ok((RedisCommandKind::JsonGet, args)) + }) + .await?; + protocol_utils::frame_to_results(frame) +} + +pub async fn json_merge( + client: &C, + key: RedisKey, + path: Str, + value: Value, +) -> RedisResult { + let frame = utils::request_response(client, || { + Ok((RedisCommandKind::JsonMerge, vec![ + key.into(), + path.into(), + value_to_bulk_str(&value)?, + ])) + }) + .await?; + protocol_utils::frame_to_results(frame) +} + +pub async fn json_mget(client: &C, keys: MultipleKeys, path: Str) -> RedisResult { + let frame = utils::request_response(client, || { + let mut args = Vec::with_capacity(keys.len() + 1); + args.extend(keys.into_values()); + args.push(path.into()); + + Ok((RedisCommandKind::JsonMGet, args)) + }) + .await?; + protocol_utils::frame_to_results(frame) +} + +pub async fn json_mset(client: &C, values: Vec<(RedisKey, Str, Value)>) -> RedisResult { + let frame = utils::request_response(client, || { + let mut args = Vec::with_capacity(values.len() * 3); + for (key, path, value) in values.into_iter() { + args.extend([key.into(), path.into(), value_to_bulk_str(&value)?]); + } + + Ok((RedisCommandKind::JsonMSet, args)) + }) + .await?; + protocol_utils::frame_to_results(frame) +} + +pub async fn json_numincrby( + client: &C, + key: RedisKey, + path: Str, + value: Value, +) -> RedisResult { + let frame = utils::request_response(client, || { + Ok((RedisCommandKind::JsonNumIncrBy, vec![ + key.into(), + path.into(), + json_to_redis(value)?, + ])) + }) + .await?; + protocol_utils::frame_to_results(frame) +} + +pub async fn json_objkeys(client: &C, key: RedisKey, path: Option) -> RedisResult { + let frame = utils::request_response(client, || { + Ok((RedisCommandKind::JsonObjKeys, key_path_args(key, path, 0))) + }) + .await?; + protocol_utils::frame_to_results(frame) +} + +pub async fn json_objlen(client: &C, key: RedisKey, path: Option) -> RedisResult { + let frame = utils::request_response(client, || { + Ok((RedisCommandKind::JsonObjLen, key_path_args(key, path, 0))) + }) + .await?; + protocol_utils::frame_to_results(frame) +} + +pub async fn json_resp(client: &C, key: RedisKey, path: Option) -> RedisResult { + let frame = + utils::request_response(client, || Ok((RedisCommandKind::JsonResp, key_path_args(key, path, 0)))).await?; + protocol_utils::frame_to_results(frame) +} + +pub async fn json_set( + client: &C, + key: RedisKey, + path: Str, + value: Value, + options: Option, +) -> RedisResult { + let frame = utils::request_response(client, || { + let mut args = key_path_args(key, Some(path), 2); + args.push(value_to_bulk_str(&value)?); + if let Some(options) = options { + args.push(options.to_str().into()); + } + + Ok((RedisCommandKind::JsonSet, args)) + }) + .await?; + protocol_utils::frame_to_results(frame) +} + +pub async fn json_strappend( + client: &C, + key: RedisKey, + path: Option, + value: Value, +) -> RedisResult { + let frame = utils::request_response(client, || { + let mut args = key_path_args(key, path, 1); + args.push(value_to_bulk_str(&value)?); + + Ok((RedisCommandKind::JsonStrAppend, args)) + }) + .await?; + protocol_utils::frame_to_results(frame) +} + +pub async fn json_strlen(client: &C, key: RedisKey, path: Option) -> RedisResult { + let frame = utils::request_response(client, || { + Ok((RedisCommandKind::JsonStrLen, key_path_args(key, path, 0))) + }) + .await?; + protocol_utils::frame_to_results(frame) +} + +pub async fn json_toggle(client: &C, key: RedisKey, path: Str) -> RedisResult { + let frame = utils::request_response(client, || { + Ok((RedisCommandKind::JsonToggle, key_path_args(key, Some(path), 0))) + }) + .await?; + protocol_utils::frame_to_results(frame) +} + +pub async fn json_type(client: &C, key: RedisKey, path: Option) -> RedisResult { + let frame = + utils::request_response(client, || Ok((RedisCommandKind::JsonType, key_path_args(key, path, 0)))).await?; + protocol_utils::frame_to_results(frame) +} diff --git a/src/commands/impls/scan.rs b/src/commands/impls/scan.rs index 640657da..a00ec236 100644 --- a/src/commands/impls/scan.rs +++ b/src/commands/impls/scan.rs @@ -17,7 +17,7 @@ use std::sync::Arc; use tokio::sync::mpsc::{unbounded_channel, UnboundedSender}; use tokio_stream::wrappers::UnboundedReceiverStream; -static STARTING_CURSOR: &'static str = "0"; +static STARTING_CURSOR: &str = "0"; fn values_args(key: RedisKey, pattern: Str, count: Option) -> Vec { let mut args = Vec::with_capacity(6); @@ -36,7 +36,7 @@ fn values_args(key: RedisKey, pattern: Str, count: Option) -> Vec(tx: &UnboundedSender>, error: RedisError) { let tx = tx.clone(); - let _ = tokio::spawn(async move { + tokio::spawn(async move { let _ = tx.send(Err(error)); }); } @@ -204,7 +204,7 @@ pub fn zscan( count: Option, ) -> impl Stream> { let (tx, rx) = unbounded_channel(); - let args = values_args(key.into(), pattern, count); + let args = values_args(key, pattern, count); let response = ResponseKind::ValueScan(ValueScanInner { tx: tx.clone(), diff --git a/src/commands/impls/server.rs b/src/commands/impls/server.rs index fbcac954..0c375a84 100644 --- a/src/commands/impls/server.rs +++ b/src/commands/impls/server.rs @@ -22,7 +22,7 @@ pub async fn quit(client: &C) -> Result<(), RedisError> { _debug!(inner, "Closing Redis connection with Quit command."); let (tx, rx) = oneshot_channel(); - let command: RedisCommand = if inner.config.server.is_clustered() { + let mut command: RedisCommand = if inner.config.server.is_clustered() { let response = ResponseKind::new_buffer(tx); (RedisCommandKind::Quit, vec![], response).into() } else { @@ -32,9 +32,9 @@ pub async fn quit(client: &C) -> Result<(), RedisError> { utils::set_client_state(&inner.state, ClientState::Disconnecting); inner.notifications.broadcast_close(); - let _ = client.send_command(command)?; - let _ = rx.await??; - utils::abort_network_timeout_task(&inner); + let timeout_dur = utils::prepare_command(client, &mut command); + client.send_command(command)?; + let _ = utils::apply_timeout(rx, timeout_dur).await??; inner.notifications.close_public_receivers(); inner.backchannel.write().await.check_and_disconnect(&inner, None).await; @@ -51,7 +51,7 @@ pub async fn shutdown(client: &C, flags: Option) - Vec::new() }; let (tx, rx) = oneshot_channel(); - let command: RedisCommand = if inner.config.server.is_clustered() { + let mut command: RedisCommand = if inner.config.server.is_clustered() { let response = ResponseKind::new_buffer(tx); (RedisCommandKind::Shutdown, args, response).into() } else { @@ -61,9 +61,9 @@ pub async fn shutdown(client: &C, flags: Option) - utils::set_client_state(&inner.state, ClientState::Disconnecting); inner.notifications.broadcast_close(); - let _ = client.send_command(command)?; - let _ = rx.await??; - utils::abort_network_timeout_task(&inner); + let timeout_dur = utils::prepare_command(client, &mut command); + client.send_command(command)?; + let _ = utils::apply_timeout(rx, timeout_dur).await??; inner.notifications.close_public_receivers(); inner.backchannel.write().await.check_and_disconnect(&inner, None).await; @@ -89,8 +89,9 @@ pub fn split(inner: &Arc) -> Result, RedisErr config.server = ServerConfig::Centralized { server }; let perf = inner.performance_config(); let policy = inner.reconnect_policy(); + let connection = inner.connection_config(); - RedisClient::new(config, Some(perf), policy) + RedisClient::new(config, Some(perf), Some(connection), policy) }) .collect(), ) @@ -105,7 +106,7 @@ pub async fn force_reconnection(inner: &Arc) -> Result<(), Red #[cfg(feature = "replicas")] replica: false, }; - let _ = interfaces::send_to_router(inner, command)?; + interfaces::send_to_router(inner, command)?; rx.await?.map(|_| ()) } @@ -114,7 +115,7 @@ pub async fn flushall(client: &C, r#async: bool) -> Result(client: &C) -> Result<(), RedisError> { @@ -124,10 +125,11 @@ pub async fn flushall_cluster(client: &C) -> Result<(), RedisErro let (tx, rx) = oneshot_channel(); let response = ResponseKind::new_buffer(tx); - let command: RedisCommand = (RedisCommandKind::_FlushAllCluster, vec![], response).into(); - let _ = client.send_command(command)?; + let mut command: RedisCommand = (RedisCommandKind::_FlushAllCluster, vec![], response).into(); + let timeout_dur = utils::prepare_command(client, &mut command); + client.send_command(command)?; - let _ = rx.await??; + let _ = utils::apply_timeout(rx, timeout_dur).await??; Ok(()) } @@ -138,7 +140,7 @@ pub async fn ping(client: &C) -> Result { pub async fn select(client: &C, db: u8) -> Result { let frame = utils::request_response(client, || Ok((RedisCommandKind::Select, vec![db.into()]))).await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } pub async fn info(client: &C, section: Option) -> Result { @@ -152,7 +154,7 @@ pub async fn info(client: &C, section: Option) -> Resul }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } pub async fn hello( @@ -171,8 +173,9 @@ pub async fn hello( let mut command: RedisCommand = RedisCommandKind::_HelloAllCluster(version).into(); command.response = ResponseKind::new_buffer(tx); - let _ = client.send_command(command)?; - let _ = rx.await??; + let timeout_dur = utils::prepare_command(client, &mut command); + client.send_command(command)?; + let _ = utils::apply_timeout(rx, timeout_dur).await??; Ok(()) } else { let frame = utils::request_response(client, move || Ok((RedisCommandKind::_Hello(version), args))).await?; @@ -191,15 +194,16 @@ pub async fn auth(client: &C, username: Option, password: if client.inner().config.server.is_clustered() { let (tx, rx) = oneshot_channel(); let response = ResponseKind::new_buffer(tx); - let command: RedisCommand = (RedisCommandKind::_AuthAllCluster, args, response).into(); - let _ = client.send_command(command)?; + let mut command: RedisCommand = (RedisCommandKind::_AuthAllCluster, args, response).into(); - let _ = rx.await??; + let timeout_dur = utils::prepare_command(client, &mut command); + client.send_command(command)?; + let _ = utils::apply_timeout(rx, timeout_dur).await??; Ok(()) } else { let frame = utils::request_response(client, move || Ok((RedisCommandKind::Auth, args))).await?; - let response = protocol_utils::frame_to_single_result(frame)?; + let response = protocol_utils::frame_to_results(frame)?; protocol_utils::expect_ok(&response) } } @@ -253,7 +257,7 @@ pub async fn failover( }) .await?; - let response = protocol_utils::frame_to_single_result(frame)?; + let response = protocol_utils::frame_to_results(frame)?; protocol_utils::expect_ok(&response) } @@ -265,5 +269,5 @@ pub async fn wait(client: &C, numreplicas: i64, timeout: i64) -> }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } diff --git a/src/commands/impls/sets.rs b/src/commands/impls/sets.rs index d030e594..5ae47e82 100644 --- a/src/commands/impls/sets.rs +++ b/src/commands/impls/sets.rs @@ -12,17 +12,18 @@ pub async fn sadd( members: MultipleValues, ) -> Result { let frame = utils::request_response(client, move || { + let members = members.into_multiple_values(); let mut args = Vec::with_capacity(1 + members.len()); args.push(key.into()); - for member in members.inner().into_iter() { + for member in members.into_iter() { args.push(member); } Ok((RedisCommandKind::Sadd, args)) }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } pub async fn scard(client: &C, key: RedisKey) -> Result { @@ -59,7 +60,7 @@ pub async fn sdiffstore( }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } pub async fn sinter(client: &C, keys: MultipleKeys) -> Result { @@ -92,7 +93,7 @@ pub async fn sinterstore( }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } pub async fn sismember( @@ -109,10 +110,11 @@ pub async fn smismember( members: MultipleValues, ) -> Result { let frame = utils::request_response(client, move || { + let members = members.into_multiple_values(); let mut args = Vec::with_capacity(1 + members.len()); args.push(key.into()); - for member in members.inner().into_iter() { + for member in members.into_iter() { args.push(member); } Ok((RedisCommandKind::Smismember, args)) @@ -176,17 +178,18 @@ pub async fn srem( members: MultipleValues, ) -> Result { let frame = utils::request_response(client, move || { + let members = members.into_multiple_values(); let mut args = Vec::with_capacity(1 + members.len()); args.push(key.into()); - for member in members.inner().into_iter() { + for member in members.into_iter() { args.push(member); } Ok((RedisCommandKind::Srem, args)) }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } pub async fn sunion(client: &C, keys: MultipleKeys) -> Result { @@ -219,5 +222,5 @@ pub async fn sunionstore( }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } diff --git a/src/commands/impls/slowlog.rs b/src/commands/impls/slowlog.rs index d14238d4..3659a442 100644 --- a/src/commands/impls/slowlog.rs +++ b/src/commands/impls/slowlog.rs @@ -1,13 +1,11 @@ use super::*; use crate::{ - error::*, + prelude::*, protocol::{command::RedisCommandKind, utils as protocol_utils}, - types::*, utils, }; -use redis_protocol::resp3::types::Frame; -pub async fn slowlog_get(client: &C, count: Option) -> Result, RedisError> { +pub async fn slowlog_get(client: &C, count: Option) -> Result { let frame = utils::request_response(client, move || { let mut args = Vec::with_capacity(2); args.push(static_val!(GET)); @@ -20,22 +18,12 @@ pub async fn slowlog_get(client: &C, count: Option) -> Resul }) .await?; - if let Frame::Array { data, .. } = frame { - protocol_utils::parse_slowlog_entries(data) - } else { - Err(RedisError::new(RedisErrorKind::Protocol, "Expected array response.")) - } + protocol_utils::frame_to_results(frame) } -pub async fn slowlog_length(client: &C) -> Result { +pub async fn slowlog_length(client: &C) -> Result { let frame = utils::request_response(client, || Ok((RedisCommandKind::Slowlog, vec![LEN.into()]))).await?; - let response = protocol_utils::frame_to_single_result(frame)?; - - if let RedisValue::Integer(len) = response { - Ok(len as u64) - } else { - Err(RedisError::new(RedisErrorKind::Protocol, "Expected integer response.")) - } + protocol_utils::frame_to_results(frame) } pub async fn slowlog_reset(client: &C) -> Result<(), RedisError> { diff --git a/src/commands/impls/sorted_sets.rs b/src/commands/impls/sorted_sets.rs index 646fa227..f452655b 100644 --- a/src/commands/impls/sorted_sets.rs +++ b/src/commands/impls/sorted_sets.rs @@ -41,8 +41,8 @@ fn check_range_type(range: &ZRange, kind: &Option) -> Result<(), RedisErr } fn check_range_types(min: &ZRange, max: &ZRange, kind: &Option) -> Result<(), RedisError> { - let _ = check_range_type(min, kind)?; - let _ = check_range_type(max, kind)?; + check_range_type(min, kind)?; + check_range_type(max, kind)?; Ok(()) } @@ -72,7 +72,7 @@ pub async fn bzmpop( }) .await?; - let _ = protocol_utils::check_null_timeout(&frame)?; + protocol_utils::check_null_timeout(&frame)?; protocol_utils::frame_to_results(frame) } @@ -89,7 +89,7 @@ pub async fn bzpopmin(client: &C, keys: MultipleKeys, timeout: f6 }) .await?; - let _ = protocol_utils::check_null_timeout(&frame)?; + protocol_utils::check_null_timeout(&frame)?; protocol_utils::frame_to_results(frame) } @@ -106,7 +106,7 @@ pub async fn bzpopmax(client: &C, keys: MultipleKeys, timeout: f6 }) .await?; - let _ = protocol_utils::check_null_timeout(&frame)?; + protocol_utils::check_null_timeout(&frame)?; protocol_utils::frame_to_results(frame) } @@ -197,7 +197,7 @@ pub async fn zdiffstore( }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } pub async fn zincrby( @@ -278,7 +278,7 @@ pub async fn zinterstore( }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } pub async fn zlexcount( @@ -287,7 +287,7 @@ pub async fn zlexcount( min: ZRange, max: ZRange, ) -> Result { - let _ = check_range_types(&min, &max, &Some(ZSort::ByLex))?; + check_range_types(&min, &max, &Some(ZSort::ByLex))?; let args = vec![key.into(), min.into_value()?, max.into_value()?]; args_value_cmd(client, RedisCommandKind::Zlexcount, args).await @@ -379,7 +379,7 @@ pub async fn zrangestore( rev: bool, limit: Option, ) -> Result { - let _ = check_range_types(&min, &max, &sort)?; + check_range_types(&min, &max, &sort)?; let frame = utils::request_response(client, move || { let mut args = Vec::with_capacity(9); @@ -404,7 +404,7 @@ pub async fn zrangestore( }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } pub async fn zrange( @@ -417,7 +417,7 @@ pub async fn zrange( limit: Option, withscores: bool, ) -> Result { - let _ = check_range_types(&min, &max, &sort)?; + check_range_types(&min, &max, &sort)?; let frame = utils::request_response(client, move || { let mut args = Vec::with_capacity(9); @@ -454,7 +454,7 @@ pub async fn zrangebylex( max: ZRange, limit: Option, ) -> Result { - let _ = check_range_types(&min, &max, &Some(ZSort::ByLex))?; + check_range_types(&min, &max, &Some(ZSort::ByLex))?; let frame = utils::request_response(client, move || { let mut args = Vec::with_capacity(6); @@ -482,7 +482,7 @@ pub async fn zrevrangebylex( min: ZRange, limit: Option, ) -> Result { - let _ = check_range_types(&min, &max, &Some(ZSort::ByLex))?; + check_range_types(&min, &max, &Some(ZSort::ByLex))?; let frame = utils::request_response(client, move || { let mut args = Vec::with_capacity(6); @@ -573,17 +573,18 @@ pub async fn zrem( members: MultipleValues, ) -> Result { let frame = utils::request_response(client, move || { + let members = members.into_multiple_values(); let mut args = Vec::with_capacity(1 + members.len()); args.push(key.into()); - for member in members.inner().into_iter() { + for member in members.into_iter() { args.push(member); } Ok((RedisCommandKind::Zrem, args)) }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } pub async fn zremrangebylex( @@ -593,7 +594,7 @@ pub async fn zremrangebylex( max: ZRange, ) -> Result { let frame = utils::request_response(client, move || { - let _ = check_range_types(&min, &max, &Some(ZSort::ByLex))?; + check_range_types(&min, &max, &Some(ZSort::ByLex))?; Ok((RedisCommandKind::Zremrangebylex, vec![ key.into(), @@ -603,7 +604,7 @@ pub async fn zremrangebylex( }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } pub async fn zremrangebyrank( @@ -623,7 +624,7 @@ pub async fn zremrangebyscore( max: ZRange, ) -> Result { let frame = utils::request_response(client, move || { - let _ = check_range_types(&min, &max, &Some(ZSort::ByScore))?; + check_range_types(&min, &max, &Some(ZSort::ByScore))?; Ok((RedisCommandKind::Zremrangebyscore, vec![ key.into(), @@ -633,7 +634,7 @@ pub async fn zremrangebyscore( }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } pub async fn zrevrange( @@ -742,7 +743,7 @@ pub async fn zunionstore( }) .await?; - protocol_utils::frame_to_single_result(frame) + protocol_utils::frame_to_results(frame) } pub async fn zmscore( @@ -751,10 +752,11 @@ pub async fn zmscore( members: MultipleValues, ) -> Result { let frame = utils::request_response(client, move || { + let members = members.into_multiple_values(); let mut args = Vec::with_capacity(1 + members.len()); args.push(key.into()); - for member in members.inner().into_iter() { + for member in members.into_iter() { args.push(member); } Ok((RedisCommandKind::Zmscore, args)) diff --git a/src/commands/impls/streams.rs b/src/commands/impls/streams.rs index a0e03bb3..a1025759 100644 --- a/src/commands/impls/streams.rs +++ b/src/commands/impls/streams.rs @@ -45,12 +45,12 @@ pub async fn xinfo_consumers( }) .await?; - protocol_utils::frame_to_results_raw(frame) + protocol_utils::frame_to_results(frame) } pub async fn xinfo_groups(client: &C, key: RedisKey) -> Result { let frame = utils::request_response(client, move || Ok((RedisCommandKind::XinfoGroups, vec![key.into()]))).await?; - protocol_utils::frame_to_results_raw(frame) + protocol_utils::frame_to_results(frame) } pub async fn xinfo_stream( @@ -158,7 +158,7 @@ pub async fn xrange( }) .await?; - protocol_utils::frame_to_results_raw(frame) + protocol_utils::frame_to_results(frame) } pub async fn xrevrange( @@ -183,7 +183,7 @@ pub async fn xrevrange( }) .await?; - protocol_utils::frame_to_results_raw(frame) + protocol_utils::frame_to_results(frame) } pub async fn xlen(client: &C, key: RedisKey) -> Result { @@ -233,7 +233,7 @@ pub async fn xread( }) .await?; - protocol_utils::frame_to_results_raw(frame) + protocol_utils::frame_to_results(frame) } pub async fn xgroup_create( @@ -377,7 +377,7 @@ pub async fn xreadgroup( }) .await?; - protocol_utils::frame_to_results_raw(frame) + protocol_utils::frame_to_results(frame) } pub async fn xack( @@ -447,7 +447,7 @@ pub async fn xclaim( }) .await?; - protocol_utils::frame_to_results_raw(frame) + protocol_utils::frame_to_results(frame) } pub async fn xautoclaim( @@ -480,7 +480,7 @@ pub async fn xautoclaim( }) .await?; - protocol_utils::frame_to_results_raw(frame) + protocol_utils::frame_to_results(frame) } pub async fn xpending( diff --git a/src/commands/impls/tracking.rs b/src/commands/impls/tracking.rs index 7f540359..ba61effd 100644 --- a/src/commands/impls/tracking.rs +++ b/src/commands/impls/tracking.rs @@ -90,7 +90,8 @@ pub async fn start_tracking( let command: RedisCommand = (RedisCommandKind::_ClientTrackingCluster, args, response).into(); let _ = client.send_command(command)?; - let _ = protocol_utils::frame_to_results(rx.await??)?; + let frame = utils::apply_timeout(rx, client.inner().internal_command_timeout()).await??; + let _ = protocol_utils::frame_to_results(frame)?; Ok(()) } } else { @@ -117,7 +118,8 @@ pub async fn stop_tracking(client: &C) -> Result<(), RedisError> let command: RedisCommand = (RedisCommandKind::_ClientTrackingCluster, args, response).into(); let _ = client.send_command(command)?; - let _ = protocol_utils::frame_to_results(rx.await??)?; + let frame = utils::apply_timeout(rx, client.inner().internal_command_timeout()).await??; + let _ = protocol_utils::frame_to_results(frame)?; Ok(()) } else { utils::request_response(client, move || Ok((RedisCommandKind::ClientTracking, args))) diff --git a/src/commands/interfaces/acl.rs b/src/commands/interfaces/acl.rs index c37c71b7..6de423ec 100644 --- a/src/commands/interfaces/acl.rs +++ b/src/commands/interfaces/acl.rs @@ -1,7 +1,8 @@ use crate::{ commands, + error::RedisError, interfaces::{ClientLike, RedisResult}, - types::{AclRule, AclUser, FromRedis, MultipleStrings, RedisValue}, + types::{FromRedis, MultipleStrings, MultipleValues}, }; use bytes_utils::Str; @@ -11,11 +12,14 @@ pub trait AclInterface: ClientLike + Sized { /// Create an ACL user with the specified rules or modify the rules of an existing user. /// /// - async fn acl_setuser(&self, username: S, rules: Vec) -> RedisResult<()> + async fn acl_setuser(&self, username: S, rules: V) -> RedisResult<()> where S: Into + Send, + V: TryInto + Send, + V::Error: Into + Send, { into!(username); + try_into!(rules); commands::acl::acl_setuser(self, username, rules).await } @@ -58,12 +62,13 @@ pub trait AclInterface: ClientLike + Sized { /// The command returns all the rules defined for an existing ACL user. /// /// - async fn acl_getuser(&self, username: S) -> RedisResult> + async fn acl_getuser(&self, username: S) -> RedisResult where + R: FromRedis, S: Into + Send, { into!(username); - commands::acl::acl_getuser(self, username).await + commands::acl::acl_getuser(self, username).await?.convert() } /// Delete all the specified ACL users and terminate all the connections that are authenticated with such users. @@ -82,14 +87,20 @@ pub trait AclInterface: ClientLike + Sized { /// the command shows all the Redis commands in the specified category. /// /// - async fn acl_cat(&self, category: Option) -> RedisResult> { + async fn acl_cat(&self, category: Option) -> RedisResult + where + R: FromRedis, + { commands::acl::acl_cat(self, category).await?.convert() } /// Generate a password with length `bits`, returning the password. /// /// - async fn acl_genpass(&self, bits: Option) -> RedisResult { + async fn acl_genpass(&self, bits: Option) -> RedisResult + where + R: FromRedis, + { commands::acl::acl_genpass(self, bits).await?.convert() } @@ -97,15 +108,21 @@ pub trait AclInterface: ClientLike + Sized { /// with the "default" user. /// /// - async fn acl_whoami(&self) -> RedisResult { + async fn acl_whoami(&self) -> RedisResult + where + R: FromRedis, + { commands::acl::acl_whoami(self).await?.convert() } /// Read `count` recent ACL security events. /// /// - async fn acl_log_count(&self, count: Option) -> RedisResult { - commands::acl::acl_log_count(self, count).await + async fn acl_log_count(&self, count: Option) -> RedisResult + where + R: FromRedis, + { + commands::acl::acl_log_count(self, count).await?.convert() } /// Clear the ACL security events logs. diff --git a/src/commands/interfaces/client.rs b/src/commands/interfaces/client.rs index c070918a..a15fa212 100644 --- a/src/commands/interfaces/client.rs +++ b/src/commands/interfaces/client.rs @@ -17,7 +17,7 @@ use std::collections::HashMap; #[cfg(feature = "client-tracking")] use crate::{ - error::{RedisError, RedisErrorKind}, + error::RedisError, types::{MultipleStrings, Toggle}, }; @@ -150,7 +150,8 @@ pub trait ClientInterface: ClientLike + Sized { /// /// /// - /// Note: **This function requires a centralized server**. See + /// This function is designed to work against a specific server, either via a centralized server config or + /// [with_options](crate::interfaces::ClientLike::with_options). See /// [crate::interfaces::TrackingInterface::start_tracking] for a version that works with all server deployment /// modes. #[cfg(feature = "client-tracking")] @@ -171,13 +172,6 @@ pub trait ClientInterface: ClientLike + Sized { T::Error: Into + Send, P: Into + Send, { - if self.inner().config.server.is_clustered() { - return Err(RedisError::new( - RedisErrorKind::Config, - "Invalid server type. Expected centralized server.", - )); - } - try_into!(toggle); into!(prefixes); commands::tracking::client_tracking(self, toggle, redirect, prefixes, bcast, optin, optout, noloop) @@ -215,22 +209,15 @@ pub trait ClientInterface: ClientLike + Sized { /// /// /// - /// Note: **This function requires a centralized server**. See - /// [TrackingInterface::caching](crate::interfaces::TrackingInterface::caching) for a version that works with all - /// server deployment modes. + /// This function is designed to work against a specific server. See + /// [with_options](crate::interfaces::ClientLike::with_options) for a variation that works with all deployment + /// types. #[cfg(feature = "client-tracking")] #[cfg_attr(docsrs, doc(cfg(feature = "client-tracking")))] async fn client_caching(&self, enabled: bool) -> RedisResult where R: FromRedis, { - if self.inner().config.server.is_clustered() { - return Err(RedisError::new( - RedisErrorKind::Config, - "Invalid server type. Expected centralized server.", - )); - } - commands::tracking::client_caching(self, enabled).await?.convert() } } diff --git a/src/commands/interfaces/cluster.rs b/src/commands/interfaces/cluster.rs index 4b1ae7e0..92a78e80 100644 --- a/src/commands/interfaces/cluster.rs +++ b/src/commands/interfaces/cluster.rs @@ -3,16 +3,7 @@ use crate::{ error::RedisError, interfaces::{ClientLike, RedisResult}, protocol::types::ClusterRouting, - types::{ - ClusterFailoverFlag, - ClusterInfo, - ClusterResetFlag, - ClusterSetSlotState, - FromRedis, - MultipleHashSlots, - RedisKey, - RedisValue, - }, + types::{ClusterFailoverFlag, ClusterResetFlag, ClusterSetSlotState, FromRedis, MultipleHashSlots, RedisKey}, }; use bytes_utils::Str; @@ -70,7 +61,10 @@ pub trait ClusterInterface: ClientLike + Sized { /// [cached_cluster_state](Self::cached_cluster_state). /// /// - async fn cluster_nodes(&self) -> RedisResult { + async fn cluster_nodes(&self) -> RedisResult + where + R: FromRedis, + { commands::cluster::cluster_nodes(self).await?.convert() } @@ -84,15 +78,21 @@ pub trait ClusterInterface: ClientLike + Sized { /// CLUSTER SLOTS returns details about which cluster slots map to which Redis instances. /// /// - async fn cluster_slots(&self) -> RedisResult { - commands::cluster::cluster_slots(self).await + async fn cluster_slots(&self) -> RedisResult + where + R: FromRedis, + { + commands::cluster::cluster_slots(self).await?.convert() } /// CLUSTER INFO provides INFO style information about Redis Cluster vital parameters. /// /// - async fn cluster_info(&self) -> RedisResult { - commands::cluster::cluster_info(self).await + async fn cluster_info(&self) -> RedisResult + where + R: FromRedis, + { + commands::cluster::cluster_info(self).await?.convert() } /// This command is useful in order to modify a node's view of the cluster configuration. Specifically it assigns a @@ -217,8 +217,9 @@ pub trait ClusterInterface: ClientLike + Sized { /// The command provides a list of replica nodes replicating from the specified master node. /// /// - async fn cluster_replicas(&self, node_id: S) -> RedisResult + async fn cluster_replicas(&self, node_id: S) -> RedisResult where + R: FromRedis, S: Into + Send, { into!(node_id); diff --git a/src/commands/interfaces/geo.rs b/src/commands/interfaces/geo.rs index 52831e91..c17c6247 100644 --- a/src/commands/interfaces/geo.rs +++ b/src/commands/interfaces/geo.rs @@ -6,7 +6,6 @@ use crate::{ Any, FromRedis, GeoPosition, - GeoRadiusInfo, GeoUnit, MultipleGeoValues, MultipleValues, @@ -58,15 +57,16 @@ pub trait GeoInterface: ClientLike + Sized { /// Callers can use [as_geo_position](crate::types::RedisValue::as_geo_position) to lazily parse results as needed. /// /// - async fn geopos(&self, key: K, members: V) -> RedisResult + async fn geopos(&self, key: K, members: V) -> RedisResult where + R: FromRedis, K: Into + Send, V: TryInto + Send, V::Error: Into + Send, { into!(key); try_into!(members); - commands::geo::geopos(self, key, members).await + commands::geo::geopos(self, key, members).await?.convert() } /// Return the distance between two members in the geospatial index represented by the sorted set. @@ -90,7 +90,7 @@ pub trait GeoInterface: ClientLike + Sized { /// borders of the area specified with the center location and the maximum distance from the center (the radius). /// /// - async fn georadius( + async fn georadius( &self, key: K, position: P, @@ -103,8 +103,9 @@ pub trait GeoInterface: ClientLike + Sized { ord: Option, store: Option, storedist: Option, - ) -> RedisResult> + ) -> RedisResult where + R: FromRedis, K: Into + Send, P: Into + Send, { @@ -112,7 +113,8 @@ pub trait GeoInterface: ClientLike + Sized { commands::geo::georadius( self, key, position, radius, unit, withcoord, withdist, withhash, count, ord, store, storedist, ) - .await + .await? + .convert() } /// This command is exactly like GEORADIUS with the sole difference that instead of taking, as the center of the @@ -120,7 +122,7 @@ pub trait GeoInterface: ClientLike + Sized { /// geospatial index represented by the sorted set. /// /// - async fn georadiusbymember( + async fn georadiusbymember( &self, key: K, member: V, @@ -133,8 +135,9 @@ pub trait GeoInterface: ClientLike + Sized { ord: Option, store: Option, storedist: Option, - ) -> RedisResult> + ) -> RedisResult where + R: FromRedis, K: Into + Send, V: TryInto + Send, V::Error: Into + Send, @@ -155,14 +158,15 @@ pub trait GeoInterface: ClientLike + Sized { store, storedist, ) - .await + .await? + .convert() } /// Return the members of a sorted set populated with geospatial information using GEOADD, which are within the /// borders of the area specified by a given shape. /// /// - async fn geosearch( + async fn geosearch( &self, key: K, from_member: Option, @@ -174,8 +178,9 @@ pub trait GeoInterface: ClientLike + Sized { withcoord: bool, withdist: bool, withhash: bool, - ) -> RedisResult> + ) -> RedisResult where + R: FromRedis, K: Into + Send, { into!(key); @@ -192,7 +197,8 @@ pub trait GeoInterface: ClientLike + Sized { withdist, withhash, ) - .await + .await? + .convert() } /// This command is like GEOSEARCH, but stores the result in destination key. Returns the number of members added to diff --git a/src/commands/interfaces/keys.rs b/src/commands/interfaces/keys.rs index 2ca3c1a9..f020ae1d 100644 --- a/src/commands/interfaces/keys.rs +++ b/src/commands/interfaces/keys.rs @@ -55,18 +55,19 @@ pub trait KeysInterface: ClientLike + Sized { /// Serialize the value stored at `key` in a Redis-specific format and return it as bulk string. /// /// - async fn dump(&self, key: K) -> RedisResult + async fn dump(&self, key: K) -> RedisResult where + R: FromRedis, K: Into + Send, { into!(key); - commands::keys::dump(self, key).await + commands::keys::dump(self, key).await?.convert() } /// Create a key associated with a value that is obtained by deserializing the provided serialized value /// /// - async fn restore( + async fn restore( &self, key: K, ttl: i64, @@ -75,12 +76,15 @@ pub trait KeysInterface: ClientLike + Sized { absttl: bool, idletime: Option, frequency: Option, - ) -> RedisResult + ) -> RedisResult where + R: FromRedis, K: Into + Send, { into!(key); - commands::keys::restore(self, key, ttl, serialized, replace, absttl, idletime, frequency).await + commands::keys::restore(self, key, ttl, serialized, replace, absttl, idletime, frequency) + .await? + .convert() } /// Set a value with optional NX|XX, EX|PX|EXAT|PXAT|KEEPTTL, and GET arguments. @@ -256,7 +260,7 @@ pub trait KeysInterface: ClientLike + Sized { /// Append `value` to `key` if it's a string. /// /// - async fn append(&self, key: K, value: V) -> Result + async fn append(&self, key: K, value: V) -> RedisResult where R: FromRedis, K: Into + Send, diff --git a/src/commands/interfaces/lua.rs b/src/commands/interfaces/lua.rs index c087ca45..f6528b69 100644 --- a/src/commands/interfaces/lua.rs +++ b/src/commands/interfaces/lua.rs @@ -29,6 +29,8 @@ pub trait LuaInterface: ClientLike + Sized { /// A clustered variant of [script_load](Self::script_load) that loads the script on all primary nodes in a cluster. /// /// Returns the SHA-1 hash of the script. + #[cfg(feature = "sha-1")] + #[cfg_attr(docsrs, doc(cfg(feature = "sha-1")))] async fn script_load_cluster(&self, script: S) -> RedisResult where R: FromRedis, diff --git a/src/commands/interfaces/memory.rs b/src/commands/interfaces/memory.rs index 4cf1d043..bf5735af 100644 --- a/src/commands/interfaces/memory.rs +++ b/src/commands/interfaces/memory.rs @@ -1,7 +1,8 @@ use crate::{ commands, interfaces::{ClientLike, RedisResult}, - types::{MemoryStats, RedisKey}, + prelude::FromRedis, + types::RedisKey, }; /// Functions that implement the [memory](https://redis.io/commands#server) interface. @@ -11,14 +12,20 @@ pub trait MemoryInterface: ClientLike + Sized { /// advises about possible remedies. /// /// - async fn memory_doctor(&self) -> RedisResult { + async fn memory_doctor(&self) -> RedisResult + where + R: FromRedis, + { commands::memory::memory_doctor(self).await?.convert() } /// The MEMORY MALLOC-STATS command provides an internal statistics report from the memory allocator. /// /// - async fn memory_malloc_stats(&self) -> RedisResult { + async fn memory_malloc_stats(&self) -> RedisResult + where + R: FromRedis, + { commands::memory::memory_malloc_stats(self).await?.convert() } @@ -32,15 +39,19 @@ pub trait MemoryInterface: ClientLike + Sized { /// The MEMORY STATS command returns an Array reply about the memory usage of the server. /// /// - async fn memory_stats(&self) -> RedisResult { - commands::memory::memory_stats(self).await + async fn memory_stats(&self) -> RedisResult + where + R: FromRedis, + { + commands::memory::memory_stats(self).await?.convert() } /// The MEMORY USAGE command reports the number of bytes that a key and its value require to be stored in RAM. /// /// - async fn memory_usage(&self, key: K, samples: Option) -> RedisResult> + async fn memory_usage(&self, key: K, samples: Option) -> RedisResult where + R: FromRedis, K: Into + Send, { into!(key); diff --git a/src/commands/interfaces/mod.rs b/src/commands/interfaces/mod.rs index 851be0cd..10fa98c0 100644 --- a/src/commands/interfaces/mod.rs +++ b/src/commands/interfaces/mod.rs @@ -25,3 +25,6 @@ pub mod tracking; #[cfg(feature = "sentinel-client")] pub mod sentinel; + +#[cfg(feature = "redis-json")] +pub mod redis_json; diff --git a/src/commands/interfaces/redis_json.rs b/src/commands/interfaces/redis_json.rs new file mode 100644 index 00000000..3c4d6cf9 --- /dev/null +++ b/src/commands/interfaces/redis_json.rs @@ -0,0 +1,399 @@ +use crate::{ + commands, + interfaces::{ClientLike, RedisResult}, + types::{FromRedis, MultipleKeys, MultipleStrings, RedisKey, SetOptions}, +}; +use bytes_utils::Str; +use serde_json::Value; + +/// The client commands in the [RedisJSON](https://redis.io/docs/data-types/json/) interface. +/// +/// ## String Values +/// +/// This interface uses [serde_json::Value](serde_json::Value) as the baseline type and will convert non-string values +/// to RESP bulk strings via [to_string](serde_json::to_string). +/// +/// Some of the RedisJSON commands include the following notice in the documentation: +/// +/// > To specify a string as an array value to append, wrap the quoted string with an additional set of single quotes. +/// > Example: '"silver"'. +/// +/// The [serde_json::to_string](serde_json::to_string) functions are often the easiest way to do +/// this. The [json_quote](crate::json_quote) macro can also help. +/// +/// For example: +/// +/// ```rust +/// use fred::{json_quote, prelude::*}; +/// use serde_json::json; +/// async fn example(client: &RedisClient) -> Result<(), RedisError> { +/// let _: () = client.json_set("foo", "$", json!(["a", "b"]), None).await?; +/// +/// // need to double quote string values in this context +/// let size: i64 = client +/// .json_arrappend("foo", Some("$"), vec![ +/// json!("c").to_string(), +/// // or +/// serde_json::to_string(&json!("d"))?, +/// // or +/// json_quote!("e"), +/// ]) +/// .await?; +/// assert_eq!(size, 5); +/// Ok(()) +/// } +/// ``` +#[async_trait] +#[cfg_attr(docsrs, doc(cfg(feature = "redis-json")))] +pub trait RedisJsonInterface: ClientLike + Sized { + /// Append the json values into the array at path after the last element in it. + /// + /// + async fn json_arrappend(&self, key: K, path: P, values: Vec) -> RedisResult + where + R: FromRedis, + K: Into + Send, + P: Into + Send, + V: Into + Send, + { + into!(key, path); + let values = values.into_iter().map(|v| v.into()).collect(); + commands::redis_json::json_arrappend(self, key, path, values) + .await? + .convert() + } + + /// Search for the first occurrence of a JSON value in an array. + /// + /// + async fn json_arrindex( + &self, + key: K, + path: P, + value: V, + start: Option, + stop: Option, + ) -> RedisResult + where + R: FromRedis, + K: Into + Send, + P: Into + Send, + V: Into + Send, + { + into!(key, path, value); + commands::redis_json::json_arrindex(self, key, path, value, start, stop) + .await? + .convert() + } + + /// Insert the json values into the array at path before the index (shifts to the right). + /// + /// + async fn json_arrinsert(&self, key: K, path: P, index: i64, values: Vec) -> RedisResult + where + R: FromRedis, + K: Into + Send, + P: Into + Send, + V: Into + Send, + { + into!(key, path); + let values = values.into_iter().map(|v| v.into()).collect(); + commands::redis_json::json_arrinsert(self, key, path, index, values) + .await? + .convert() + } + + /// Report the length of the JSON array at path in key. + /// + /// + async fn json_arrlen(&self, key: K, path: Option

) -> RedisResult + where + R: FromRedis, + K: Into + Send, + P: Into + Send, + { + into!(key); + let path = path.map(|p| p.into()); + commands::redis_json::json_arrlen(self, key, path).await?.convert() + } + + /// Remove and return an element from the index in the array + /// + /// + async fn json_arrpop(&self, key: K, path: Option

, index: Option) -> RedisResult + where + R: FromRedis, + K: Into + Send, + P: Into + Send, + { + into!(key); + let path = path.map(|p| p.into()); + commands::redis_json::json_arrpop(self, key, path, index) + .await? + .convert() + } + + /// Trim an array so that it contains only the specified inclusive range of elements + /// + /// + async fn json_arrtrim(&self, key: K, path: P, start: i64, stop: i64) -> RedisResult + where + R: FromRedis, + K: Into + Send, + P: Into + Send, + { + into!(key, path); + commands::redis_json::json_arrtrim(self, key, path, start, stop) + .await? + .convert() + } + + /// Clear container values (arrays/objects) and set numeric values to 0 + /// + /// + async fn json_clear(&self, key: K, path: Option

) -> RedisResult + where + R: FromRedis, + K: Into + Send, + P: Into + Send, + { + into!(key); + let path = path.map(|p| p.into()); + commands::redis_json::json_clear(self, key, path).await?.convert() + } + + /// Report a value's memory usage in bytes + /// + /// + async fn json_debug_memory(&self, key: K, path: Option

) -> RedisResult + where + R: FromRedis, + K: Into + Send, + P: Into + Send, + { + into!(key); + let path = path.map(|p| p.into()); + commands::redis_json::json_debug_memory(self, key, path) + .await? + .convert() + } + + /// Delete a value. + /// + /// + async fn json_del(&self, key: K, path: P) -> RedisResult + where + R: FromRedis, + K: Into + Send, + P: Into + Send, + { + into!(key, path); + commands::redis_json::json_del(self, key, path).await?.convert() + } + + /// Return the value at path in JSON serialized form. + /// + /// + async fn json_get( + &self, + key: K, + indent: Option, + newline: Option, + space: Option, + paths: P, + ) -> RedisResult + where + R: FromRedis, + K: Into + Send, + I: Into + Send, + N: Into + Send, + S: Into + Send, + P: Into + Send, + { + into!(key, paths); + let indent = indent.map(|v| v.into()); + let newline = newline.map(|v| v.into()); + let space = space.map(|v| v.into()); + commands::redis_json::json_get(self, key, indent, newline, space, paths) + .await? + .convert() + } + + /// Merge a given JSON value into matching paths. + /// + /// + async fn json_merge(&self, key: K, path: P, value: V) -> RedisResult + where + R: FromRedis, + K: Into + Send, + P: Into + Send, + V: Into + Send, + { + into!(key, path, value); + commands::redis_json::json_merge(self, key, path, value) + .await? + .convert() + } + + /// Return the values at path from multiple key arguments. + /// + /// + async fn json_mget(&self, keys: K, path: P) -> RedisResult + where + R: FromRedis, + K: Into + Send, + P: Into + Send, + { + into!(keys, path); + commands::redis_json::json_mget(self, keys, path).await?.convert() + } + + /// Set or update one or more JSON values according to the specified key-path-value triplets. + /// + /// + async fn json_mset(&self, values: Vec<(K, P, V)>) -> RedisResult + where + R: FromRedis, + K: Into + Send, + P: Into + Send, + V: Into + Send, + { + let values = values + .into_iter() + .map(|(k, p, v)| (k.into(), p.into(), v.into())) + .collect(); + commands::redis_json::json_mset(self, values).await?.convert() + } + + /// Increment the number value stored at path by number + /// + /// + async fn json_numincrby(&self, key: K, path: P, value: V) -> RedisResult + where + R: FromRedis, + K: Into + Send, + P: Into + Send, + V: Into + Send, + { + into!(key, path, value); + commands::redis_json::json_numincrby(self, key, path, value) + .await? + .convert() + } + + /// Return the keys in the object that's referenced by path. + /// + /// + async fn json_objkeys(&self, key: K, path: Option

) -> RedisResult + where + R: FromRedis, + K: Into + Send, + P: Into + Send, + { + into!(key); + let path = path.map(|p| p.into()); + commands::redis_json::json_objkeys(self, key, path).await?.convert() + } + + /// Report the number of keys in the JSON object at path in key. + /// + /// + async fn json_objlen(&self, key: K, path: Option

) -> RedisResult + where + R: FromRedis, + K: Into + Send, + P: Into + Send, + { + into!(key); + let path = path.map(|p| p.into()); + commands::redis_json::json_objlen(self, key, path).await?.convert() + } + + /// Return the JSON in key in Redis serialization protocol specification form. + /// + /// + async fn json_resp(&self, key: K, path: Option

) -> RedisResult + where + R: FromRedis, + K: Into + Send, + P: Into + Send, + { + into!(key); + let path = path.map(|p| p.into()); + commands::redis_json::json_resp(self, key, path).await?.convert() + } + + /// Set the JSON value at path in key. + /// + /// + async fn json_set(&self, key: K, path: P, value: V, options: Option) -> RedisResult + where + R: FromRedis, + K: Into + Send, + P: Into + Send, + V: Into + Send, + { + into!(key, path, value); + commands::redis_json::json_set(self, key, path, value, options) + .await? + .convert() + } + + /// Append the json-string values to the string at path. + /// + /// + async fn json_strappend(&self, key: K, path: Option

, value: V) -> RedisResult + where + R: FromRedis, + K: Into + Send, + P: Into + Send, + V: Into + Send, + { + into!(key, value); + let path = path.map(|p| p.into()); + commands::redis_json::json_strappend(self, key, path, value) + .await? + .convert() + } + + /// Report the length of the JSON String at path in key. + /// + /// + async fn json_strlen(&self, key: K, path: Option

) -> RedisResult + where + R: FromRedis, + K: Into + Send, + P: Into + Send, + { + into!(key); + let path = path.map(|p| p.into()); + commands::redis_json::json_strlen(self, key, path).await?.convert() + } + + /// Toggle a Boolean value stored at path. + /// + /// + async fn json_toggle(&self, key: K, path: P) -> RedisResult + where + R: FromRedis, + K: Into + Send, + P: Into + Send, + { + into!(key, path); + commands::redis_json::json_toggle(self, key, path).await?.convert() + } + + /// Report the type of JSON value at path. + /// + /// + async fn json_type(&self, key: K, path: Option

) -> RedisResult + where + R: FromRedis, + K: Into + Send, + P: Into + Send, + { + into!(key); + let path = path.map(|p| p.into()); + commands::redis_json::json_type(self, key, path).await?.convert() + } +} diff --git a/src/commands/interfaces/server.rs b/src/commands/interfaces/server.rs index 5e8f7bc7..701dd432 100644 --- a/src/commands/interfaces/server.rs +++ b/src/commands/interfaces/server.rs @@ -53,10 +53,8 @@ pub trait HeartbeatInterface: ClientLike { if break_on_error { let _: () = _self.ping().await?; - } else { - if let Err(e) = _self.ping::<()>().await { - warn!("{}: Heartbeat ping failed with error: {:?}", _self.inner().id, e); - } + } else if let Err(e) = _self.ping::<()>().await { + warn!("{}: Heartbeat ping failed with error: {:?}", _self.inner().id, e); } } diff --git a/src/commands/interfaces/slowlog.rs b/src/commands/interfaces/slowlog.rs index 32a51772..4b8a2ee4 100644 --- a/src/commands/interfaces/slowlog.rs +++ b/src/commands/interfaces/slowlog.rs @@ -1,7 +1,7 @@ use crate::{ commands, interfaces::{ClientLike, RedisResult}, - types::SlowlogEntry, + types::FromRedis, }; /// Functions that implement the [slowlog](https://redis.io/commands#server) interface. @@ -10,15 +10,21 @@ pub trait SlowlogInterface: ClientLike + Sized { /// This command is used to read the slow queries log. /// /// - async fn slowlog_get(&self, count: Option) -> RedisResult> { - commands::slowlog::slowlog_get(self, count).await + async fn slowlog_get(&self, count: Option) -> RedisResult + where + R: FromRedis, + { + commands::slowlog::slowlog_get(self, count).await?.convert() } /// This command is used to read length of the slow queries log. /// /// - async fn slowlog_length(&self) -> RedisResult { - commands::slowlog::slowlog_length(self).await + async fn slowlog_length(&self) -> RedisResult + where + R: FromRedis, + { + commands::slowlog::slowlog_length(self).await?.convert() } /// This command is used to reset the slow queries log. diff --git a/src/commands/interfaces/sorted_sets.rs b/src/commands/interfaces/sorted_sets.rs index 4594956d..b74c419e 100644 --- a/src/commands/interfaces/sorted_sets.rs +++ b/src/commands/interfaces/sorted_sets.rs @@ -551,19 +551,22 @@ pub trait SortedSetsInterface: ClientLike + Sized { /// client. /// /// - async fn zunion( + async fn zunion( &self, keys: K, weights: W, aggregate: Option, withscores: bool, - ) -> RedisResult + ) -> RedisResult where + R: FromRedis, K: Into + Send, W: Into + Send, { into!(keys, weights); - commands::sorted_sets::zunion(self, keys, weights, aggregate, withscores).await + commands::sorted_sets::zunion(self, keys, weights, aggregate, withscores) + .await? + .convert() } /// Computes the union of the sorted sets given by the specified keys, and stores the result in `destination`. diff --git a/src/commands/interfaces/streams.rs b/src/commands/interfaces/streams.rs index 79c620cd..52ff20ee 100644 --- a/src/commands/interfaces/streams.rs +++ b/src/commands/interfaces/streams.rs @@ -242,18 +242,22 @@ pub trait StreamsInterface: ClientLike + Sized { /// between RESP2 and RESP3. /// /// ```rust no_run - /// # use fred::types::XReadResponse; - /// // borrowed from the tests. XREAD and XREADGROUP are very similar. - /// let result: XReadResponse = client - /// .xreadgroup_map("group1", "consumer1", None, None, false, "foo", ">") - /// .await?; - /// println!("Result: {:?}", result); - /// // Result: {"foo": [("1646240801081-0", {"count": 0}), ("1646240801082-0", {"count": 1}), ("1646240801082-1", {"count": 2})]} - /// - /// assert_eq!(result.len(), 1); - /// for (idx, (id, record)) in result.get("foo").unwrap().into_iter().enumerate() { - /// let value = record.get("count").expect("Failed to read count"); - /// assert_eq!(idx, *value); + /// # use fred::{prelude::*, types::XReadResponse}; + /// async fn example(client: RedisClient) -> Result<(), RedisError> { + /// // borrowed from the tests. XREAD and XREADGROUP are very similar. + /// let result: XReadResponse = client + /// .xreadgroup_map("group1", "consumer1", None, None, false, "foo", ">") + /// .await?; + /// println!("Result: {:?}", result); + /// // Result: {"foo": [("1646240801081-0", {"count": 0}), ("1646240801082-0", {"count": 1}), ("1646240801082-1", {"count": 2})]} + /// + /// assert_eq!(result.len(), 1); + /// for (idx, (id, record)) in result.get("foo").unwrap().into_iter().enumerate() { + /// let value = record.get("count").expect("Failed to read count"); + /// assert_eq!(idx, *value); + /// } + /// + /// Ok(()) /// } /// ``` // The underlying issue here isn't so much a semantic difference between RESP2 and RESP3, but rather an assumption @@ -313,9 +317,6 @@ pub trait StreamsInterface: ClientLike + Sized { // 2) "6" // ``` // - // This function (and `xreadgroup_map`) provide an easier but optional way to handle the encoding differences with - // the streams interface. - // // The underlying functions that do the RESP2 vs RESP3 conversion are public for callers as well, so one could use a // `BTreeMap` instead of a `HashMap` like so: // @@ -326,8 +327,6 @@ pub trait StreamsInterface: ClientLike + Sized { // .flatten_array_values(2) // .convert()?; // ``` - // - // Thanks for attending my TED talk. async fn xread_map( &self, count: Option, diff --git a/src/commands/interfaces/tracking.rs b/src/commands/interfaces/tracking.rs index aebc5025..fd2fb78a 100644 --- a/src/commands/interfaces/tracking.rs +++ b/src/commands/interfaces/tracking.rs @@ -1,11 +1,10 @@ use crate::{ - clients::Caching, commands, interfaces::ClientLike, prelude::RedisResult, types::{Invalidation, MultipleStrings}, }; -use tokio::sync::broadcast::Receiver as BroadcastReceiver; +use tokio::{sync::broadcast::Receiver as BroadcastReceiver, task::JoinHandle}; /// A high level interface that supports [client side caching](https://redis.io/docs/manual/client-side-caching/) via the [client tracking](https://redis.io/commands/client-tracking/) interface. #[async_trait] @@ -38,13 +37,30 @@ pub trait TrackingInterface: ClientLike + Sized { commands::tracking::stop_tracking(self).await } - /// Subscribe to invalidation messages from the server(s). - fn on_invalidation(&self) -> BroadcastReceiver { - self.inner().notifications.invalidations.load().subscribe() + /// Spawn a task that processes invalidation messages from the server. + /// + /// See [invalidation_rx](Self::invalidation_rx) for a more flexible variation of this function. + fn on_invalidation(&self, func: F) -> JoinHandle> + where + F: Fn(Invalidation) -> RedisResult<()> + Send + 'static, + { + let mut invalidation_rx = self.invalidation_rx(); + + tokio::spawn(async move { + let mut result = Ok(()); + + while let Ok(invalidation) = invalidation_rx.recv().await { + if let Err(err) = func(invalidation) { + result = Err(err); + break; + } + } + result + }) } - /// Send a `CLIENT CACHING yes|no` command before each subsequent command. - fn caching(&self, enabled: bool) -> Caching { - Caching::new(self.inner(), enabled) + /// Subscribe to invalidation messages from the server(s). + fn invalidation_rx(&self) -> BroadcastReceiver { + self.inner().notifications.invalidations.load().subscribe() } } diff --git a/src/commands/interfaces/transactions.rs b/src/commands/interfaces/transactions.rs index 30096d6d..7e122678 100644 --- a/src/commands/interfaces/transactions.rs +++ b/src/commands/interfaces/transactions.rs @@ -9,6 +9,6 @@ pub trait TransactionInterface: ClientLike + Sized { /// /// fn multi(&self) -> Transaction { - self.inner().into() + Transaction::from_inner(self.inner()) } } diff --git a/src/error.rs b/src/error.rs index 6c0434e7..f03184e0 100644 --- a/src/error.rs +++ b/src/error.rs @@ -346,7 +346,7 @@ impl RedisError { } /// Create a new empty Canceled error. - pub(crate) fn new_canceled() -> Self { + pub fn new_canceled() -> Self { RedisError::new(RedisErrorKind::Canceled, "Canceled.") } @@ -365,44 +365,29 @@ impl RedisError { /// Whether reconnection logic should be skipped in all cases. pub(crate) fn should_not_reconnect(&self) -> bool { - match self.kind { - RedisErrorKind::Config | RedisErrorKind::Url => true, - _ => false, - } + matches!(self.kind, RedisErrorKind::Config | RedisErrorKind::Url) } /// Whether the error is a `Cluster` error. pub fn is_cluster(&self) -> bool { - match self.kind { - RedisErrorKind::Cluster => true, - _ => false, - } + matches!(self.kind, RedisErrorKind::Cluster) } /// Whether the error is a `Canceled` error. pub fn is_canceled(&self) -> bool { - match self.kind { - RedisErrorKind::Canceled => true, - _ => false, - } + matches!(self.kind, RedisErrorKind::Canceled) } /// Whether the error is a `Replica` error. #[cfg(feature = "replicas")] #[cfg_attr(docsrs, doc(cfg(feature = "replicas")))] pub fn is_replica(&self) -> bool { - match self.kind { - RedisErrorKind::Replica => true, - _ => false, - } + matches!(self.kind, RedisErrorKind::Replica) } /// Whether the error is a `NotFound` error. pub fn is_not_found(&self) -> bool { - match self.kind { - RedisErrorKind::NotFound => true, - _ => false, - } + matches!(self.kind, RedisErrorKind::NotFound) } } diff --git a/src/interfaces.rs b/src/interfaces.rs index 31973f7f..300ecbd0 100644 --- a/src/interfaces.rs +++ b/src/interfaces.rs @@ -1,4 +1,5 @@ use crate::{ + clients::WithOptions, commands, error::{RedisError, RedisErrorKind}, modules::inner::RedisClientInner, @@ -8,9 +9,11 @@ use crate::{ ClientState, ClusterStateChange, ConnectHandle, + ConnectionConfig, CustomCommand, FromRedis, InfoKind, + Options, PerformanceConfig, ReconnectPolicy, RedisConfig, @@ -24,7 +27,10 @@ use crate::{ pub use redis_protocol::resp3::types::Frame as Resp3Frame; use semver::Version; use std::{convert::TryInto, sync::Arc}; -use tokio::sync::broadcast::Receiver as BroadcastReceiver; +use tokio::{ + sync::{broadcast::Receiver as BroadcastReceiver, mpsc::unbounded_channel}, + task::JoinHandle, +}; /// Type alias for `Result`. pub type RedisResult = Result; @@ -37,20 +43,22 @@ pub(crate) fn default_send_command(inner: &Arc, command: C) where C: Into, { - let command: RedisCommand = command.into(); + let mut command: RedisCommand = command.into(); _trace!( inner, "Sending command {} ({}) to router.", command.kind.to_str_debug(), command.debug_id() ); + command.inherit_options(inner); + send_to_router(inner, command.into()) } /// Send a `RouterCommand` to the router. -pub(crate) fn send_to_router(inner: &RedisClientInner, command: RouterCommand) -> Result<(), RedisError> { +pub(crate) fn send_to_router(inner: &Arc, command: RouterCommand) -> Result<(), RedisError> { inner.counters.incr_cmd_buffer_len(); - if let Err(e) = inner.command_tx.send(command) { + if let Err(e) = inner.command_tx.load().send(command) { // usually happens if the caller tries to send a command before calling `connect` or after calling `quit` inner.counters.decr_cmd_buffer_len(); @@ -61,7 +69,6 @@ pub(crate) fn send_to_router(inner: &RedisClientInner, command: RouterCommand) - command.kind.to_str_debug() ); - // if a caller manages to trigger this it means that a connection task is not running command.respond_to_caller(Err(RedisError::new( RedisErrorKind::Unknown, "Client is not initialized.", @@ -100,15 +107,12 @@ pub trait ClientLike: Clone + Send + Sized { { let mut command: RedisCommand = command.into(); self.change_command(&mut command); - default_send_command(&self.inner(), command) + default_send_command(self.inner(), command) } /// The unique ID identifying this client and underlying connections. - /// - /// All connections created by this client will use `CLIENT SETNAME` with this value unless the `no-client-setname` - /// feature is enabled. fn id(&self) -> &str { - self.inner().id.as_str() + &self.inner().id } /// Read the config used to initialize the client. @@ -121,6 +125,11 @@ pub trait ClientLike: Clone + Send + Sized { self.inner().policy.read().clone() } + /// Read the connection config used to initialize the client. + fn connection_config(&self) -> &ConnectionConfig { + self.inner().connection.as_ref() + } + /// Read the RESP version used by the client when communicating with the server. fn protocol_version(&self) -> RespVersion { if self.inner().is_resp3() { @@ -193,18 +202,40 @@ pub trait ClientLike: Clone + Send + Sized { /// /// This function returns a `JoinHandle` to a task that drives the connection. It will not resolve until the /// connection closes, and if a reconnection policy with unlimited attempts is provided then the `JoinHandle` will - /// run forever, or until `QUIT` is called. + /// run until `QUIT` is called. + /// + /// **Calling this function more than once will drop all state associated with the previous connection(s).** Any + /// pending commands on the old connection(s) will either finish or timeout, but they will not be retried on the + /// new connection(s). fn connect(&self) -> ConnectHandle { let inner = self.inner().clone(); + { + let _guard = inner._lock.lock(); + + if !inner.has_command_rx() { + _trace!(inner, "Resetting command channel before connecting."); + // another connection task is running. this will let the command channel drain, then it'll drop everything on + // the old connection/router interface. + let (tx, rx) = unbounded_channel(); + let old_command_tx = inner.swap_command_tx(tx); + inner.store_command_rx(rx, true); + utils::close_router_channel(&inner, old_command_tx); + } + } tokio::spawn(async move { + utils::clear_backchannel_state(&inner).await; let result = router_commands::start(&inner).await; + // a canceled error means we intentionally closed the client _trace!(inner, "Ending connection task with {:?}", result); - if let Err(ref e) = result { - inner.notifications.broadcast_connect(Err(e.clone())); + if let Err(ref error) = result { + if !error.is_canceled() { + inner.notifications.broadcast_connect(Err(error.clone())); + } } - utils::set_client_state(&inner.state, ClientState::Disconnected); + + utils::check_and_set_client_state(&inner.state, ClientState::Disconnecting, ClientState::Disconnected); result }) } @@ -229,32 +260,6 @@ pub trait ClientLike: Clone + Send + Sized { } } - /// Listen for reconnection notifications. - /// - /// This function can be used to receive notifications whenever the client successfully reconnects in order to - /// re-subscribe to channels, etc. - /// - /// A reconnection event is also triggered upon first connecting to the server. - fn on_reconnect(&self) -> BroadcastReceiver<()> { - self.inner().notifications.reconnect.load().subscribe() - } - - /// Listen for notifications whenever the cluster state changes. - /// - /// This is usually triggered in response to a `MOVED` error, but can also happen when connections close - /// unexpectedly. - fn on_cluster_change(&self) -> BroadcastReceiver> { - self.inner().notifications.cluster_change.load().subscribe() - } - - /// Listen for protocol and connection errors. This stream can be used to more intelligently handle errors that may - /// not appear in the request-response cycle, and so cannot be handled by response futures. - /// - /// This function does not need to be called again if the connection closes. - fn on_error(&self) -> BroadcastReceiver { - self.inner().notifications.errors.load().subscribe() - } - /// Close the connection to the Redis server. The returned future resolves when the command has been written to the /// socket, not when the connection has been fully closed. Some time after this future resolves the future /// returned by [connect](Self::connect) will resolve which indicates that the connection has been fully closed. @@ -321,6 +326,157 @@ pub trait ClientLike: Clone + Send + Sized { let args = utils::try_into_vec(args)?; commands::server::custom_raw(self, cmd, args).await } + + /// Customize various configuration options on commands. + fn with_options(&self, options: &Options) -> WithOptions { + WithOptions { + client: self.clone(), + options: options.clone(), + } + } +} + +fn spawn_event_listener(mut rx: BroadcastReceiver, func: F) -> JoinHandle> +where + T: Clone + Send + 'static, + F: Fn(T) -> RedisResult<()> + Send + 'static, +{ + tokio::spawn(async move { + let mut result = Ok(()); + + while let Ok(val) = rx.recv().await { + if let Err(err) = func(val) { + result = Err(err); + break; + } + } + + result + }) +} + +/// An interface that exposes various connection events. +/// +/// Calling [quit](crate::interfaces::ClientLike::quit) will exit or close all event streams. +pub trait EventInterface: ClientLike { + /// Spawn a task that runs the provided function on each reconnection event. + /// + /// Errors returned by `func` will exit the task. + fn on_reconnect(&self, func: F) -> JoinHandle> + where + F: Fn(Server) -> RedisResult<()> + Send + 'static, + { + let rx = self.reconnect_rx(); + spawn_event_listener(rx, func) + } + + /// Spawn a task that runs the provided function on each cluster change event. + /// + /// Errors returned by `func` will exit the task. + fn on_cluster_change(&self, func: F) -> JoinHandle> + where + F: Fn(Vec) -> RedisResult<()> + Send + 'static, + { + let rx = self.cluster_change_rx(); + spawn_event_listener(rx, func) + } + + /// Spawn a task that runs the provided function on each connection error event. + /// + /// Errors returned by `func` will exit the task. + fn on_error(&self, func: F) -> JoinHandle> + where + F: Fn(RedisError) -> RedisResult<()> + Send + 'static, + { + let rx = self.error_rx(); + spawn_event_listener(rx, func) + } + + /// Spawn a task that runs the provided function whenever the client detects an unresponsive connection. + #[cfg(feature = "check-unresponsive")] + #[cfg_attr(docsrs, doc(cfg(feature = "check-unresponsive")))] + fn on_unresponsive(&self, func: F) -> JoinHandle> + where + F: Fn(Server) -> RedisResult<()> + Send + 'static, + { + let rx = self.unresponsive_rx(); + spawn_event_listener(rx, func) + } + + /// Spawn one task that listens for all event types. + /// + /// Errors in any of the provided functions will exit the task. + fn on_any(&self, error_fn: Fe, reconnect_fn: Fr, cluster_change_fn: Fc) -> JoinHandle> + where + Fe: Fn(RedisError) -> RedisResult<()> + Send + 'static, + Fr: Fn(Server) -> RedisResult<()> + Send + 'static, + Fc: Fn(Vec) -> RedisResult<()> + Send + 'static, + { + let mut error_rx = self.error_rx(); + let mut reconnect_rx = self.reconnect_rx(); + let mut cluster_rx = self.cluster_change_rx(); + + tokio::spawn(async move { + #[allow(unused_assignments)] + let mut result = Ok(()); + + loop { + tokio::select! { + Ok(error) = error_rx.recv() => { + if let Err(err) = error_fn(error) { + result = Err(err); + break; + } + } + Ok(server) = reconnect_rx.recv() => { + if let Err(err) = reconnect_fn(server) { + result = Err(err); + break; + } + } + Ok(changes) = cluster_rx.recv() => { + if let Err(err) = cluster_change_fn(changes) { + result = Err(err); + break; + } + } + } + } + + result + }) + } + + /// Listen for reconnection notifications. + /// + /// This function can be used to receive notifications whenever the client reconnects in order to + /// re-subscribe to channels, etc. + /// + /// A reconnection event is also triggered upon first connecting to the server. + fn reconnect_rx(&self) -> BroadcastReceiver { + self.inner().notifications.reconnect.load().subscribe() + } + + /// Listen for notifications whenever the cluster state changes. + /// + /// This is usually triggered in response to a `MOVED` error, but can also happen when connections close + /// unexpectedly. + fn cluster_change_rx(&self) -> BroadcastReceiver> { + self.inner().notifications.cluster_change.load().subscribe() + } + + /// Listen for protocol and connection errors. This stream can be used to more intelligently handle errors that may + /// not appear in the request-response cycle, and so cannot be handled by response futures. + fn error_rx(&self) -> BroadcastReceiver { + self.inner().notifications.errors.load().subscribe() + } + + /// Receive a message when the client initiates a reconnection after detecting an unresponsive connection. + #[cfg(feature = "check-unresponsive")] + #[cfg_attr(docsrs, doc(cfg(feature = "check-unresponsive")))] + fn unresponsive_rx(&self) -> BroadcastReceiver { + self.inner().notifications.unresponsive.load().subscribe() + } } pub use crate::commands::interfaces::{ @@ -345,6 +501,8 @@ pub use crate::commands::interfaces::{ transactions::TransactionInterface, }; +#[cfg(feature = "redis-json")] +pub use crate::commands::interfaces::redis_json::RedisJsonInterface; #[cfg(feature = "sentinel-client")] pub use crate::commands::interfaces::sentinel::SentinelInterface; #[cfg(feature = "client-tracking")] diff --git a/src/lib.rs b/src/lib.rs index 1f208649..117e1887 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,14 @@ +#![allow(clippy::redundant_pattern_matching)] +#![allow(clippy::mutable_key_type)] +#![allow(clippy::derivable_impls)] +#![allow(clippy::enum_variant_names)] +#![allow(clippy::iter_kv_map)] +#![allow(clippy::len_without_is_empty)] +#![allow(clippy::vec_init_then_push)] +#![allow(clippy::while_let_on_iterator)] +#![allow(clippy::type_complexity)] +#![allow(clippy::too_many_arguments)] +#![allow(clippy::new_without_default)] #![cfg_attr(docsrs, deny(rustdoc::broken_intra_doc_links))] #![cfg_attr(docsrs, feature(doc_cfg))] #![cfg_attr(docsrs, allow(unused_attributes))] @@ -21,6 +32,10 @@ pub extern crate rustls; pub extern crate rustls_native_certs; #[cfg(feature = "serde-json")] pub extern crate serde_json; +pub extern crate socket2; +#[cfg(feature = "codec")] +#[cfg_attr(docsrs, doc(cfg(feature = "codec")))] +pub extern crate tokio_util; #[cfg(feature = "partial-tracing")] #[cfg_attr(docsrs, doc(cfg(feature = "partial-tracing")))] pub extern crate tracing; @@ -53,16 +68,29 @@ pub mod monitor; /// The structs and enums used by the Redis client. pub mod types; +/// Codecs for use with the [tokio codec](https://docs.rs/tokio-util/latest/tokio_util/codec/index.html) interface. +#[cfg(feature = "codec")] +#[cfg_attr(docsrs, doc(cfg(feature = "codec")))] +pub mod codec { + pub use super::protocol::public::*; +} + /// Utility functions used by the client that may also be useful to callers. pub mod util { - pub use crate::{ - s, - utils::{f64_to_redis_string, redis_string_to_f64, static_bytes, static_str}, - }; + pub use crate::utils::{f64_to_redis_string, redis_string_to_f64, static_bytes, static_str}; pub use redis_protocol::redis_keyslot; + /// A convenience constant for `None` values used as generic arguments. + /// + /// Functions that take `Option` as an argument often require the caller to use a turbofish when the + /// variant is `None`. In many cases this constant can be used instead. + // pretty much everything in this crate supports From + pub const NONE: Option = None; + /// Calculate the SHA1 hash output as a hex string. This is provided for clients that use the Lua interface to /// manage their own script caches. + #[cfg(feature = "sha-1")] + #[cfg_attr(docsrs, doc(cfg(feature = "sha-1")))] pub fn sha1_hash(input: &str) -> String { use sha1::Digest; @@ -72,21 +100,26 @@ pub mod util { } } -pub use crate::modules::{globals, pool}; +pub use crate::modules::globals; /// Convenience module to import a `RedisClient`, all possible interfaces, error types, and common argument types or /// return value types. pub mod prelude { #[cfg(feature = "partial-tracing")] + #[cfg_attr(docsrs, doc(cfg(feature = "partial-tracing")))] pub use crate::types::TracingConfig; + pub use crate::{ - clients::RedisClient, + clients::{RedisClient, RedisPool}, error::{RedisError, RedisErrorKind}, interfaces::*, types::{ Blocking, + Builder, + ConnectionConfig, Expiration, FromRedis, + Options, PerformanceConfig, ReconnectPolicy, RedisConfig, @@ -94,6 +127,11 @@ pub mod prelude { RedisValueKind, ServerConfig, SetOptions, + TcpConfig, }, }; + + #[cfg(any(feature = "enable-native-tls", feature = "enable-rustls"))] + #[cfg_attr(docsrs, doc(cfg(any(feature = "enable-rustls", feature = "enable-native-tls"))))] + pub use crate::types::{TlsConfig, TlsConnector}; } diff --git a/src/macros.rs b/src/macros.rs index d8bbd5b3..bf6ae75a 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -8,41 +8,41 @@ macro_rules! to( macro_rules! _trace( ($inner:tt, $($arg:tt)*) => { { - $inner.log_client_name_fn(log::Level::Trace, |name| { - log::trace!("{}: {}", name, format!($($arg)*)); - }) + if log::log_enabled!(log::Level::Trace) { + log::trace!("{}: {}", $inner.id, format!($($arg)*)) + } } } ); macro_rules! _debug( ($inner:tt, $($arg:tt)*) => { { - $inner.log_client_name_fn(log::Level::Debug, |name| { - log::debug!("{}: {}", name, format!($($arg)*)); - }) + if log::log_enabled!(log::Level::Debug) { + log::debug!("{}: {}", $inner.id, format!($($arg)*)) + } } } ); macro_rules! _error( ($inner:tt, $($arg:tt)*) => { { - $inner.log_client_name_fn(log::Level::Error, |name| { - log::error!("{}: {}", name, format!($($arg)*)); - }) + if log::log_enabled!(log::Level::Error) { + log::error!("{}: {}", $inner.id, format!($($arg)*)) + } } } ); macro_rules! _warn( ($inner:tt, $($arg:tt)*) => { { - $inner.log_client_name_fn(log::Level::Warn, |name| { - log::warn!("{}: {}", name, format!($($arg)*)); - }) + if log::log_enabled!(log::Level::Warn) { + log::warn!("{}: {}", $inner.id, format!($($arg)*)) + } } } ); macro_rules! _info( ($inner:tt, $($arg:tt)*) => { { - $inner.log_client_name_fn(log::Level::Info, |name| { - log::info!("{}: {}", name, format!($($arg)*)); - }) + if log::log_enabled!(log::Level::Info) { + log::info!("{}: {}", $inner.id, format!($($arg)*)) + } } } ); @@ -75,16 +75,6 @@ macro_rules! fspan ( } ); -/// Async try! for `AsyncResult`. This is rarely used on its own, but rather as a part of try_into!. -macro_rules! atry ( - ($expr:expr) => { - match $expr { - Ok(val) => val, - Err(e) => return crate::interfaces::AsyncResult::from(Err(e)) - } - } -); - /// Similar to `try`/`?`, but `continue` instead of breaking out with an error. macro_rules! try_or_continue ( ($expr:expr) => { @@ -105,30 +95,43 @@ macro_rules! static_str( } ); -/// Public macro to create a `Str` from a static str slice without copying. +/// A helper macro to wrap a string value in quotes via the [json](serde_json::json) macro. /// -/// ```rust no_run -/// // use "foo" without copying or parsing the underlying data. this uses the `Bytes::from_static` interface under the hood. -/// let _ = client.get(s!("foo")).await?; -/// ``` +/// See the [RedisJSON interface](crate::interfaces::RedisJsonInterface) for more information. +#[cfg(feature = "redis-json")] +#[cfg_attr(docsrs, doc(cfg(feature = "redis-json")))] #[macro_export] -macro_rules! s( - ($val:expr) => { - fred::util::static_str($val) +macro_rules! json_quote( + ($($json:tt)+) => { + serde_json::json!($($json)+).to_string() } ); -/// Public macro to create a `Bytes` from a static byte slice without copying. +/// Shorthand to create a [CustomCommand](crate::types::CustomCommand). /// /// ```rust no_run -/// // use "bar" without copying or parsing the underlying data. this uses the `Bytes::from_static` interface under the hood. -/// let _ = client.set(s!("foo"), b!(b"bar")).await?; +/// # use fred::{cmd, types::{CustomCommand, ClusterHash}}; +/// let _cmd = cmd!("FOO.BAR"); +/// let _cmd = cmd!("FOO.BAR", blocking: true); +/// let _cmd = cmd!("FOO.BAR", hash: ClusterHash::FirstKey); +/// let _cmd = cmd!("FOO.BAR", hash: ClusterHash::FirstKey, blocking: true); +/// // which is shorthand for +/// let _cmd = CustomCommand::new("FOO.BAR", ClusterHash::FirstKey, true); /// ``` #[macro_export] -macro_rules! b( - ($val:expr) => { - fred::util::static_bytes($val) - } +macro_rules! cmd( + ($name:expr) => { + fred::types::CustomCommand::new($name, fred::types::ClusterHash::FirstKey, false) + }; + ($name:expr, blocking: $blk:expr) => { + fred::types::CustomCommand::new($name, fred::types::ClusterHash::FirstKey, $blk) + }; + ($name:expr, hash: $hash:expr) => { + fred::types::CustomCommand::new($name, $hash, false) + }; + ($name:expr, hash: $hash:expr, blocking: $blk:expr) => { + fred::types::CustomCommand::new($name, $hash, $blk) + }; ); macro_rules! static_val( diff --git a/src/modules/backchannel.rs b/src/modules/backchannel.rs index 49203958..04396239 100644 --- a/src/modules/backchannel.rs +++ b/src/modules/backchannel.rs @@ -1,9 +1,8 @@ use crate::{ error::{RedisError, RedisErrorKind}, - globals::globals, modules::inner::RedisClientInner, - router::Connections, protocol::{command::RedisCommand, connection, connection::RedisTransport, types::Server}, + router::Connections, utils, }; use redis_protocol::resp3::types::Frame as Resp3Frame; @@ -18,24 +17,15 @@ async fn check_and_create_transport( server: &Server, ) -> Result { if let Some(ref mut transport) = backchannel.transport { - if &transport.server == server { - if transport.ping(inner).await.is_ok() { - _debug!(inner, "Using existing backchannel connection to {}", server); - return Ok(false); - } + if &transport.server == server && transport.ping(inner).await.is_ok() { + _debug!(inner, "Using existing backchannel connection to {}", server); + return Ok(false); } } backchannel.transport = None; - let mut transport = connection::create( - inner, - server.host.as_str().to_owned(), - server.port, - None, - server.tls_server_name.as_ref(), - ) - .await?; - let _ = transport.setup(inner, None).await?; + let mut transport = connection::create(inner, server, None).await?; + transport.setup(inner, None).await?; backchannel.transport = Some(transport); Ok(true) @@ -69,6 +59,17 @@ impl Backchannel { } } + /// Clear all local state that depends on the associated `Router` instance. + pub async fn clear_router_state(&mut self, inner: &Arc) { + self.connection_ids.clear(); + self.blocked = None; + + if let Some(ref mut transport) = self.transport { + let _ = transport.disconnect(inner).await; + } + self.transport = None; + } + /// Set the connection IDs from the router. pub fn update_connection_ids(&mut self, connections: &Connections) { self.connection_ids = connections.connection_ids(); @@ -162,14 +163,12 @@ impl Backchannel { command.debug_id(), server ); - let timeout = globals().default_connection_timeout_ms(); - let timeout = if timeout == 0 { - connection::DEFAULT_CONNECTION_TIMEOUT_MS - } else { - timeout - }; - utils::apply_timeout(transport.request_response(command, inner.is_resp3()), timeout).await + utils::apply_timeout( + transport.request_response(command, inner.is_resp3()), + inner.connection_timeout(), + ) + .await } else { Err(RedisError::new( RedisErrorKind::Unknown, @@ -200,36 +199,34 @@ impl Backchannel { // should this be more relaxed? Err(RedisError::new(RedisErrorKind::Unknown, "No connections are blocked.")) } - } else { - if inner.config.server.is_clustered() { - if command.kind.use_random_cluster_node() { - self.any_server().ok_or(RedisError::new( - RedisErrorKind::Unknown, - "Failed to find backchannel server.", - )) - } else { - inner.with_cluster_state(|state| { - let slot = match command.cluster_hash() { - Some(slot) => slot, - None => { - return Err(RedisError::new( - RedisErrorKind::Cluster, - "Failed to find cluster hash slot.", - )) - }, - }; - state.get_server(slot).cloned().ok_or(RedisError::new( - RedisErrorKind::Cluster, - "Failed to find cluster owner.", - )) - }) - } - } else { + } else if inner.config.server.is_clustered() { + if command.kind.use_random_cluster_node() { self.any_server().ok_or(RedisError::new( RedisErrorKind::Unknown, "Failed to find backchannel server.", )) + } else { + inner.with_cluster_state(|state| { + let slot = match command.cluster_hash() { + Some(slot) => slot, + None => { + return Err(RedisError::new( + RedisErrorKind::Cluster, + "Failed to find cluster hash slot.", + )) + }, + }; + state.get_server(slot).cloned().ok_or(RedisError::new( + RedisErrorKind::Cluster, + "Failed to find cluster owner.", + )) + }) } + } else { + self.any_server().ok_or(RedisError::new( + RedisErrorKind::Unknown, + "Failed to find backchannel server.", + )) } } } diff --git a/src/modules/globals.rs b/src/modules/globals.rs index e0a78835..35727744 100644 --- a/src/modules/globals.rs +++ b/src/modules/globals.rs @@ -10,6 +10,7 @@ use parking_lot::RwLock; /// `MOVED`, `ASK`, and `NOAUTH` errors are handled separately by the client. #[derive(Clone, Debug, Eq, PartialEq)] #[cfg(feature = "custom-reconnect-errors")] +#[cfg_attr(docsrs, doc(cfg(feature = "custom-reconnect-errors")))] pub enum ReconnectError { /// The CLUSTERDOWN prefix. ClusterDown, @@ -28,7 +29,7 @@ pub enum ReconnectError { NoReplicas, /// A case-sensitive prefix on an error message. /// - /// See [the source](https://github.com/redis/redis/blob/unstable/src/server.c#L2506-L2538) for examples. + /// See [the source](https://github.com/redis/redis/blob/fe37e4fc874a92dcf61b3b0de899ec6f674d2442/src/server.c#L1845) for examples. Custom(&'static str), } @@ -54,10 +55,6 @@ impl ReconnectError { pub(crate) struct Globals { /// The default capacity of all broadcast channels behind the `on_*` functions. pub default_broadcast_channel_capacity: Arc, - /// The default timeout to apply when connecting or initializing connections to servers. - pub default_connection_timeout_ms: Arc, - /// The default timeout to apply to connections to sentinel nodes. - pub sentinel_connection_timeout_ms: Arc, #[cfg(feature = "blocking-encoding")] /// The minimum size, in bytes, of frames that should be encoded or decoded with a blocking task. pub blocking_encode_threshold: Arc, @@ -73,10 +70,8 @@ impl Default for Globals { fn default() -> Self { Globals { default_broadcast_channel_capacity: Arc::new(AtomicUsize::new(32)), - default_connection_timeout_ms: Arc::new(AtomicUsize::new(60_000)), - sentinel_connection_timeout_ms: Arc::new(AtomicUsize::new(2_000)), #[cfg(feature = "blocking-encoding")] - blocking_encode_threshold: Arc::new(AtomicUsize::new(500_000)), + blocking_encode_threshold: Arc::new(AtomicUsize::new(50_000_000)), #[cfg(feature = "custom-reconnect-errors")] reconnect_errors: Arc::new(RwLock::new(vec![ ReconnectError::ClusterDown, @@ -84,7 +79,7 @@ impl Default for Globals { ReconnectError::ReadOnly, ])), #[cfg(feature = "check-unresponsive")] - unresponsive_interval: Arc::new(AtomicUsize::new(5_000)), + unresponsive_interval: Arc::new(AtomicUsize::new(2_000)), } } } @@ -94,14 +89,6 @@ impl Globals { read_atomic(&self.default_broadcast_channel_capacity) } - pub fn default_connection_timeout_ms(&self) -> u64 { - read_atomic(&self.default_connection_timeout_ms) as u64 - } - - pub fn sentinel_connection_timeout_ms(&self) -> usize { - read_atomic(&self.sentinel_connection_timeout_ms) - } - #[cfg(feature = "check-unresponsive")] pub fn unresponsive_interval_ms(&self) -> u64 { read_atomic(&self.unresponsive_interval) as u64 @@ -143,7 +130,7 @@ pub fn set_custom_reconnect_errors(prefixes: Vec) { /// /// See [block_in_place](https://docs.rs/tokio/1.9.0/tokio/task/fn.block_in_place.html) for more information. /// -/// Default: 500 Kb +/// Default: 50 MB #[cfg(feature = "blocking-encoding")] #[cfg_attr(docsrs, doc(cfg(feature = "blocking-encoding")))] pub fn get_blocking_encode_threshold() -> usize { @@ -157,33 +144,9 @@ pub fn set_blocking_encode_threshold(val: usize) -> usize { set_atomic(&globals().blocking_encode_threshold, val) } -/// The timeout to apply to connections to sentinel servers. -/// -/// Default: 200 ms -pub fn get_sentinel_connection_timeout_ms() -> usize { - read_atomic(&globals().sentinel_connection_timeout_ms) -} - -/// See [get_sentinel_connection_timeout_ms] for more information. -pub fn set_sentinel_connection_timeout_ms(val: usize) -> usize { - set_atomic(&globals().sentinel_connection_timeout_ms, val) -} - -/// The timeout to apply when connecting and initializing connections to servers. -/// -/// Default: 60 sec -pub fn get_default_connection_timeout_ms() -> u64 { - read_atomic(&globals().default_connection_timeout_ms) as u64 -} - -/// See [get_default_connection_timeout_ms] for more information. -pub fn set_default_connection_timeout_ms(val: u64) -> u64 { - set_atomic(&globals().default_connection_timeout_ms, val as usize) as u64 -} - /// The interval on which to check for unresponsive connections. /// -/// Default: 5 sec +/// Default: 2 sec #[cfg(feature = "check-unresponsive")] #[cfg_attr(docsrs, doc(cfg(feature = "check-unresponsive")))] pub fn get_unresponsive_interval_ms() -> u64 { @@ -197,7 +160,7 @@ pub fn set_unresponsive_interval_ms(val: u64) -> u64 { set_atomic(&globals().unresponsive_interval, val as usize) as u64 } -/// The default capacity used when creating [broadcast channels](https://docs.rs/tokio/latest/tokio/sync/broadcast/fn.channel.html) for the `on_*` notification functions. +/// The default capacity used when creating [broadcast channels](https://docs.rs/tokio/latest/tokio/sync/broadcast/fn.channel.html) in the [EventInterface](crate::interfaces::EventInterface). /// /// Default: 32 pub fn get_default_broadcast_channel_capacity() -> usize { diff --git a/src/modules/inner.rs b/src/modules/inner.rs index f3b6e85a..d63ce467 100644 --- a/src/modules/inner.rs +++ b/src/modules/inner.rs @@ -12,9 +12,8 @@ use crate::{ utils, }; use arc_swap::ArcSwap; -use arcstr::ArcStr; use futures::future::{select, Either}; -use parking_lot::RwLock; +use parking_lot::{Mutex, RwLock}; use semver::Version; use std::{ ops::DerefMut, @@ -37,6 +36,7 @@ use tokio::{ use crate::modules::metrics::MovingStats; #[cfg(feature = "check-unresponsive")] use crate::router::types::NetworkTimeout; +use bytes_utils::Str; #[cfg(feature = "replicas")] use std::collections::HashMap; @@ -48,7 +48,7 @@ use crate::types::Invalidation; pub struct Notifications { /// The client ID. - pub id: ArcStr, + pub id: Str, /// A broadcast channel for the `on_error` interface. pub errors: ArcSwap>, /// A broadcast channel for the `on_message` interface. @@ -56,7 +56,7 @@ pub struct Notifications { /// A broadcast channel for the `on_keyspace_event` interface. pub keyspace: ArcSwap>, /// A broadcast channel for the `on_reconnect` interface. - pub reconnect: ArcSwap>, + pub reconnect: ArcSwap>, /// A broadcast channel for the `on_cluster_change` interface. pub cluster_change: ArcSwap>>, /// A broadcast channel for the `on_connect` interface. @@ -68,23 +68,28 @@ pub struct Notifications { /// A broadcast channel for the `on_invalidation` interface. #[cfg(feature = "client-tracking")] pub invalidations: ArcSwap>, + /// A broadcast channel for notifying callers when servers go unresponsive. + #[cfg(feature = "check-unresponsive")] + pub unresponsive: ArcSwap>, } impl Notifications { - pub fn new(id: &ArcStr) -> Self { + pub fn new(id: &Str) -> Self { let capacity = globals().default_broadcast_channel_capacity(); Notifications { - id: id.clone(), - close: broadcast::channel(capacity).0, - errors: ArcSwap::new(Arc::new(broadcast::channel(capacity).0)), - pubsub: ArcSwap::new(Arc::new(broadcast::channel(capacity).0)), - keyspace: ArcSwap::new(Arc::new(broadcast::channel(capacity).0)), - reconnect: ArcSwap::new(Arc::new(broadcast::channel(capacity).0)), - cluster_change: ArcSwap::new(Arc::new(broadcast::channel(capacity).0)), - connect: ArcSwap::new(Arc::new(broadcast::channel(capacity).0)), + id: id.clone(), + close: broadcast::channel(capacity).0, + errors: ArcSwap::new(Arc::new(broadcast::channel(capacity).0)), + pubsub: ArcSwap::new(Arc::new(broadcast::channel(capacity).0)), + keyspace: ArcSwap::new(Arc::new(broadcast::channel(capacity).0)), + reconnect: ArcSwap::new(Arc::new(broadcast::channel(capacity).0)), + cluster_change: ArcSwap::new(Arc::new(broadcast::channel(capacity).0)), + connect: ArcSwap::new(Arc::new(broadcast::channel(capacity).0)), #[cfg(feature = "client-tracking")] - invalidations: ArcSwap::new(Arc::new(broadcast::channel(capacity).0)), + invalidations: ArcSwap::new(Arc::new(broadcast::channel(capacity).0)), + #[cfg(feature = "check-unresponsive")] + unresponsive: ArcSwap::new(Arc::new(broadcast::channel(capacity).0)), } } @@ -98,6 +103,8 @@ impl Notifications { utils::swap_new_broadcast_channel(&self.connect); #[cfg(feature = "client-tracking")] utils::swap_new_broadcast_channel(&self.invalidations); + #[cfg(feature = "check-unresponsive")] + utils::swap_new_broadcast_channel(&self.unresponsive); } pub fn broadcast_error(&self, error: RedisError) { @@ -118,8 +125,8 @@ impl Notifications { } } - pub fn broadcast_reconnect(&self) { - if let Err(_) = self.reconnect.load().send(()) { + pub fn broadcast_reconnect(&self, server: Server) { + if let Err(_) = self.reconnect.load().send(server) { debug!("{}: No `on_reconnect` listeners.", self.id); } } @@ -136,6 +143,8 @@ impl Notifications { } } + /// Interrupt any tokio `sleep` calls. + //`RedisClientInner::wait_with_interrupt` hides the subscription part from callers. pub fn broadcast_close(&self) { if let Err(_) = self.close.send(()) { debug!("{}: No `close` listeners.", self.id); @@ -148,6 +157,13 @@ impl Notifications { debug!("{}: No `on_invalidation` listeners.", self.id); } } + + #[cfg(feature = "check-unresponsive")] + pub fn broadcast_unresponsive(&self, server: Server) { + if let Err(_) = self.unresponsive.load().send(server) { + debug!("{}: No unresponsive listeners", self.id); + } + } } #[derive(Clone)] @@ -222,7 +238,7 @@ impl ServerState { } } -/// Added state associated with different server deployment types. +/// Added state associated with different server deployment types, synchronized by the router task. pub enum ServerKind { Sentinel { version: Option, @@ -361,19 +377,23 @@ impl ServerKind { } // TODO make a config option for other defaults and extend this -fn create_resolver(id: &ArcStr) -> Arc { +fn create_resolver(id: &Str) -> Arc { Arc::new(DefaultResolver::new(id)) } pub struct RedisClientInner { + /// An internal lock used to sync certain operations that should not run concurrently across tasks. + pub _lock: Mutex<()>, /// The client ID used for logging and the default `CLIENT SETNAME` value. - pub id: ArcStr, + pub id: Str, /// Whether the client uses RESP3. pub resp3: Arc, /// The state of the underlying connection. pub state: RwLock, /// Client configuration options. pub config: Arc, + /// Connection configuration options. + pub connection: Arc, /// Performance config options for the client. pub performance: ArcSwap, /// An optional reconnect policy. @@ -381,7 +401,7 @@ pub struct RedisClientInner { /// Notification channels for the event interfaces. pub notifications: Arc, /// An mpsc sender for commands to the router. - pub command_tx: CommandSender, + pub command_tx: ArcSwap, /// Temporary storage for the receiver half of the router command channel. pub command_rx: RwLock>, /// Shared counters. @@ -421,8 +441,13 @@ impl Drop for RedisClientInner { } impl RedisClientInner { - pub fn new(config: RedisConfig, perf: PerformanceConfig, policy: Option) -> Arc { - let id = ArcStr::from(format!("fred-{}", utils::random_string(10))); + pub fn new( + config: RedisConfig, + perf: PerformanceConfig, + connection: ConnectionConfig, + policy: Option, + ) -> Arc { + let id = Str::from(format!("fred-{}", utils::random_string(10))); let resolver = AsyncRwLock::new(create_resolver(&id)); let (command_tx, command_rx) = unbounded_channel(); let notifications = Arc::new(Notifications::new(&id)); @@ -437,8 +462,11 @@ impl RedisClientInner { } else { Arc::new(AtomicBool::new(false)) }; + let connection = Arc::new(connection); + let command_tx = ArcSwap::new(Arc::new(command_tx)); let inner = Arc::new(RedisClientInner { + _lock: Mutex::new(()), #[cfg(feature = "metrics")] latency_stats: RwLock::new(MovingStats::default()), #[cfg(feature = "metrics")] @@ -462,6 +490,7 @@ impl RedisClientInner { resp3, notifications, resolver, + connection, id, }); inner.spawn_timeout_task(); @@ -482,7 +511,7 @@ impl RedisClientInner { #[cfg(feature = "replicas")] pub fn ignore_replica_reconnect_errors(&self) -> bool { - self.config.replica.ignore_reconnection_errors + self.connection.replica.ignore_reconnection_errors } #[cfg(not(feature = "replicas"))] @@ -490,6 +519,22 @@ impl RedisClientInner { true } + /// Swap the command channel sender, returning the old one. + pub fn swap_command_tx(&self, tx: CommandSender) -> Arc { + self.command_tx.swap(Arc::new(tx)) + } + + /// Whether the client has the command channel receiver stored. If not then the caller can assume another + /// connection/router instance is using it. + pub fn has_command_rx(&self) -> bool { + self.command_rx.read().is_some() + } + + pub fn reset_server_state(&self) { + #[cfg(feature = "replicas")] + self.server_state.write().replicas.clear() + } + pub fn shared_resp3(&self) -> Arc { self.resp3.clone() } @@ -499,7 +544,7 @@ impl RedisClientInner { F: FnOnce(&str), { if log_enabled!(level) { - func(self.id.as_str()) + func(&self.id) } } @@ -513,7 +558,7 @@ impl RedisClientInner { } pub fn client_name(&self) -> &str { - self.id.as_str() + &self.id } pub fn num_cluster_nodes(&self) -> usize { @@ -559,9 +604,11 @@ impl RedisClientInner { self.command_rx.write().take() } - pub fn store_command_rx(&self, rx: CommandReceiver) { + pub fn store_command_rx(&self, rx: CommandReceiver, force: bool) { let mut guard = self.command_rx.write(); - *guard = Some(rx); + if guard.is_none() || force { + *guard = Some(rx); + } } pub fn is_resp3(&self) -> bool { @@ -583,8 +630,12 @@ impl RedisClientInner { self.performance.load().as_ref().clone() } + pub fn connection_config(&self) -> ConnectionConfig { + self.connection.as_ref().clone() + } + pub fn reconnect_policy(&self) -> Option { - self.policy.read().as_ref().map(|p| p.clone()) + self.policy.read().as_ref().cloned() } pub fn reset_protocol_version(&self) { @@ -597,15 +648,23 @@ impl RedisClientInner { } pub fn max_command_attempts(&self) -> u32 { - self.performance.load().max_command_attempts + self.connection.max_command_attempts } pub fn max_feed_count(&self) -> u64 { self.performance.load().max_feed_count } - pub fn default_command_timeout(&self) -> u64 { - self.performance.load().default_command_timeout_ms + pub fn default_command_timeout(&self) -> Duration { + self.performance.load().default_command_timeout + } + + pub fn connection_timeout(&self) -> Duration { + self.connection.connection_timeout + } + + pub fn internal_command_timeout(&self) -> Duration { + self.connection.internal_command_timeout } pub async fn set_blocked_server(&self, server: &Server) { diff --git a/src/modules/metrics.rs b/src/modules/metrics.rs index 974282a8..0688bd6f 100644 --- a/src/modules/metrics.rs +++ b/src/modules/metrics.rs @@ -103,7 +103,7 @@ impl<'a> From<&'a MovingStats> for Stats { stddev: stats.variance.sqrt(), min: stats.min, max: stats.max, - samples: stats.samples as u64, + samples: stats.samples, sum: stats.sum, } } diff --git a/src/modules/mocks.rs b/src/modules/mocks.rs index f2d15a71..deee82b2 100644 --- a/src/modules/mocks.rs +++ b/src/modules/mocks.rs @@ -21,7 +21,6 @@ use parking_lot::Mutex; use std::{ collections::{HashMap, VecDeque}, fmt::Debug, - sync::Arc, }; /// A wrapper type for the parts of an internal Redis command. @@ -75,10 +74,10 @@ pub trait Mocks: Debug + Send + Sync + 'static { /// #[tokio::test] /// async fn should_use_echo_mock() { /// let config = RedisConfig { -/// mocks: Arc::new(Echo), +/// mocks: Some(Arc::new(Echo)), /// ..Default::default() /// }; -/// let client = RedisClient::new(config, None, None); +/// let client = Builder::from_config(config).build().unwrap(); /// let _ = client.connect(); /// let _ = client.wait_for_connect().await.expect("Failed to connect"); /// @@ -121,10 +120,10 @@ impl Mocks for Echo { /// #[tokio::test] /// async fn should_use_echo_mock() { /// let config = RedisConfig { -/// mocks: Arc::new(SimpleMap::new()), +/// mocks: Some(Arc::new(SimpleMap::new())), /// ..Default::default() /// }; -/// let client = RedisClient::new(config, None, None); +/// let client = Builder::from_config(config).build().unwrap(); /// let _ = client.connect(); /// let _ = client.wait_for_connect().await.expect("Failed to connect"); /// @@ -226,10 +225,10 @@ impl Mocks for SimpleMap { /// async fn should_use_buffer_mock() { /// let buffer = Arc::new(Buffer::new()); /// let config = RedisConfig { -/// mocks: buffer.clone(), +/// mocks: Some(buffer.clone()), /// ..Default::default() /// }; -/// let client = RedisClient::new(config, None, None); +/// let client = Builder::from_config(config).build().unwrap(); /// let _ = client.connect(); /// let _ = client.wait_for_connect().await.expect("Failed to connect"); /// @@ -331,14 +330,15 @@ mod tests { prelude::Expiration, types::{RedisConfig, RedisValue, SetOptions}, }; + use std::sync::Arc; use tokio::task::JoinHandle; async fn create_mock_client(mocks: Arc) -> (RedisClient, JoinHandle>) { let config = RedisConfig { - mocks, + mocks: Some(mocks), ..Default::default() }; - let client = RedisClient::new(config, None, None); + let client = RedisClient::new(config, None, None, None); let jh = client.connect(); let _ = client.wait_for_connect().await.expect("Failed to connect"); diff --git a/src/modules/mod.rs b/src/modules/mod.rs index 8c588317..6337de15 100644 --- a/src/modules/mod.rs +++ b/src/modules/mod.rs @@ -3,8 +3,6 @@ pub mod backchannel; pub mod globals; pub mod inner; pub mod metrics; -/// Client pooling structs. -pub mod pool; pub mod response; #[cfg(feature = "mocks")] diff --git a/src/modules/pool.rs b/src/modules/pool.rs deleted file mode 100644 index 15f016d4..00000000 --- a/src/modules/pool.rs +++ /dev/null @@ -1,149 +0,0 @@ -use crate::{ - clients::RedisClient, - error::{RedisError, RedisErrorKind}, - interfaces::ClientLike, - types::{ConnectHandle, PerformanceConfig, ReconnectPolicy, RedisConfig}, - utils, -}; -use futures::future::{join_all, try_join_all}; -use std::{ - fmt, - ops::Deref, - sync::{atomic::AtomicUsize, Arc}, -}; - -#[cfg(feature = "dns")] -use crate::types::Resolve; - -/// The inner state used by a `RedisPool`. -#[derive(Clone)] -pub(crate) struct RedisPoolInner { - clients: Vec, - last: Arc, -} - -/// A struct to pool multiple Redis clients together into one interface that will round-robin requests among clients, -/// preferring clients with an active connection if specified. -#[derive(Clone)] -pub struct RedisPool { - inner: Arc, -} - -impl fmt::Debug for RedisPool { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_struct("RedisPool") - .field("size", &self.inner.clients.len()) - .finish() - } -} - -impl Deref for RedisPool { - type Target = RedisClient; - - fn deref(&self) -> &Self::Target { - self.next() - } -} - -impl<'a> From<&'a RedisPool> for &'a RedisClient { - fn from(p: &'a RedisPool) -> &'a RedisClient { - p.next() - } -} - -impl<'a> From<&'a RedisPool> for RedisClient { - fn from(p: &'a RedisPool) -> RedisClient { - p.next().clone() - } -} - -impl RedisPool { - /// Create a new pool without connecting to the server. - pub fn new( - config: RedisConfig, - perf: Option, - policy: Option, - size: usize, - ) -> Result { - if size > 0 { - let mut clients = Vec::with_capacity(size); - for _ in 0 .. size { - clients.push(RedisClient::new(config.clone(), perf.clone(), policy.clone())); - } - let last = Arc::new(AtomicUsize::new(0)); - - Ok(RedisPool { - inner: Arc::new(RedisPoolInner { clients, last }), - }) - } else { - Err(RedisError::new(RedisErrorKind::Config, "Pool cannot be empty.")) - } - } - - /// Read the individual clients in the pool. - pub fn clients(&self) -> &[RedisClient] { - &self.inner.clients - } - - /// Connect each client to the server, returning the task driving each connection. - /// - /// The caller is responsible for calling any `on_*` functions on each client. - pub fn connect(&self) -> Vec { - self.inner.clients.iter().map(|c| c.connect()).collect() - } - - /// Wait for all the clients to connect to the server. - pub async fn wait_for_connect(&self) -> Result<(), RedisError> { - let futures = self.inner.clients.iter().map(|c| c.wait_for_connect()); - let _ = try_join_all(futures).await?; - - Ok(()) - } - - /// Override the DNS resolution logic for all clients in the pool. - #[cfg(feature = "dns")] - #[cfg_attr(docsrs, doc(cfg(feature = "dns")))] - pub async fn set_resolver(&self, resolver: Arc) { - for client in self.inner.clients.iter() { - client.set_resolver(resolver.clone()).await; - } - } - - /// Read the size of the pool. - pub fn size(&self) -> usize { - self.inner.clients.len() - } - - /// Read the client that should run the next command. - #[cfg(feature = "pool-prefer-active")] - pub fn next(&self) -> &RedisClient { - let mut idx = utils::incr_atomic(&self.inner.last) % self.inner.clients.len(); - - for _ in 0 .. self.inner.clients.len() { - let client = &self.inner.clients[idx]; - if client.is_connected() { - return client; - } - idx = (idx + 1) % self.inner.clients.len(); - } - - &self.inner.clients[idx] - } - - /// Read the client that should run the next command. - #[cfg(not(feature = "pool-prefer-active"))] - pub fn next(&self) -> &RedisClient { - &self.inner.clients[utils::incr_atomic(&self.inner.last) % self.inner.clients.len()] - } - - /// Read the client that ran the last command. - pub fn last(&self) -> &RedisClient { - &self.inner.clients[utils::read_atomic(&self.inner.last) % self.inner.clients.len()] - } - - /// Call `QUIT` on each client in the pool. - pub async fn quit_pool(&self) { - let futures = self.inner.clients.iter().map(|c| c.quit()); - let _ = join_all(futures).await; - } -} diff --git a/src/modules/response.rs b/src/modules/response.rs index 46ee484f..4050f5db 100644 --- a/src/modules/response.rs +++ b/src/modules/response.rs @@ -1,7 +1,6 @@ use crate::{ error::{RedisError, RedisErrorKind}, - types::{RedisKey, RedisValue, NIL, QUEUED}, - utils, + types::{ClusterInfo, DatabaseMemoryStats, GeoPosition, MemoryStats, RedisKey, RedisValue, SlowlogEntry, QUEUED}, }; use bytes::Bytes; use bytes_utils::Str; @@ -10,15 +9,28 @@ use std::{ hash::{BuildHasher, Hash}, }; +#[allow(unused_imports)] +use std::any::type_name; + #[cfg(feature = "serde-json")] use serde_json::{Map, Value}; macro_rules! debug_type( ($($arg:tt)*) => { - cfg_if::cfg_if! { - if #[cfg(feature="network-logs")] { - log::trace!($($arg)*); - } + #[cfg(feature="network-logs")] + log::trace!($($arg)*); + } +); + +macro_rules! check_single_bulk_reply( + ($v:expr) => { + if $v.is_single_element_vec() { + return Self::from_value($v.pop_or_take()); + } + }; + ($t:ty, $v:expr) => { + if $v.is_single_element_vec() { + return $t::from_value($v.pop_or_take()); } } ); @@ -29,17 +41,23 @@ macro_rules! to_signed_number( RedisValue::Double(f) => Ok(f as $t), RedisValue::Integer(i) => Ok(i as $t), RedisValue::String(s) => s.parse::<$t>().map_err(|e| e.into()), - RedisValue::Null => Err(RedisError::new(RedisErrorKind::NotFound, "Cannot convert nil to number.")), RedisValue::Array(mut a) => if a.len() == 1 { match a.pop().unwrap() { RedisValue::Integer(i) => Ok(i as $t), RedisValue::String(s) => s.parse::<$t>().map_err(|e| e.into()), + #[cfg(feature = "default-nil-types")] + RedisValue::Null => Ok(0), + #[cfg(not(feature = "default-nil-types"))] RedisValue::Null => Err(RedisError::new(RedisErrorKind::NotFound, "Cannot convert nil to number.")), _ => Err(RedisError::new_parse("Cannot convert to number.")) } }else{ Err(RedisError::new_parse("Cannot convert array to number.")) } + #[cfg(feature = "default-nil-types")] + RedisValue::Null => Ok(0), + #[cfg(not(feature = "default-nil-types"))] + RedisValue::Null => Err(RedisError::new(RedisErrorKind::NotFound, "Cannot convert nil to number.")), _ => Err(RedisError::new_parse("Cannot convert to number.")), } } @@ -66,6 +84,9 @@ macro_rules! to_unsigned_number( }else{ Ok(i as $t) }, + #[cfg(feature = "default-nil-types")] + RedisValue::Null => Ok(0), + #[cfg(not(feature = "default-nil-types"))] RedisValue::Null => Err(RedisError::new(RedisErrorKind::NotFound, "Cannot convert nil to number.")), RedisValue::String(s) => s.parse::<$t>().map_err(|e| e.into()), _ => Err(RedisError::new_parse("Cannot convert to number.")) @@ -73,6 +94,9 @@ macro_rules! to_unsigned_number( }else{ Err(RedisError::new_parse("Cannot convert array to number.")) }, + #[cfg(feature = "default-nil-types")] + RedisValue::Null => Ok(0), + #[cfg(not(feature = "default-nil-types"))] RedisValue::Null => Err(RedisError::new(RedisErrorKind::NotFound, "Cannot convert nil to number.")), _ => Err(RedisError::new_parse("Cannot convert to number.")), } @@ -83,6 +107,7 @@ macro_rules! impl_signed_number ( ($t:ty) => { impl FromRedis for $t { fn from_value(value: RedisValue) -> Result<$t, RedisError> { + check_single_bulk_reply!(value); to_signed_number!($t, value) } } @@ -93,16 +118,66 @@ macro_rules! impl_unsigned_number ( ($t:ty) => { impl FromRedis for $t { fn from_value(value: RedisValue) -> Result<$t, RedisError> { + check_single_bulk_reply!(value); to_unsigned_number!($t, value) } } } ); -/// A trait used to convert various forms of [RedisValue](crate::types::RedisValue) into different types. +/// A trait used to [convert](RedisValue::convert) various forms of [RedisValue](RedisValue) into different types. +/// +/// ## Examples +/// +/// ```rust +/// # use fred::types::RedisValue; +/// # use std::collections::HashMap; +/// let foo: usize = RedisValue::String("123".into()).convert()?; +/// let foo: i64 = RedisValue::String("123".into()).convert()?; +/// let foo: String = RedisValue::String("123".into()).convert()?; +/// let foo: Vec = RedisValue::Bytes(vec![102, 111, 111].into()).convert()?; +/// let foo: Vec = RedisValue::String("foo".into()).convert()?; +/// let foo: Vec = RedisValue::Array(vec!["a".into(), "b".into()]).convert()?; +/// let foo: HashMap = +/// RedisValue::Array(vec!["a".into(), 1.into(), "b".into(), 2.into()]).convert()?; +/// let foo: (String, i64) = RedisValue::Array(vec!["a".into(), 1.into()]).convert()?; +/// let foo: Vec<(String, i64)> = +/// RedisValue::Array(vec!["a".into(), 1.into(), "b".into(), 2.into()]).convert()?; +/// // ... +/// ``` +/// +/// ## Bulk Values +/// +/// This interface can also convert single-element vectors to scalar values in certain scenarios. This is often +/// useful with commands that conditionally return bulk values, or where the number of elements in the response +/// depends on the number of arguments (`MGET`, etc). /// -/// See the [convert](crate::types::RedisValue::convert) documentation for important information regarding performance -/// considerations and examples. +/// For example: +/// +/// ```rust +/// # use fred::types::RedisValue; +/// let _: String = RedisValue::Array(vec![]).convert()?; // error +/// let _: String = RedisValue::Array(vec!["a".into()]).convert()?; // "a" +/// let _: String = RedisValue::Array(vec!["a".into(), "b".into()]).convert()?; // error +/// let _: Option = RedisValue::Array(vec![]).convert()?; // None +/// let _: Option = RedisValue::Array(vec!["a".into()]).convert()?; // Some("a") +/// let _: Option = RedisValue::Array(vec!["a".into(), "b".into()]).convert()?; // error +/// ``` +/// +/// ## The `default-nil-types` Feature Flag +/// +/// By default a `nil` value cannot be converted directly into any of the scalar types (`u8`, `String`, `Bytes`, +/// etc). In practice this often requires callers to use an `Option` or `Vec` container with commands that can return +/// `nil`. +/// +/// The `default-nil-types` feature flag can enable some further type conversion branches that treat `nil` values as +/// default values for the relevant type. For `RedisValue::Null` these include: +/// +/// * `impl FromRedis` for `String` or `Str` returns an empty string. +/// * `impl FromRedis` for `Bytes` or `Vec` returns an empty array. +/// * `impl FromRedis` for any integer or float type returns `0` +/// * `impl FromRedis` for `bool` returns `false` +/// * `impl FromRedis` for map or set types return an empty map or set. pub trait FromRedis: Sized { fn from_value(value: RedisValue) -> Result; @@ -143,6 +218,7 @@ impl_signed_number!(isize); impl FromRedis for u8 { fn from_value(value: RedisValue) -> Result { + check_single_bulk_reply!(value); to_unsigned_number!(u8, value) } @@ -160,74 +236,64 @@ impl_unsigned_number!(usize); impl FromRedis for String { fn from_value(value: RedisValue) -> Result { debug_type!("FromRedis(String): {:?}", value); - if value.is_null() { - Ok(NIL.to_owned()) - } else { - value - .into_string() - .ok_or(RedisError::new_parse("Could not convert to string.")) - } + check_single_bulk_reply!(value); + + value + .into_string() + .ok_or(RedisError::new_parse("Could not convert to string.")) } } impl FromRedis for Str { fn from_value(value: RedisValue) -> Result { debug_type!("FromRedis(Str): {:?}", value); - if value.is_null() { - Ok(utils::static_str(NIL)) - } else { - value - .into_bytes_str() - .ok_or(RedisError::new_parse("Could not convert to string.")) - } + check_single_bulk_reply!(value); + + value + .into_bytes_str() + .ok_or(RedisError::new_parse("Could not convert to string.")) } } impl FromRedis for f64 { fn from_value(value: RedisValue) -> Result { debug_type!("FromRedis(f64): {:?}", value); - if value.is_null() { - Err(RedisError::new( - RedisErrorKind::NotFound, - "Cannot convert nil response to double.", - )) - } else { - value - .as_f64() - .ok_or(RedisError::new_parse("Could not convert to double.")) - } + check_single_bulk_reply!(value); + + value + .as_f64() + .ok_or(RedisError::new_parse("Could not convert to double.")) } } impl FromRedis for f32 { fn from_value(value: RedisValue) -> Result { debug_type!("FromRedis(f32): {:?}", value); - if value.is_null() { - Err(RedisError::new( - RedisErrorKind::NotFound, - "Cannot convert nil response to float.", - )) - } else { - value - .as_f64() - .map(|f| f as f32) - .ok_or(RedisError::new_parse("Could not convert to float.")) - } + check_single_bulk_reply!(value); + + value + .as_f64() + .map(|f| f as f32) + .ok_or(RedisError::new_parse("Could not convert to float.")) } } impl FromRedis for bool { fn from_value(value: RedisValue) -> Result { debug_type!("FromRedis(bool): {:?}", value); - if value.is_null() { - Err(RedisError::new( - RedisErrorKind::NotFound, - "Cannot convert nil response to bool.", - )) + check_single_bulk_reply!(value); + + if let Some(val) = value.as_bool() { + Ok(val) } else { - value - .as_bool() - .ok_or(RedisError::new_parse("Could not convert to bool.")) + // it's not obvious how to convert the value to a bool in this block, so we go with a + // tried and true approach that i'm sure we'll never regret - JS semantics + Ok(match value { + RedisValue::String(s) => !s.is_empty(), + RedisValue::Bytes(b) => !b.is_empty(), + // everything else should be covered by `as_bool` above + _ => return Err(RedisError::new_parse("Could not convert to bool.")), + }) } } } @@ -237,8 +303,11 @@ where T: FromRedis, { fn from_value(value: RedisValue) -> Result, RedisError> { - debug_type!("FromRedis(Option): {:?}", value); - if value.is_null() { + debug_type!("FromRedis(Option<{}>): {:?}", type_name::(), value); + + if let Some(0) = value.array_len() { + Ok(None) + } else if value.is_null() { Ok(None) } else { Ok(Some(T::from_value(value)?)) @@ -249,6 +318,8 @@ where impl FromRedis for Bytes { fn from_value(value: RedisValue) -> Result { debug_type!("FromRedis(Bytes): {:?}", value); + check_single_bulk_reply!(value); + value .into_bytes() .ok_or(RedisError::new_parse("Cannot parse into bytes.")) @@ -260,14 +331,15 @@ where T: FromRedis, { fn from_value(value: RedisValue) -> Result, RedisError> { - debug_type!("FromRedis(Vec): {:?}", value); + debug_type!("FromRedis(Vec<{}>): {:?}", type_name::(), value); + match value { RedisValue::Bytes(bytes) => { T::from_owned_bytes(bytes.to_vec()).ok_or(RedisError::new_parse("Cannot convert from bytes")) }, RedisValue::String(string) => { // hacky way to check if T is bytes without consuming `string` - if T::from_owned_bytes(vec![]).is_some() { + if T::from_owned_bytes(Vec::new()).is_some() { T::from_owned_bytes(string.into_inner().to_vec()) .ok_or(RedisError::new_parse("Could not convert string to bytes.")) } else { @@ -275,7 +347,7 @@ where } }, RedisValue::Array(values) => { - if values.len() > 0 { + if !values.is_empty() { if let RedisValue::Array(_) = &values[0] { values.into_iter().map(|x| T::from_value(x)).collect() } else { @@ -288,29 +360,43 @@ where RedisValue::Map(map) => { // not being able to use collect() here is unfortunate let out = Vec::with_capacity(map.len() * 2); - map.inner().into_iter().fold(Ok(out), |out, (key, value)| { - out.and_then(|mut out| { - if T::is_tuple() { - // try to convert to a 2-element tuple since that's a common use case from `HGETALL`, etc - out.push(T::from_value(RedisValue::Array(vec![key.into(), value]))?); - } else { - out.push(T::from_value(key.into())?); - out.push(T::from_value(value)?); - } - - Ok(out) - }) + map.inner().into_iter().try_fold(out, |mut out, (key, value)| { + if T::is_tuple() { + // try to convert to a 2-element tuple since that's a common use case from `HGETALL`, etc + out.push(T::from_value(RedisValue::Array(vec![key.into(), value]))?); + } else { + out.push(T::from_value(key.into())?); + out.push(T::from_value(value)?); + } + + Ok(out) }) }, - RedisValue::Null => Ok(vec![]), RedisValue::Integer(i) => Ok(vec![T::from_value(RedisValue::Integer(i))?]), RedisValue::Double(f) => Ok(vec![T::from_value(RedisValue::Double(f))?]), RedisValue::Boolean(b) => Ok(vec![T::from_value(RedisValue::Boolean(b))?]), RedisValue::Queued => Ok(vec![T::from_value(RedisValue::from_static_str(QUEUED))?]), + RedisValue::Null => Ok(Vec::new()), } } } +impl FromRedis for [T; N] +where + T: FromRedis, +{ + fn from_value(value: RedisValue) -> Result<[T; N], RedisError> { + debug_type!("FromRedis([{}; {}]): {:?}", type_name::(), N, value); + // use the `from_value` impl for Vec + let value: Vec = value.convert()?; + let len = value.len(); + + value + .try_into() + .map_err(|_| RedisError::new_parse(format!("Failed to convert to array. Expected {}, found {}.", N, len))) + } +} + impl FromRedis for HashMap where K: FromRedisKey + Eq + Hash, @@ -318,12 +404,14 @@ where S: BuildHasher + Default, { fn from_value(value: RedisValue) -> Result { - debug_type!("FromRedis(HashMap): {:?}", value); - if value.is_null() { - return Err(RedisError::new(RedisErrorKind::NotFound, "Cannot convert nil to map.")); - } + debug_type!( + "FromRedis(HashMap<{}, {}>): {:?}", + type_name::(), + type_name::(), + value + ); - let as_map = if value.is_array() || value.is_map() { + let as_map = if value.is_array() || value.is_map() || value.is_null() { value .into_map() .map_err(|_| RedisError::new_parse("Cannot convert to map."))? @@ -331,7 +419,6 @@ where return Err(RedisError::new_parse("Cannot convert to map.")); }; - debug_type!("FromRedis(HashMap) Map: {:?}", as_map); as_map .inner() .into_iter() @@ -346,8 +433,8 @@ where S: BuildHasher + Default, { fn from_value(value: RedisValue) -> Result { - debug_type!("FromRedis(HashSet): {:?}", value); - value.into_array().into_iter().map(|v| V::from_value(v)).collect() + debug_type!("FromRedis(HashSet<{}>): {:?}", type_name::(), value); + value.into_set()?.into_iter().map(|v| V::from_value(v)).collect() } } @@ -357,8 +444,13 @@ where V: FromRedis, { fn from_value(value: RedisValue) -> Result { - debug_type!("FromRedis(BTreeMap): {:?}", value); - let as_map = if value.is_array() || value.is_map() { + debug_type!( + "FromRedis(BTreeMap<{}, {}>): {:?}", + type_name::(), + type_name::(), + value + ); + let as_map = if value.is_array() || value.is_map() || value.is_null() { value .into_map() .map_err(|_| RedisError::new_parse("Cannot convert to map."))? @@ -379,8 +471,8 @@ where V: FromRedis + Ord, { fn from_value(value: RedisValue) -> Result { - debug_type!("FromRedis(BTreeSet): {:?}", value); - value.into_array().into_iter().map(|v| V::from_value(v)).collect() + debug_type!("FromRedis(BTreeSet<{}>): {:?}", type_name::(), value); + value.into_set()?.into_iter().map(|v| V::from_value(v)).collect() } } @@ -401,7 +493,7 @@ macro_rules! impl_from_redis_tuple { $(let $name = (); n += 1;)* debug_type!("FromRedis({}-tuple): {:?}", n, values); if values.len() != n { - return Err(RedisError::new_parse("Invalid tuple dimension.")); + return Err(RedisError::new_parse(format!("Invalid tuple dimension. Expected {}, found {}.", n, values.len()))); } // since we have ownership over the values we have some freedom in how to implement this @@ -422,7 +514,7 @@ macro_rules! impl_from_redis_tuple { $(let $name = (); n += 1;)* debug_type!("FromRedis({}-tuple): {:?}", n, values); if values.len() % n != 0 { - return Err(RedisError::new_parse("Invalid tuple dimension.")) + return Err(RedisError::new_parse(format!("Invalid tuple dimension. Expected {}, found {}.", n, values.len()))); } let mut out = Vec::with_capacity(values.len() / n); @@ -468,13 +560,13 @@ impl FromRedis for Value { RedisValue::Null => Value::Null, RedisValue::Queued => QUEUED.into(), RedisValue::String(s) => { - if let Some(parsed) = utils::parse_nested_json(&s) { - parsed - } else { - s.to_string().into() - } + // check for nested json. this is particularly useful with JSON.GET + serde_json::from_str(&s).ok().unwrap_or_else(|| s.to_string().into()) + }, + RedisValue::Bytes(b) => { + let val = RedisValue::String(Str::from_inner(b)?); + Self::from_value(val)? }, - RedisValue::Bytes(b) => String::from_utf8(b.to_vec())?.into(), RedisValue::Integer(i) => i.into(), RedisValue::Double(f) => f.into(), RedisValue::Boolean(b) => b.into(), @@ -503,6 +595,36 @@ impl FromRedis for Value { } } +impl FromRedis for GeoPosition { + fn from_value(value: RedisValue) -> Result { + GeoPosition::try_from(value) + } +} + +impl FromRedis for SlowlogEntry { + fn from_value(value: RedisValue) -> Result { + SlowlogEntry::try_from(value) + } +} + +impl FromRedis for ClusterInfo { + fn from_value(value: RedisValue) -> Result { + ClusterInfo::try_from(value) + } +} + +impl FromRedis for MemoryStats { + fn from_value(value: RedisValue) -> Result { + MemoryStats::try_from(value) + } +} + +impl FromRedis for DatabaseMemoryStats { + fn from_value(value: RedisValue) -> Result { + DatabaseMemoryStats::try_from(value) + } +} + impl FromRedis for RedisKey { fn from_value(value: RedisValue) -> Result { let key = match value { @@ -590,13 +712,11 @@ impl FromRedisKey for Bytes { #[cfg(test)] mod tests { - use crate::{error::RedisError, types::RedisValue}; + use crate::types::RedisValue; use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; - #[test] - fn should_convert_null() { - let _foo: () = RedisValue::Null.convert().unwrap(); - } + #[cfg(not(feature = "default-nil-types"))] + use crate::error::RedisError; #[test] fn should_convert_signed_numeric_types() { @@ -659,7 +779,8 @@ mod tests { } #[test] - fn should_return_not_found_with_null_scalar_values() { + #[cfg(not(feature = "default-nil-types"))] + fn should_return_not_found_with_null_number_types() { let result: Result = RedisValue::Null.convert(); assert!(result.unwrap_err().is_not_found()); let result: Result = RedisValue::Null.convert(); @@ -687,9 +808,35 @@ mod tests { } #[test] - fn should_return_not_found_with_null_strings_and_bools() { - let result: Result = RedisValue::Null.convert(); - assert!(result.unwrap_err().is_not_found()); + #[cfg(feature = "default-nil-types")] + fn should_return_zero_with_null_number_types() { + assert_eq!(0, RedisValue::Null.convert::().unwrap()); + assert_eq!(0, RedisValue::Null.convert::().unwrap()); + assert_eq!(0, RedisValue::Null.convert::().unwrap()); + assert_eq!(0, RedisValue::Null.convert::().unwrap()); + assert_eq!(0, RedisValue::Null.convert::().unwrap()); + assert_eq!(0, RedisValue::Null.convert::().unwrap()); + assert_eq!(0, RedisValue::Null.convert::().unwrap()); + assert_eq!(0, RedisValue::Null.convert::().unwrap()); + assert_eq!(0, RedisValue::Null.convert::().unwrap()); + assert_eq!(0, RedisValue::Null.convert::().unwrap()); + assert_eq!(0, RedisValue::Null.convert::().unwrap()); + assert_eq!(0, RedisValue::Null.convert::().unwrap()); + assert_eq!(0.0, RedisValue::Null.convert::().unwrap()); + assert_eq!(0.0, RedisValue::Null.convert::().unwrap()); + } + + #[test] + #[cfg(feature = "default-nil-types")] + fn should_convert_null_to_false() { + assert!(!RedisValue::Null.convert::().unwrap()); + } + + #[test] + #[should_panic] + #[cfg(not(feature = "default-nil-types"))] + fn should_not_convert_null_to_false() { + assert!(!RedisValue::Null.convert::().unwrap()); } #[test] @@ -699,27 +846,27 @@ mod tests { } #[test] - fn should_convert_bools() { - let _foo: bool = RedisValue::Integer(0).convert().unwrap(); - assert_eq!(_foo, false); - let _foo: bool = RedisValue::Integer(1).convert().unwrap(); - assert_eq!(_foo, true); - let _foo: bool = RedisValue::String("0".into()).convert().unwrap(); - assert_eq!(_foo, false); - let _foo: bool = RedisValue::String("1".into()).convert().unwrap(); - assert_eq!(_foo, true); + fn should_convert_numbers_to_bools() { + let foo: bool = RedisValue::Integer(0).convert().unwrap(); + assert!(!foo); + let foo: bool = RedisValue::Integer(1).convert().unwrap(); + assert!(foo); + let foo: bool = RedisValue::String("0".into()).convert().unwrap(); + assert!(!foo); + let foo: bool = RedisValue::String("1".into()).convert().unwrap(); + assert!(foo); } #[test] fn should_convert_bytes() { - let _foo: Vec = RedisValue::Bytes("foo".as_bytes().to_vec().into()).convert().unwrap(); - assert_eq!(_foo, "foo".as_bytes().to_vec()); - let _foo: Vec = RedisValue::String("foo".into()).convert().unwrap(); - assert_eq!(_foo, "foo".as_bytes().to_vec()); - let _foo: Vec = RedisValue::Array(vec![102.into(), 111.into(), 111.into()]) + let foo: Vec = RedisValue::Bytes("foo".as_bytes().to_vec().into()).convert().unwrap(); + assert_eq!(foo, "foo".as_bytes().to_vec()); + let foo: Vec = RedisValue::String("foo".into()).convert().unwrap(); + assert_eq!(foo, "foo".as_bytes().to_vec()); + let foo: Vec = RedisValue::Array(vec![102.into(), 111.into(), 111.into()]) .convert() .unwrap(); - assert_eq!(_foo, "foo".as_bytes().to_vec()); + assert_eq!(foo, "foo".as_bytes().to_vec()); } #[test] @@ -785,4 +932,42 @@ mod tests { .unwrap(); assert_eq!(foo, vec![("a".to_owned(), 1), ("b".to_owned(), 2)]); } + + #[test] + fn should_handle_single_element_vector_to_scalar() { + assert!(RedisValue::Array(vec![]).convert::().is_err()); + assert_eq!( + RedisValue::Array(vec!["foo".into()]).convert::(), + Ok("foo".into()) + ); + assert!(RedisValue::Array(vec!["foo".into(), "bar".into()]) + .convert::() + .is_err()); + + assert_eq!(RedisValue::Array(vec![]).convert::>(), Ok(None)); + assert_eq!( + RedisValue::Array(vec!["foo".into()]).convert::>(), + Ok(Some("foo".into())) + ); + assert!(RedisValue::Array(vec!["foo".into(), "bar".into()]) + .convert::>() + .is_err()); + } + + #[test] + fn should_convert_null_to_empty_array() { + assert_eq!(Vec::::new(), RedisValue::Null.convert::>().unwrap()); + assert_eq!(Vec::::new(), RedisValue::Null.convert::>().unwrap()); + } + + #[test] + fn should_convert_to_fixed_arrays() { + let foo: [i64; 2] = RedisValue::Array(vec![1.into(), 2.into()]).convert().unwrap(); + assert_eq!(foo, [1, 2]); + + assert!(RedisValue::Array(vec![1.into(), 2.into()]) + .convert::<[i64; 3]>() + .is_err()); + assert!(RedisValue::Array(vec![]).convert::<[i64; 3]>().is_err()); + } } diff --git a/src/monitor/utils.rs b/src/monitor/utils.rs index 444738ed..06924ea4 100644 --- a/src/monitor/utils.rs +++ b/src/monitor/utils.rs @@ -5,11 +5,11 @@ use crate::{ protocol::{ codec::RedisCodec, command::{RedisCommand, RedisCommandKind}, - connection::{self, RedisTransport}, + connection::{self, ConnectionKind, RedisTransport}, types::ProtocolFrame, utils as protocol_utils, }, - types::{PerformanceConfig, RedisConfig, ServerConfig}, + types::{ConnectionConfig, PerformanceConfig, RedisConfig, ServerConfig}, }; use futures::stream::{Stream, StreamExt}; use std::sync::Arc; @@ -22,7 +22,6 @@ use tokio_util::codec::Framed; #[cfg(feature = "blocking-encoding")] use crate::globals::globals; -use crate::protocol::connection::ConnectionKind; #[cfg(feature = "blocking-encoding")] async fn handle_monitor_frame( @@ -82,7 +81,7 @@ async fn send_monitor_command( let frame = connection.request_response(command, inner.is_resp3()).await?; _trace!(inner, "Recv MONITOR response: {:?}", frame); - let response = protocol_utils::frame_to_single_result(frame)?; + let response = protocol_utils::frame_to_results(frame)?; let _ = protocol_utils::expect_ok(&response)?; Ok(connection) } @@ -125,6 +124,7 @@ pub async fn start(config: RedisConfig) -> Result, R auto_pipeline: false, ..Default::default() }; + let connection = ConnectionConfig::default(); let server = match config.server { ServerConfig::Centralized { ref server } => server.clone(), _ => { @@ -135,8 +135,8 @@ pub async fn start(config: RedisConfig) -> Result, R }, }; - let inner = RedisClientInner::new(config, perf, None); - let mut connection = connection::create(&inner, server.host.as_str().to_owned(), server.port, None, None).await?; + let inner = RedisClientInner::new(config, perf, connection, None); + let mut connection = connection::create(&inner, &server, None).await?; let _ = connection.setup(&inner, None).await?; let connection = send_monitor_command(&inner, connection).await?; diff --git a/src/protocol/cluster.rs b/src/protocol/cluster.rs index 377552a4..c7a93318 100644 --- a/src/protocol/cluster.rs +++ b/src/protocol/cluster.rs @@ -3,8 +3,9 @@ use crate::{ modules::inner::RedisClientInner, protocol::types::{Server, SlotRange}, types::RedisValue, + utils, }; -use arcstr::ArcStr; +use bytes_utils::Str; use std::{collections::HashMap, net::IpAddr, str::FromStr, sync::Arc}; #[cfg(any(feature = "enable-native-tls", feature = "enable-rustls"))] @@ -27,12 +28,12 @@ fn parse_as_u16(value: RedisValue) -> Result { } } -fn is_ip_address(value: &str) -> bool { +fn is_ip_address(value: &Str) -> bool { IpAddr::from_str(value).is_ok() } -fn check_metadata_hostname(data: &HashMap) -> Option<&str> { - data.get("hostname").map(|s| s.as_str()) +fn check_metadata_hostname(data: &HashMap) -> Option<&Str> { + data.get(&utils::static_str("hostname")) } /// Find the correct hostname for the server, preferring hostnames over IP addresses for TLS purposes. @@ -47,8 +48,10 @@ fn check_metadata_hostname(data: &HashMap) -> Option<&str> { /// 3. If `server[0]` is null, but `server[3]` has a "hostname" metadata field, then use the metadata field. Otherwise /// use `default_host`. /// +/// The `default_host` is the host that returned the `CLUSTER SLOTS` response. +/// /// -fn parse_cluster_slot_hostname(server: &[RedisValue], default_host: &str) -> Result { +fn parse_cluster_slot_hostname(server: &[RedisValue], default_host: &Str) -> Result { if server.is_empty() { return Err(RedisError::new( RedisErrorKind::Protocol, @@ -57,19 +60,19 @@ fn parse_cluster_slot_hostname(server: &[RedisValue], default_host: &str) -> Res } let should_parse_metadata = server.len() >= 4 && !server[3].is_null() && server[3].array_len().unwrap_or(0) > 0; - let metadata: HashMap = if should_parse_metadata { - // not ideal, but all the variants with data on the heap are ref counted (`Bytes`, `Str`, etc) + let metadata: HashMap = if should_parse_metadata { + // all the variants with data on the heap are ref counted (`Bytes`, `Str`, etc) server[3].clone().convert()? } else { HashMap::new() }; if server[0].is_null() { - // step 3 - Ok(check_metadata_hostname(&metadata).unwrap_or(default_host).to_owned()) + // option 3 + Ok(check_metadata_hostname(&metadata).unwrap_or(default_host).clone()) } else { - let preferred_host = match server[0].clone().into_string() { - Some(host) => host, - None => { + let preferred_host = match server[0].clone().convert::() { + Ok(host) => host, + Err(_) => { return Err(RedisError::new( RedisErrorKind::Protocol, "Invalid CLUSTER SLOTS server block hostname.", @@ -78,26 +81,22 @@ fn parse_cluster_slot_hostname(server: &[RedisValue], default_host: &str) -> Res }; if is_ip_address(&preferred_host) { - // step 2 - Ok( - check_metadata_hostname(&metadata) - .map(|s| s.to_owned()) - .unwrap_or(preferred_host), - ) + // option 2 + Ok(check_metadata_hostname(&metadata).unwrap_or(&preferred_host).clone()) } else { - // step 1 + // option 1 Ok(preferred_host) } } } /// Read the node block with format `|null, , , [metadata]` -fn parse_node_block(data: &Vec, default_host: &str) -> Option<(String, u16, ArcStr, ArcStr)> { +fn parse_node_block(data: &Vec, default_host: &Str) -> Option<(Str, u16, Str, Str)> { if data.len() < 3 { return None; } - let hostname = match parse_cluster_slot_hostname(&data, default_host) { + let hostname = match parse_cluster_slot_hostname(data, default_host) { Ok(host) => host, Err(_) => return None, }; @@ -105,19 +104,15 @@ fn parse_node_block(data: &Vec, default_host: &str) -> Option<(Strin Ok(port) => port, Err(_) => return None, }; - let primary = ArcStr::from(format!("{}:{}", hostname, port)); - let id = if let Some(s) = data[2].as_str() { - ArcStr::from(s.as_ref().to_string()) - } else { - return None; - }; + let primary = Str::from(format!("{}:{}", hostname, port)); + let id = data[2].as_bytes_str()?; Some((hostname, port, primary, id)) } /// Parse the optional trailing replica nodes in each `CLUSTER SLOTS` slot range block. #[cfg(feature = "replicas")] -fn parse_cluster_slot_replica_nodes(slot_range: Vec, default_host: &str) -> Vec { +fn parse_cluster_slot_replica_nodes(slot_range: Vec, default_host: &Str) -> Vec { slot_range .into_iter() .filter_map(|value| { @@ -130,7 +125,7 @@ fn parse_cluster_slot_replica_nodes(slot_range: Vec, default_host: & }; let (host, port) = match parse_node_block(&server_block, default_host) { - Some((h, p, _, _)) => (ArcStr::from(h), p), + Some((h, p, _, _)) => (h, p), None => { warn!("Skip replica CLUSTER SLOTS block from {}", default_host); return None; @@ -140,6 +135,7 @@ fn parse_cluster_slot_replica_nodes(slot_range: Vec, default_host: & Some(Server { host, port, + #[cfg(any(feature = "enable-native-tls", feature = "enable-rustls"))] tls_server_name: None, }) }) @@ -147,7 +143,7 @@ fn parse_cluster_slot_replica_nodes(slot_range: Vec, default_host: & } /// Parse the cluster slot range and associated server blocks. -fn parse_cluster_slot_nodes(mut slot_range: Vec, default_host: &str) -> Result { +fn parse_cluster_slot_nodes(mut slot_range: Vec, default_host: &Str) -> Result { if slot_range.len() < 3 { return Err(RedisError::new( RedisErrorKind::Protocol, @@ -163,7 +159,7 @@ fn parse_cluster_slot_nodes(mut slot_range: Vec, default_host: &str) // length checked above. format is `|null, , , [metadata]` let server_block: Vec = slot_range.pop().unwrap().convert()?; let (host, port, id) = match parse_node_block(&server_block, default_host) { - Some((h, p, _, i)) => (ArcStr::from(h), p, i), + Some((h, p, _, i)) => (h, p, i), None => { trace!("Failed to parse CLUSTER SLOTS response: {:?}", server_block); return Err(RedisError::new( @@ -180,6 +176,7 @@ fn parse_cluster_slot_nodes(mut slot_range: Vec, default_host: &str) primary: Server { host, port, + #[cfg(any(feature = "enable-native-tls", feature = "enable-rustls"))] tls_server_name: None, }, #[cfg(feature = "replicas")] @@ -189,7 +186,7 @@ fn parse_cluster_slot_nodes(mut slot_range: Vec, default_host: &str) /// Parse the entire CLUSTER SLOTS response with the provided `default_host` of the connection used to send the /// command. -pub fn parse_cluster_slots(frame: RedisValue, default_host: &str) -> Result, RedisError> { +pub fn parse_cluster_slots(frame: RedisValue, default_host: &Str) -> Result, RedisError> { let slot_ranges: Vec> = frame.convert()?; let mut out: Vec = Vec::with_capacity(slot_ranges.len()); @@ -202,7 +199,7 @@ pub fn parse_cluster_slots(frame: RedisValue, default_host: &str) -> Result, default_host: &str) { +fn replace_tls_server_names(policy: &TlsHostMapping, ranges: &mut Vec, default_host: &Str) { for slot_range in ranges.iter_mut() { slot_range.primary.set_tls_server_name(policy, default_host); @@ -215,7 +212,7 @@ fn replace_tls_server_names(policy: &TlsHostMapping, ranges: &mut Vec /// Modify the `CLUSTER SLOTS` command according to the hostname mapping policy in the `TlsHostMapping`. #[cfg(any(feature = "enable-rustls", feature = "enable-native-tls"))] -pub fn modify_cluster_slot_hostnames(inner: &Arc, ranges: &mut Vec, default_host: &str) { +pub fn modify_cluster_slot_hostnames(inner: &Arc, ranges: &mut Vec, default_host: &Str) { let policy = match inner.config.tls { Some(ref config) => &config.hostnames, None => { @@ -232,7 +229,7 @@ pub fn modify_cluster_slot_hostnames(inner: &Arc, ranges: &mut } #[cfg(not(any(feature = "enable-rustls", feature = "enable-native-tls")))] -pub fn modify_cluster_slot_hostnames(inner: &Arc, _: &mut Vec, _: &str) { +pub fn modify_cluster_slot_hostnames(inner: &Arc, _: &mut Vec, _: &Str) { _trace!(inner, "Skip modifying TLS hostnames.") } @@ -359,8 +356,8 @@ mod tests { fn should_modify_cluster_slot_hostnames_default_host_without_metadata() { let policy = TlsHostMapping::DefaultHost; let fake_data = fake_cluster_slots_without_metadata(); - let mut ranges = parse_cluster_slots(fake_data, "default-host").unwrap(); - replace_tls_server_names(&policy, &mut ranges, "default-host"); + let mut ranges = parse_cluster_slots(fake_data, &Str::from("default-host")).unwrap(); + replace_tls_server_names(&policy, &mut ranges, &Str::from("default-host")); for slot_range in ranges.iter() { assert_ne!(slot_range.primary.host, "default-host"); @@ -379,8 +376,8 @@ mod tests { fn should_not_modify_cluster_slot_hostnames_default_host_with_metadata() { let policy = TlsHostMapping::DefaultHost; let fake_data = fake_cluster_slots_with_metadata(); - let mut ranges = parse_cluster_slots(fake_data, "default-host").unwrap(); - replace_tls_server_names(&policy, &mut ranges, "default-host"); + let mut ranges = parse_cluster_slots(fake_data, &Str::from("default-host")).unwrap(); + replace_tls_server_names(&policy, &mut ranges, &Str::from("default-host")); for slot_range in ranges.iter() { assert_ne!(slot_range.primary.host, "default-host"); @@ -400,8 +397,8 @@ mod tests { fn should_modify_cluster_slot_hostnames_custom() { let policy = TlsHostMapping::Custom(Arc::new(FakeHostMapper)); let fake_data = fake_cluster_slots_without_metadata(); - let mut ranges = parse_cluster_slots(fake_data, "default-host").unwrap(); - replace_tls_server_names(&policy, &mut ranges, "default-host"); + let mut ranges = parse_cluster_slots(fake_data, &Str::from("default-host")).unwrap(); + replace_tls_server_names(&policy, &mut ranges, &Str::from("default-host")); for slot_range in ranges.iter() { assert_ne!(slot_range.primary.host, "default-host"); @@ -419,21 +416,23 @@ mod tests { fn should_parse_cluster_slots_example_metadata_hostnames() { let input = fake_cluster_slots_with_metadata(); - let actual = parse_cluster_slots(input, "bad-host").expect("Failed to parse input"); + let actual = parse_cluster_slots(input, &Str::from("bad-host")).expect("Failed to parse input"); let expected = vec![ SlotRange { start: 0, end: 5460, primary: Server { - host: "host-1.redis.example.com".into(), - port: 30001, + host: "host-1.redis.example.com".into(), + port: 30001, + #[cfg(any(feature = "enable-native-tls", feature = "enable-rustls"))] tls_server_name: None, }, id: "09dbe9720cda62f7865eabc5fd8857c5d2678366".into(), #[cfg(feature = "replicas")] replicas: vec![Server { - host: "host-2.redis.example.com".into(), - port: 30004, + host: "host-2.redis.example.com".into(), + port: 30004, + #[cfg(any(feature = "enable-native-tls", feature = "enable-rustls"))] tls_server_name: None, }], }, @@ -441,15 +440,17 @@ mod tests { start: 5461, end: 10922, primary: Server { - host: "host-3.redis.example.com".into(), - port: 30002, + host: "host-3.redis.example.com".into(), + port: 30002, + #[cfg(any(feature = "enable-native-tls", feature = "enable-rustls"))] tls_server_name: None, }, id: "c9d93d9f2c0c524ff34cc11838c2003d8c29e013".into(), #[cfg(feature = "replicas")] replicas: vec![Server { - host: "host-4.redis.example.com".into(), - port: 30005, + host: "host-4.redis.example.com".into(), + port: 30005, + #[cfg(any(feature = "enable-native-tls", feature = "enable-rustls"))] tls_server_name: None, }], }, @@ -457,15 +458,17 @@ mod tests { start: 10923, end: 16383, primary: Server { - host: "host-5.redis.example.com".into(), - port: 30003, + host: "host-5.redis.example.com".into(), + port: 30003, + #[cfg(any(feature = "enable-native-tls", feature = "enable-rustls"))] tls_server_name: None, }, id: "044ec91f325b7595e76dbcb18cc688b6a5b434a1".into(), #[cfg(feature = "replicas")] replicas: vec![Server { - host: "host-6.redis.example.com".into(), - port: 30006, + host: "host-6.redis.example.com".into(), + port: 30006, + #[cfg(any(feature = "enable-native-tls", feature = "enable-rustls"))] tls_server_name: None, }], }, @@ -477,21 +480,23 @@ mod tests { fn should_parse_cluster_slots_example_no_metadata() { let input = fake_cluster_slots_without_metadata(); - let actual = parse_cluster_slots(input, "bad-host").expect("Failed to parse input"); + let actual = parse_cluster_slots(input, &Str::from("bad-host")).expect("Failed to parse input"); let expected = vec![ SlotRange { start: 0, end: 5460, primary: Server { - host: "127.0.0.1".into(), - port: 30001, + host: "127.0.0.1".into(), + port: 30001, + #[cfg(any(feature = "enable-native-tls", feature = "enable-rustls"))] tls_server_name: None, }, id: "09dbe9720cda62f7865eabc5fd8857c5d2678366".into(), #[cfg(feature = "replicas")] replicas: vec![Server { - host: "127.0.0.1".into(), - port: 30004, + host: "127.0.0.1".into(), + port: 30004, + #[cfg(any(feature = "enable-native-tls", feature = "enable-rustls"))] tls_server_name: None, }], }, @@ -499,15 +504,17 @@ mod tests { start: 5461, end: 10922, primary: Server { - host: "127.0.0.1".into(), - port: 30002, + host: "127.0.0.1".into(), + port: 30002, + #[cfg(any(feature = "enable-native-tls", feature = "enable-rustls"))] tls_server_name: None, }, id: "c9d93d9f2c0c524ff34cc11838c2003d8c29e013".into(), #[cfg(feature = "replicas")] replicas: vec![Server { - host: "127.0.0.1".into(), - port: 30005, + host: "127.0.0.1".into(), + port: 30005, + #[cfg(any(feature = "enable-native-tls", feature = "enable-rustls"))] tls_server_name: None, }], }, @@ -515,15 +522,17 @@ mod tests { start: 10923, end: 16383, primary: Server { - host: "127.0.0.1".into(), - port: 30003, + host: "127.0.0.1".into(), + port: 30003, + #[cfg(any(feature = "enable-native-tls", feature = "enable-rustls"))] tls_server_name: None, }, id: "044ec91f325b7595e76dbcb18cc688b6a5b434a1".into(), #[cfg(feature = "replicas")] replicas: vec![Server { - host: "127.0.0.1".into(), - port: 30006, + host: "127.0.0.1".into(), + port: 30006, + #[cfg(any(feature = "enable-native-tls", feature = "enable-rustls"))] tls_server_name: None, }], }, @@ -583,21 +592,23 @@ mod tests { ]); let input = RedisValue::Array(vec![first_slot_range, second_slot_range, third_slot_range]); - let actual = parse_cluster_slots(input, "bad-host").expect("Failed to parse input"); + let actual = parse_cluster_slots(input, &Str::from("bad-host")).expect("Failed to parse input"); let expected = vec![ SlotRange { start: 0, end: 5460, primary: Server { - host: "127.0.0.1".into(), - port: 30001, + host: "127.0.0.1".into(), + port: 30001, + #[cfg(any(feature = "enable-native-tls", feature = "enable-rustls"))] tls_server_name: None, }, id: "09dbe9720cda62f7865eabc5fd8857c5d2678366".into(), #[cfg(feature = "replicas")] replicas: vec![Server { - host: "127.0.0.1".into(), - port: 30004, + host: "127.0.0.1".into(), + port: 30004, + #[cfg(any(feature = "enable-native-tls", feature = "enable-rustls"))] tls_server_name: None, }], }, @@ -605,15 +616,17 @@ mod tests { start: 5461, end: 10922, primary: Server { - host: "127.0.0.1".into(), - port: 30002, + host: "127.0.0.1".into(), + port: 30002, + #[cfg(any(feature = "enable-native-tls", feature = "enable-rustls"))] tls_server_name: None, }, id: "c9d93d9f2c0c524ff34cc11838c2003d8c29e013".into(), #[cfg(feature = "replicas")] replicas: vec![Server { - host: "127.0.0.1".into(), - port: 30005, + host: "127.0.0.1".into(), + port: 30005, + #[cfg(any(feature = "enable-native-tls", feature = "enable-rustls"))] tls_server_name: None, }], }, @@ -621,15 +634,17 @@ mod tests { start: 10923, end: 16383, primary: Server { - host: "127.0.0.1".into(), - port: 30003, + host: "127.0.0.1".into(), + port: 30003, + #[cfg(any(feature = "enable-native-tls", feature = "enable-rustls"))] tls_server_name: None, }, id: "044ec91f325b7595e76dbcb18cc688b6a5b434a1".into(), #[cfg(feature = "replicas")] replicas: vec![Server { - host: "127.0.0.1".into(), - port: 30006, + host: "127.0.0.1".into(), + port: 30006, + #[cfg(any(feature = "enable-native-tls", feature = "enable-rustls"))] tls_server_name: None, }], }, @@ -689,21 +704,23 @@ mod tests { ]); let input = RedisValue::Array(vec![first_slot_range, second_slot_range, third_slot_range]); - let actual = parse_cluster_slots(input, "fake-host").expect("Failed to parse input"); + let actual = parse_cluster_slots(input, &Str::from("fake-host")).expect("Failed to parse input"); let expected = vec![ SlotRange { start: 0, end: 5460, primary: Server { - host: "fake-host".into(), - port: 30001, + host: "fake-host".into(), + port: 30001, + #[cfg(any(feature = "enable-native-tls", feature = "enable-rustls"))] tls_server_name: None, }, id: "09dbe9720cda62f7865eabc5fd8857c5d2678366".into(), #[cfg(feature = "replicas")] replicas: vec![Server { - host: "fake-host".into(), - port: 30004, + host: "fake-host".into(), + port: 30004, + #[cfg(any(feature = "enable-native-tls", feature = "enable-rustls"))] tls_server_name: None, }], }, @@ -711,15 +728,17 @@ mod tests { start: 5461, end: 10922, primary: Server { - host: "fake-host".into(), - port: 30002, + host: "fake-host".into(), + port: 30002, + #[cfg(any(feature = "enable-native-tls", feature = "enable-rustls"))] tls_server_name: None, }, id: "c9d93d9f2c0c524ff34cc11838c2003d8c29e013".into(), #[cfg(feature = "replicas")] replicas: vec![Server { - host: "fake-host".into(), - port: 30005, + host: "fake-host".into(), + port: 30005, + #[cfg(any(feature = "enable-native-tls", feature = "enable-rustls"))] tls_server_name: None, }], }, @@ -727,15 +746,17 @@ mod tests { start: 10923, end: 16383, primary: Server { - host: "fake-host".into(), - port: 30003, + host: "fake-host".into(), + port: 30003, + #[cfg(any(feature = "enable-native-tls", feature = "enable-rustls"))] tls_server_name: None, }, id: "044ec91f325b7595e76dbcb18cc688b6a5b434a1".into(), #[cfg(feature = "replicas")] replicas: vec![Server { - host: "fake-host".into(), - port: 30006, + host: "fake-host".into(), + port: 30006, + #[cfg(any(feature = "enable-native-tls", feature = "enable-rustls"))] tls_server_name: None, }], }, diff --git a/src/protocol/codec.rs b/src/protocol/codec.rs index bcfddcf7..ceb47558 100644 --- a/src/protocol/codec.rs +++ b/src/protocol/codec.rs @@ -4,8 +4,8 @@ use crate::{ protocol::{types::ProtocolFrame, utils as protocol_utils}, utils, }; -use arcstr::ArcStr; use bytes::BytesMut; +use bytes_utils::Str; use redis_protocol::{ resp2::{decode::decode_mut as resp2_decode, encode::encode_bytes as resp2_encode, types::Frame as Resp2Frame}, resp3::{ @@ -60,7 +60,7 @@ fn resp2_encode_frame(codec: &RedisCodec, item: Resp2Frame, dst: &mut BytesMut) res ); log_resp2_frame(&codec.name, &item, true); - sample_stats(&codec, false, len as i64); + sample_stats(codec, false, len as i64); Ok(()) } @@ -79,7 +79,7 @@ fn resp2_decode_frame(codec: &RedisCodec, src: &mut BytesMut) -> Result Result, pub streaming_state: Option, @@ -221,8 +221,8 @@ impl Encoder for RedisCodec { #[cfg(not(feature = "blocking-encoding"))] fn encode(&mut self, item: ProtocolFrame, dst: &mut BytesMut) -> Result<(), Self::Error> { match item { - ProtocolFrame::Resp2(frame) => resp2_encode_frame(&self, frame, dst), - ProtocolFrame::Resp3(frame) => resp3_encode_frame(&self, frame, dst), + ProtocolFrame::Resp2(frame) => resp2_encode_frame(self, frame, dst), + ProtocolFrame::Resp3(frame) => resp3_encode_frame(self, frame, dst), } } diff --git a/src/protocol/command.rs b/src/protocol/command.rs index 069e637d..8edf6256 100644 --- a/src/protocol/command.rs +++ b/src/protocol/command.rs @@ -23,7 +23,7 @@ use std::{ mem, str, sync::{atomic::AtomicBool, Arc}, - time::Instant, + time::{Duration, Instant}, }; use tokio::sync::oneshot::{channel as oneshot_channel, Receiver as OneshotReceiver, Sender as OneshotSender}; @@ -108,7 +108,7 @@ impl<'a> TryFrom<&'a str> for ClusterErrorKind { type Error = RedisError; fn try_from(value: &'a str) -> Result { - match value.as_ref() { + match value { "MOVED" => Ok(ClusterErrorKind::Moved), "ASK" => Ok(ClusterErrorKind::Ask), _ => Err(RedisError::new( @@ -276,11 +276,7 @@ pub enum RedisCommandKind { Pfmerge, Ping, Psetex, - Psubscribe, - Pubsub, Pttl, - Publish, - Punsubscribe, Quit, Randomkey, Readonly, @@ -319,7 +315,6 @@ pub enum RedisCommandKind { Srandmember, Srem, Strlen, - Subscribe, Sunion, Sunionstore, Swapdb, @@ -328,11 +323,11 @@ pub enum RedisCommandKind { Touch, Ttl, Type, - Unsubscribe, Unlink, Unwatch, Wait, Watch, + // Streams XinfoConsumers, XinfoGroups, XinfoStream, @@ -353,6 +348,7 @@ pub enum RedisCommandKind { Xclaim, Xautoclaim, Xpending, + // Sorted Sets Zadd, Zcard, Zcount, @@ -383,18 +379,18 @@ pub enum RedisCommandKind { Zpopmax, Zpopmin, Zmpop, + // Scripts ScriptLoad, ScriptDebug, ScriptExists, ScriptFlush, ScriptKill, + // Scanning Scan, Sscan, Hscan, Zscan, - Spublish, - Ssubscribe, - Sunsubscribe, + // Function Fcall, FcallRO, FunctionDelete, @@ -405,11 +401,43 @@ pub enum RedisCommandKind { FunctionLoad, FunctionRestore, FunctionStats, + // Pubsub + Publish, PubsubChannels, PubsubNumpat, PubsubNumsub, PubsubShardchannels, PubsubShardnumsub, + Spublish, + Ssubscribe, + Sunsubscribe, + Unsubscribe, + Subscribe, + Psubscribe, + Punsubscribe, + // RedisJSON + JsonArrAppend, + JsonArrIndex, + JsonArrInsert, + JsonArrLen, + JsonArrPop, + JsonArrTrim, + JsonClear, + JsonDebugMemory, + JsonDel, + JsonGet, + JsonMerge, + JsonMGet, + JsonMSet, + JsonNumIncrBy, + JsonObjKeys, + JsonObjLen, + JsonResp, + JsonSet, + JsonStrAppend, + JsonStrLen, + JsonToggle, + JsonType, // Commands with custom state or commands that don't map directly to the server's command interface. _Hello(RespVersion), _AuthAllCluster, @@ -435,101 +463,65 @@ impl fmt::Debug for RedisCommandKind { impl RedisCommandKind { pub fn is_scan(&self) -> bool { - match *self { - RedisCommandKind::Scan => true, - _ => false, - } + matches!(*self, RedisCommandKind::Scan) } pub fn is_hscan(&self) -> bool { - match *self { - RedisCommandKind::Hscan => true, - _ => false, - } + matches!(*self, RedisCommandKind::Hscan) } pub fn is_sscan(&self) -> bool { - match *self { - RedisCommandKind::Sscan => true, - _ => false, - } + matches!(*self, RedisCommandKind::Sscan) } pub fn is_zscan(&self) -> bool { - match *self { - RedisCommandKind::Zscan => true, - _ => false, - } + matches!(*self, RedisCommandKind::Zscan) } pub fn is_hello(&self) -> bool { - match *self { - RedisCommandKind::_Hello(_) | RedisCommandKind::_HelloAllCluster(_) => true, - _ => false, - } + matches!( + *self, + RedisCommandKind::_Hello(_) | RedisCommandKind::_HelloAllCluster(_) + ) } pub fn is_auth(&self) -> bool { - match *self { - RedisCommandKind::Auth => true, - _ => false, - } + matches!(*self, RedisCommandKind::Auth) } pub fn is_value_scan(&self) -> bool { - match *self { - RedisCommandKind::Zscan | RedisCommandKind::Hscan | RedisCommandKind::Sscan => true, - _ => false, - } + matches!( + *self, + RedisCommandKind::Zscan | RedisCommandKind::Hscan | RedisCommandKind::Sscan + ) } pub fn is_multi(&self) -> bool { - match *self { - RedisCommandKind::Multi => true, - _ => false, - } + matches!(*self, RedisCommandKind::Multi) } pub fn is_exec(&self) -> bool { - match *self { - RedisCommandKind::Exec => true, - _ => false, - } + matches!(*self, RedisCommandKind::Exec) } pub fn is_discard(&self) -> bool { - match *self { - RedisCommandKind::Discard => true, - _ => false, - } + matches!(*self, RedisCommandKind::Discard) } pub fn ends_transaction(&self) -> bool { - match *self { - RedisCommandKind::Exec | RedisCommandKind::Discard => true, - _ => false, - } + matches!(*self, RedisCommandKind::Exec | RedisCommandKind::Discard) } pub fn is_mset(&self) -> bool { - match *self { - RedisCommandKind::Mset | RedisCommandKind::Msetnx => true, - _ => false, - } + matches!(*self, RedisCommandKind::Mset | RedisCommandKind::Msetnx) } pub fn is_custom(&self) -> bool { - match *self { - RedisCommandKind::_Custom(_) => true, - _ => false, - } + matches!(*self, RedisCommandKind::_Custom(_)) } pub fn closes_connection(&self) -> bool { - match *self { - RedisCommandKind::Quit | RedisCommandKind::Shutdown => true, - _ => false, - } + matches!(*self, RedisCommandKind::Quit | RedisCommandKind::Shutdown) } pub fn custom_hash_slot(&self) -> Option { @@ -704,7 +696,6 @@ impl RedisCommandKind { RedisCommandKind::Ping => "PING", RedisCommandKind::Psetex => "PSETEX", RedisCommandKind::Psubscribe => "PSUBSCRIBE", - RedisCommandKind::Pubsub => "PUBSUB", RedisCommandKind::Pttl => "PTTL", RedisCommandKind::Publish => "PUBLISH", RedisCommandKind::Punsubscribe => "PUNSUBSCRIBE", @@ -848,6 +839,28 @@ impl RedisCommandKind { RedisCommandKind::PubsubNumsub => "PUBSUB NUMSUB", RedisCommandKind::PubsubShardchannels => "PUBSUB SHARDCHANNELS", RedisCommandKind::PubsubShardnumsub => "PUBSUB SHARDNUMSUB", + RedisCommandKind::JsonArrAppend => "JSON.ARRAPPEND", + RedisCommandKind::JsonArrIndex => "JSON.ARRINDEX", + RedisCommandKind::JsonArrInsert => "JSON.ARRINSERT", + RedisCommandKind::JsonArrLen => "JSON.ARRLEN", + RedisCommandKind::JsonArrPop => "JSON.ARRPOP", + RedisCommandKind::JsonArrTrim => "JSON.ARRTRIM", + RedisCommandKind::JsonClear => "JSON.CLEAR", + RedisCommandKind::JsonDebugMemory => "JSON.DEBUG MEMORY", + RedisCommandKind::JsonDel => "JSON.DEL", + RedisCommandKind::JsonGet => "JSON.GET", + RedisCommandKind::JsonMerge => "JSON.MERGE", + RedisCommandKind::JsonMGet => "JSON.MGET", + RedisCommandKind::JsonMSet => "JSON.MSET", + RedisCommandKind::JsonNumIncrBy => "JSON.NUMINCRBY", + RedisCommandKind::JsonObjKeys => "JSON.OBJKEYS", + RedisCommandKind::JsonObjLen => "JSON.OBJLEN", + RedisCommandKind::JsonResp => "JSON.RESP", + RedisCommandKind::JsonSet => "JSON.SET", + RedisCommandKind::JsonStrAppend => "JSON.STRAPPEND", + RedisCommandKind::JsonStrLen => "JSON.STRLEN", + RedisCommandKind::JsonToggle => "JSON.TOGGLE", + RedisCommandKind::JsonType => "JSON.TYPE", RedisCommandKind::_Custom(ref kind) => &kind.cmd, } } @@ -1014,7 +1027,6 @@ impl RedisCommandKind { RedisCommandKind::Ping => "PING", RedisCommandKind::Psetex => "PSETEX", RedisCommandKind::Psubscribe => "PSUBSCRIBE", - RedisCommandKind::Pubsub => "PUBSUB", RedisCommandKind::Pttl => "PTTL", RedisCommandKind::Publish => "PUBLISH", RedisCommandKind::Punsubscribe => "PUNSUBSCRIBE", @@ -1155,6 +1167,28 @@ impl RedisCommandKind { RedisCommandKind::_AuthAllCluster => "AUTH", RedisCommandKind::_HelloAllCluster(_) => "HELLO", RedisCommandKind::_ClientTrackingCluster => "CLIENT", + RedisCommandKind::JsonArrAppend => "JSON.ARRAPPEND", + RedisCommandKind::JsonArrIndex => "JSON.ARRINDEX", + RedisCommandKind::JsonArrInsert => "JSON.ARRINSERT", + RedisCommandKind::JsonArrLen => "JSON.ARRLEN", + RedisCommandKind::JsonArrPop => "JSON.ARRPOP", + RedisCommandKind::JsonArrTrim => "JSON.ARRTRIM", + RedisCommandKind::JsonClear => "JSON.CLEAR", + RedisCommandKind::JsonDebugMemory => "JSON.DEBUG", + RedisCommandKind::JsonDel => "JSON.DEL", + RedisCommandKind::JsonGet => "JSON.GET", + RedisCommandKind::JsonMerge => "JSON.MERGE", + RedisCommandKind::JsonMGet => "JSON.MGET", + RedisCommandKind::JsonMSet => "JSON.MSET", + RedisCommandKind::JsonNumIncrBy => "JSON.NUMINCRBY", + RedisCommandKind::JsonObjKeys => "JSON.OBJKEYS", + RedisCommandKind::JsonObjLen => "JSON.OBJLEN", + RedisCommandKind::JsonResp => "JSON.RESP", + RedisCommandKind::JsonSet => "JSON.SET", + RedisCommandKind::JsonStrAppend => "JSON.STRAPPEND", + RedisCommandKind::JsonStrLen => "JSON.STRLEN", + RedisCommandKind::JsonToggle => "JSON.TOGGLE", + RedisCommandKind::JsonType => "JSON.TYPE", RedisCommandKind::_Custom(ref kind) => return kind.cmd.clone(), }; @@ -1255,6 +1289,7 @@ impl RedisCommandKind { RedisCommandKind::_FunctionDeleteCluster => "DELETE", RedisCommandKind::_FunctionRestoreCluster => "RESTORE", RedisCommandKind::_ClientTrackingCluster => "TRACKING", + RedisCommandKind::JsonDebugMemory => "MEMORY", _ => return None, }; @@ -1262,15 +1297,15 @@ impl RedisCommandKind { } pub fn use_random_cluster_node(&self) -> bool { - match self { + matches!( + *self, RedisCommandKind::Publish - | RedisCommandKind::Ping - | RedisCommandKind::Info - | RedisCommandKind::Scan - | RedisCommandKind::FlushAll - | RedisCommandKind::FlushDB => true, - _ => false, - } + | RedisCommandKind::Ping + | RedisCommandKind::Info + | RedisCommandKind::Scan + | RedisCommandKind::FlushAll + | RedisCommandKind::FlushDB + ) } pub fn is_blocking(&self) -> bool { @@ -1289,44 +1324,44 @@ impl RedisCommandKind { // default is false, but can be changed by the BLOCKING args. the RedisCommand::can_pipeline function checks the // args too. RedisCommandKind::Xread | RedisCommandKind::Xreadgroup => false, - RedisCommandKind::_Custom(ref kind) => kind.is_blocking, + RedisCommandKind::_Custom(ref kind) => kind.blocking, _ => false, } } pub fn is_all_cluster_nodes(&self) -> bool { - match *self { + matches!( + *self, RedisCommandKind::_FlushAllCluster - | RedisCommandKind::_AuthAllCluster - | RedisCommandKind::_ScriptFlushCluster - | RedisCommandKind::_ScriptKillCluster - | RedisCommandKind::_HelloAllCluster(_) - | RedisCommandKind::_ClientTrackingCluster - | RedisCommandKind::_ScriptLoadCluster - | RedisCommandKind::_FunctionFlushCluster - | RedisCommandKind::_FunctionDeleteCluster - | RedisCommandKind::_FunctionRestoreCluster - | RedisCommandKind::_FunctionLoadCluster => true, - _ => false, - } + | RedisCommandKind::_AuthAllCluster + | RedisCommandKind::_ScriptFlushCluster + | RedisCommandKind::_ScriptKillCluster + | RedisCommandKind::_HelloAllCluster(_) + | RedisCommandKind::_ClientTrackingCluster + | RedisCommandKind::_ScriptLoadCluster + | RedisCommandKind::_FunctionFlushCluster + | RedisCommandKind::_FunctionDeleteCluster + | RedisCommandKind::_FunctionRestoreCluster + | RedisCommandKind::_FunctionLoadCluster + ) } pub fn should_flush(&self) -> bool { - match self { + matches!( + *self, RedisCommandKind::Quit - | RedisCommandKind::Shutdown - | RedisCommandKind::Ping - | RedisCommandKind::Auth - | RedisCommandKind::_Hello(_) - | RedisCommandKind::Exec - | RedisCommandKind::Discard - | RedisCommandKind::Eval - | RedisCommandKind::EvalSha - | RedisCommandKind::Fcall - | RedisCommandKind::FcallRO - | RedisCommandKind::_Custom(_) => true, - _ => false, - } + | RedisCommandKind::Shutdown + | RedisCommandKind::Ping + | RedisCommandKind::Auth + | RedisCommandKind::_Hello(_) + | RedisCommandKind::Exec + | RedisCommandKind::Discard + | RedisCommandKind::Eval + | RedisCommandKind::EvalSha + | RedisCommandKind::Fcall + | RedisCommandKind::FcallRO + | RedisCommandKind::_Custom(_) + ) } pub fn can_pipeline(&self) -> bool { @@ -1355,80 +1390,80 @@ impl RedisCommandKind { } pub fn is_eval(&self) -> bool { - match *self { - RedisCommandKind::EvalSha | RedisCommandKind::Eval | RedisCommandKind::Fcall | RedisCommandKind::FcallRO => { - true - }, - _ => false, - } + matches!( + *self, + RedisCommandKind::EvalSha | RedisCommandKind::Eval | RedisCommandKind::Fcall | RedisCommandKind::FcallRO + ) } } pub struct RedisCommand { /// The command and optional subcommand name. - pub kind: RedisCommandKind, + pub kind: RedisCommandKind, /// The policy to apply when handling the response. - pub response: ResponseKind, + pub response: ResponseKind, /// The policy to use when hashing the arguments for cluster routing. - pub hasher: ClusterHash, + pub hasher: ClusterHash, /// The provided arguments. /// /// Some commands store arguments differently. Callers should use `self.args()` to account for this. - pub arguments: Vec, + pub arguments: Vec, /// A oneshot sender used to communicate with the router. - pub router_tx: Arc>>, - /// The number of times the command was sent to the server. - pub attempted: u32, + pub router_tx: Arc>>, + /// The number of times the command has been written to a socket. + pub write_attempts: u32, + /// The number of write attempts remaining. + pub attempts_remaining: u32, + /// The number of cluster redirections remaining. + pub redirections_remaining: u32, /// Whether or not the command can be pipelined. /// /// Also used for commands like XREAD that block based on an argument. - pub can_pipeline: bool, + pub can_pipeline: bool, /// Whether or not to skip backpressure checks. - pub skip_backpressure: bool, + pub skip_backpressure: bool, /// The internal ID of a transaction. - pub transaction_id: Option, + pub transaction_id: Option, + /// The timeout duration provided by the `with_options` interface. + pub timeout_dur: Option, /// Whether the command has timed out from the perspective of the caller. - pub timed_out: Arc, + pub timed_out: Arc, /// A timestamp of when the command was last written to the socket. - pub network_start: Option, + pub network_start: Option, /// Whether to route the command to a replica, if possible. - pub use_replica: bool, + pub use_replica: bool, /// Only send the command to the provided server. - pub cluster_node: Option, + pub cluster_node: Option, /// A timestamp of when the command was first created from the public interface. #[cfg(feature = "metrics")] - pub created: Instant, + pub created: Instant, /// Tracing state that has to carry over across writer/reader tasks to track certain fields (response size, etc). #[cfg(feature = "partial-tracing")] - pub traces: CommandTraces, + pub traces: CommandTraces, /// A counter to differentiate unique commands. #[cfg(feature = "debug-ids")] - pub counter: usize, + pub counter: usize, /// Whether to send a `CLIENT CACHING yes|no` before the command. #[cfg(feature = "client-tracking")] - pub caching: Option, -} - -impl Drop for RedisCommand { - fn drop(&mut self) { - if self.has_response_tx() { - debug!( - "Dropping command `{}` ({}) without responding to caller.", - self.kind.to_str_debug(), - self.debug_id() - ); - } - } + pub caching: Option, } impl fmt::Debug for RedisCommand { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("RedisCommand") + let mut formatter = f.debug_struct("RedisCommand"); + formatter .field("command", &self.kind.to_str_debug()) - .field("attempted", &self.attempted) + .field("attempts_remaining", &self.attempts_remaining) + .field("redirections_remaining", &self.redirections_remaining) .field("can_pipeline", &self.can_pipeline) - .field("arguments", &self.args()) - .finish() + .field("write_attempts", &self.write_attempts) + .field("timeout_dur", &self.timeout_dur) + .field("no_backpressure", &self.skip_backpressure); + + #[cfg(feature = "network-logs")] + formatter.field("arguments", &self.args()); + + formatter.finish() } } @@ -1444,30 +1479,43 @@ impl From for RedisCommand { } } -impl From<(RedisCommandKind, Vec)> for RedisCommand { - fn from((kind, arguments): (RedisCommandKind, Vec)) -> Self { +impl Default for RedisCommand { + fn default() -> Self { RedisCommand { - kind, - arguments, - timed_out: Arc::new(AtomicBool::new(false)), - response: ResponseKind::Respond(None), - hasher: ClusterHash::default(), - router_tx: Arc::new(Mutex::new(None)), - attempted: 0, - can_pipeline: true, - skip_backpressure: false, - transaction_id: None, - use_replica: false, - cluster_node: None, - network_start: None, + kind: RedisCommandKind::Ping, + arguments: Vec::new(), + timed_out: Arc::new(AtomicBool::new(false)), + timeout_dur: None, + response: ResponseKind::Respond(None), + hasher: ClusterHash::default(), + router_tx: Arc::new(Mutex::new(None)), + attempts_remaining: 0, + redirections_remaining: 0, + can_pipeline: true, + skip_backpressure: false, + transaction_id: None, + use_replica: false, + cluster_node: None, + network_start: None, + write_attempts: 0, #[cfg(feature = "metrics")] - created: Instant::now(), + created: Instant::now(), #[cfg(feature = "partial-tracing")] - traces: CommandTraces::default(), + traces: CommandTraces::default(), #[cfg(feature = "debug-ids")] - counter: command_counter(), + counter: command_counter(), #[cfg(feature = "client-tracking")] - caching: None, + caching: None, + } + } +} + +impl From<(RedisCommandKind, Vec)> for RedisCommand { + fn from((kind, arguments): (RedisCommandKind, Vec)) -> Self { + RedisCommand { + kind, + arguments, + ..RedisCommand::default() } } } @@ -1477,25 +1525,8 @@ impl From<(RedisCommandKind, Vec, ResponseSender)> for RedisCommand RedisCommand { kind, arguments, - timed_out: Arc::new(AtomicBool::new(false)), response: ResponseKind::Respond(Some(tx)), - hasher: ClusterHash::default(), - router_tx: Arc::new(Mutex::new(None)), - attempted: 0, - can_pipeline: true, - skip_backpressure: false, - transaction_id: None, - use_replica: false, - cluster_node: None, - network_start: None, - #[cfg(feature = "metrics")] - created: Instant::now(), - #[cfg(feature = "partial-tracing")] - traces: CommandTraces::default(), - #[cfg(feature = "debug-ids")] - counter: command_counter(), - #[cfg(feature = "client-tracking")] - caching: None, + ..RedisCommand::default() } } } @@ -1506,24 +1537,7 @@ impl From<(RedisCommandKind, Vec, ResponseKind)> for RedisCommand { kind, arguments, response, - timed_out: Arc::new(AtomicBool::new(false)), - hasher: ClusterHash::default(), - router_tx: Arc::new(Mutex::new(None)), - attempted: 0, - can_pipeline: true, - skip_backpressure: false, - transaction_id: None, - use_replica: false, - network_start: None, - cluster_node: None, - #[cfg(feature = "metrics")] - created: Instant::now(), - #[cfg(feature = "partial-tracing")] - traces: CommandTraces::default(), - #[cfg(feature = "debug-ids")] - counter: command_counter(), - #[cfg(feature = "client-tracking")] - caching: None, + ..RedisCommand::default() } } } @@ -1534,52 +1548,16 @@ impl RedisCommand { RedisCommand { kind, arguments: args, - timed_out: Arc::new(AtomicBool::new(false)), - response: ResponseKind::Skip, - hasher: ClusterHash::FirstKey, - router_tx: Arc::new(Mutex::new(None)), - attempted: 0, - can_pipeline: true, - skip_backpressure: false, - transaction_id: None, - use_replica: false, - cluster_node: None, - network_start: None, - #[cfg(feature = "metrics")] - created: Instant::now(), - #[cfg(feature = "partial-tracing")] - traces: CommandTraces::default(), - #[cfg(feature = "debug-ids")] - counter: command_counter(), - #[cfg(feature = "client-tracking")] - caching: None, + ..RedisCommand::default() } } /// Create a new empty `ASKING` command. pub fn new_asking(hash_slot: u16) -> Self { RedisCommand { - kind: RedisCommandKind::Asking, - arguments: Vec::new(), - response: ResponseKind::Skip, - hasher: ClusterHash::Custom(hash_slot), - timed_out: Arc::new(AtomicBool::new(false)), - router_tx: Arc::new(Mutex::new(None)), - attempted: 0, - can_pipeline: false, - skip_backpressure: true, - transaction_id: None, - use_replica: false, - cluster_node: None, - network_start: None, - #[cfg(feature = "metrics")] - created: Instant::now(), - #[cfg(feature = "partial-tracing")] - traces: CommandTraces::default(), - #[cfg(feature = "debug-ids")] - counter: command_counter(), - #[cfg(feature = "client-tracking")] - caching: None, + kind: RedisCommandKind::Asking, + hasher: ClusterHash::Custom(hash_slot), + ..RedisCommand::default() } } @@ -1604,24 +1582,29 @@ impl RedisCommand { } /// Whether errors writing the command should be returned to the caller. - pub fn should_send_write_error(&self, inner: &Arc) -> bool { - self.attempted >= inner.max_command_attempts() || inner.policy.read().is_none() - } - - /// Mark the command to only run once, returning connection write errors to the caller immediately. - pub fn set_try_once(&mut self, inner: &Arc) { - self.attempted = inner.max_command_attempts(); + pub fn should_finish_with_error(&self, inner: &Arc) -> bool { + self.attempts_remaining == 0 || inner.policy.read().is_none() } /// Increment and check the number of write attempts. - pub fn incr_check_attempted(&mut self, max: u32) -> Result<(), RedisError> { - self.attempted += 1; - if max > 0 && self.attempted > max { + pub fn decr_check_attempted(&mut self) -> Result<(), RedisError> { + if self.attempts_remaining == 0 { Err(RedisError::new( RedisErrorKind::Unknown, "Too many failed write attempts.", )) } else { + self.attempts_remaining -= 1; + Ok(()) + } + } + + /// + pub fn decr_check_redirections(&mut self) -> Result<(), RedisError> { + if self.redirections_remaining == 0 { + Err(RedisError::new(RedisErrorKind::Unknown, "Too many redirections.")) + } else { + self.redirections_remaining -= 1; Ok(()) } } @@ -1705,30 +1688,49 @@ impl RedisCommand { /// Note: this will **not** clone the router channel. pub fn duplicate(&self, response: ResponseKind) -> Self { RedisCommand { - timed_out: self.timed_out.clone(), + timed_out: Arc::new(AtomicBool::new(false)), kind: self.kind.clone(), arguments: self.arguments.clone(), hasher: self.hasher.clone(), - transaction_id: self.transaction_id.clone(), - attempted: self.attempted, + transaction_id: self.transaction_id, + attempts_remaining: self.attempts_remaining, + redirections_remaining: self.redirections_remaining, + timeout_dur: self.timeout_dur, can_pipeline: self.can_pipeline, skip_backpressure: self.skip_backpressure, router_tx: self.router_tx.clone(), cluster_node: self.cluster_node.clone(), response, use_replica: self.use_replica, + write_attempts: self.write_attempts, + network_start: self.network_start, #[cfg(feature = "metrics")] - created: self.created.clone(), - network_start: self.network_start.clone(), + created: Instant::now(), #[cfg(feature = "partial-tracing")] traces: CommandTraces::default(), #[cfg(feature = "debug-ids")] - counter: self.counter, + counter: command_counter(), #[cfg(feature = "client-tracking")] caching: self.caching.clone(), } } + /// Inherit connection and perf settings from the client. + pub fn inherit_options(&mut self, inner: &Arc) { + if self.attempts_remaining == 0 { + self.attempts_remaining = inner.connection.max_command_attempts; + } + if self.redirections_remaining == 0 { + self.redirections_remaining = inner.connection.max_redirections; + } + if self.timeout_dur.is_none() { + let default_dur = inner.default_command_timeout(); + if !default_dur.is_zero() { + self.timeout_dur = Some(default_dur); + } + } + } + /// Take the command tracing state for the `queued` span. #[cfg(feature = "full-tracing")] pub fn take_queued_span(&mut self) -> Option { @@ -1770,6 +1772,12 @@ impl RedisCommand { } } + /// Finish the command, responding to both the caller and router. + pub fn finish(mut self, inner: &Arc, result: Result) { + self.respond_to_caller(result); + self.respond_to_router(inner, RouterResponse::Continue); + } + /// Read the first key in the arguments according to the `FirstKey` cluster hash policy. pub fn first_key(&self) -> Option<&[u8]> { ClusterHash::FirstKey.find_key(self.args()) @@ -1787,7 +1795,7 @@ impl RedisCommand { /// Read the custom hash slot assigned to a scan operation. pub fn scan_hash_slot(&self) -> Option { match self.response { - ResponseKind::KeyScan(ref inner) => inner.hash_slot.clone(), + ResponseKind::KeyScan(ref inner) => inner.hash_slot, _ => None, } } @@ -1836,41 +1844,23 @@ pub enum RouterCommand { /// Send a command to the server. Command(RedisCommand), /// Send a pipelined series of commands to the server. - /// - /// Commands may finish out of order in the following cluster scenario: - /// 1. The client sends `GET foo`. - /// 2. The client sends `GET bar`. - /// 3. The client sends `GET baz`. - /// 4. The client receives a successful response from `GET foo`. - /// 5. The client receives `MOVED` or `ASK` from `GET bar`. - /// 6. The client receives a successful response from `GET baz`. - /// - /// In this scenario the client will retry `GET bar` against the correct node, but after `GET baz` has already - /// finished. Callers should use a transaction if they require commands to always finish in order across - /// arbitrary keys in a cluster. Both a `Pipeline` and `Transaction` will run a series of commands without - /// interruption, but only a `Transaction` can guarantee in-order execution while accounting for cluster errors. - /// - /// Note: if the third command also operated on the `bar` key (such as `TTL bar` instead of `GET baz`) then the - /// commands **would** finish in order, since the server would respond with `MOVED` or `ASK` to both commands, - /// and the client would retry them in the same order. Pipeline { commands: Vec }, /// Send a transaction to the server. - /// - /// Notes: - /// * The inner command buffer will not contain the trailing `EXEC` command. - /// * Transactions are never pipelined in order to handle ASK responses. - /// * IDs must be unique w/r/t other transactions buffered in memory. - /// - /// There is one special failure mode that must be considered: - /// 1. The client sends `MULTI` and we receive an `OK` response. - /// 2. The caller sends `GET foo{1}` and we receive a `QUEUED` response. - /// 3. The caller sends `GET bar{1}` and we receive an `ASK` response. - /// - /// According to the cluster spec the client should retry the entire transaction against the node in the `ASK` - /// response, but with an `ASKING` command before `MULTI`. However, the future returned to the caller from `GET - /// foo{1}` will have already finished at this point. To account for this the client will never pipeline - /// transactions against a cluster, and may clone commands before sending them in order to replay them later with - /// a different cluster node mapping. + // Notes: + // * The inner command buffer will not contain the trailing `EXEC` command. + // * Transactions are never pipelined in order to handle ASK responses. + // * IDs must be unique w/r/t other transactions buffered in memory. + // + // There is one special failure mode that must be considered: + // 1. The client sends `MULTI` and we receive an `OK` response. + // 2. The caller sends `GET foo{1}` and we receive a `QUEUED` response. + // 3. The caller sends `GET bar{1}` and we receive an `ASK` response. + // + // According to the cluster spec the client should retry the entire transaction against the node in the `ASK` + // response, but with an `ASKING` command before `MULTI`. However, the future returned to the caller from `GET + // foo{1}` will have already finished at this point. To account for this the client will never pipeline + // transactions against a cluster, and may clone commands before sending them in order to replay them later with + // a different cluster node mapping. Transaction { id: u64, commands: Vec, @@ -1879,27 +1869,21 @@ pub enum RouterCommand { tx: ResponseSender, }, /// Retry a command after a `MOVED` error. - /// - /// This will trigger a call to `CLUSTER SLOTS` before the command is retried. Additionally, - /// the client will **not** increment the command's write attempt counter. + // This will trigger a call to `CLUSTER SLOTS` before the command is retried. Moved { slot: u16, server: Server, command: RedisCommand, }, /// Retry a command after an `ASK` error. - /// - /// The client will **not** increment the command's write attempt counter. - /// - /// This is typically used instead of `RouterResponse::Ask` when a command was pipelined. + // This is typically used instead of `RouterResponse::Ask` when a command was pipelined. Ask { slot: u16, server: Server, command: RedisCommand, }, /// Initiate a reconnection to the provided server, or all servers. - /// - /// The client may not perform a reconnection if a healthy connection exists to `server`, unless `force` is `true`. + // The client may not perform a reconnection if a healthy connection exists to `server`, unless `force` is `true`. Reconnect { server: Option, force: bool, @@ -1916,6 +1900,38 @@ pub enum RouterCommand { SyncReplicas { tx: OneshotSender> }, } +impl RouterCommand { + /// Inherit settings from the configuration structs on `inner`. + pub fn inherit_options(&mut self, inner: &Arc) { + match self { + RouterCommand::Command(ref mut cmd) => { + cmd.inherit_options(inner); + }, + RouterCommand::Pipeline { ref mut commands, .. } => { + for cmd in commands.iter_mut() { + cmd.inherit_options(inner); + } + }, + RouterCommand::Transaction { ref mut commands, .. } => { + for cmd in commands.iter_mut() { + cmd.inherit_options(inner); + } + }, + _ => {}, + }; + } + + /// Apply a timeout to the response channel receiver based on the command and `inner` context. + pub fn timeout_dur(&self) -> Option { + match self { + RouterCommand::Command(ref command) => command.timeout_dur, + RouterCommand::Pipeline { ref commands, .. } => commands.first().and_then(|c| c.timeout_dur), + RouterCommand::Transaction { ref commands, .. } => commands.first().and_then(|c| c.timeout_dur), + _ => None, + } + } +} + impl fmt::Debug for RouterCommand { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { let mut formatter = f.debug_struct("RouterCommand"); diff --git a/src/protocol/connection.rs b/src/protocol/connection.rs index 791944c0..a2b07304 100644 --- a/src/protocol/connection.rs +++ b/src/protocol/connection.rs @@ -1,6 +1,5 @@ use crate::{ error::{RedisError, RedisErrorKind}, - globals::globals, modules::inner::RedisClientInner, protocol::{ codec::RedisCodec, @@ -12,7 +11,6 @@ use crate::{ utils as client_utils, utils, }; -use arcstr::ArcStr; use futures::{ sink::SinkExt, stream::{SplitSink, SplitStream, StreamExt}, @@ -22,6 +20,7 @@ use futures::{ use parking_lot::Mutex; use redis_protocol::resp3::types::{Frame as Resp3Frame, RespVersion}; use semver::Version; +use socket2::SockRef; use std::{ collections::VecDeque, fmt, @@ -30,6 +29,7 @@ use std::{ str, sync::{atomic::AtomicUsize, Arc}, task::{Context, Poll}, + time::Duration, }; use tokio::{net::TcpStream, task::JoinHandle}; use tokio_util::codec::Framed; @@ -41,8 +41,10 @@ use crate::{ protocol::{connection, responders::ResponseKind}, types::RedisValue, }; +use bytes_utils::Str; #[cfg(feature = "enable-rustls")] -use std::convert::TryInto; +use std::{convert::TryInto, ops::Deref}; + #[cfg(feature = "replicas")] use tokio::sync::oneshot::channel as oneshot_channel; #[cfg(feature = "enable-native-tls")] @@ -51,9 +53,9 @@ use tokio_native_tls::TlsStream as NativeTlsStream; use tokio_rustls::{client::TlsStream as RustlsStream, rustls::ServerName}; /// The contents of a simplestring OK response. -pub const OK: &'static str = "OK"; -/// The default timeout when establishing new connections. -pub const DEFAULT_CONNECTION_TIMEOUT_MS: u64 = 60_0000; +pub const OK: &str = "OK"; +/// The timeout duration used when dropping the split sink and waiting on the split stream to close. +pub const CONNECTION_CLOSE_TIMEOUT_MS: u64 = 5_000; pub type CommandBuffer = VecDeque; pub type SharedBuffer = Arc>; @@ -61,14 +63,48 @@ pub type SharedBuffer = Arc>; pub type SplitRedisSink = SplitSink, ProtocolFrame>; pub type SplitRedisStream = SplitStream>; -pub fn connection_timeout(timeout: Option) -> u64 { - let timeout = timeout.unwrap_or(globals().default_connection_timeout_ms()); +/// Connect to each socket addr and return the first successful connection. +async fn tcp_connect_any( + inner: &Arc, + server: &Server, + addrs: &Vec, +) -> Result<(TcpStream, SocketAddr), RedisError> { + let mut last_error: Option = None; - if timeout == 0 { - DEFAULT_CONNECTION_TIMEOUT_MS - } else { - timeout + for addr in addrs.iter() { + _debug!( + inner, + "Creating TCP connection to {} at {}:{}", + server.host, + addr.ip(), + addr.port() + ); + let socket = match TcpStream::connect(addr).await { + Ok(socket) => socket, + Err(e) => { + _debug!(inner, "Error connecting to {}: {:?}", addr, e); + last_error = Some(e.into()); + continue; + }, + }; + if let Some(val) = inner.connection.tcp.nodelay { + socket.set_nodelay(val)?; + } + if let Some(dur) = inner.connection.tcp.linger { + socket.set_linger(Some(dur))?; + } + if let Some(ttl) = inner.connection.tcp.ttl { + socket.set_ttl(ttl)?; + } + if let Some(ref keepalive) = inner.connection.tcp.keepalive { + SockRef::from(&socket).set_tcp_keepalive(keepalive)?; + } + + return Ok((socket, *addr)); } + + _trace!(inner, "Failed to connect to any of {:?}.", addrs); + Err(last_error.unwrap_or(RedisError::new(RedisErrorKind::IO, "Failed to connect."))) } pub enum ConnectionKind { @@ -150,7 +186,7 @@ impl Sink for ConnectionKind { fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.get_mut() { - ConnectionKind::Tcp(ref mut conn) => Pin::new(conn).poll_flush(cx).map_err(|e| e.into()), + ConnectionKind::Tcp(ref mut conn) => Pin::new(conn).poll_flush(cx).map_err(|e| e), #[cfg(feature = "enable-rustls")] ConnectionKind::Rustls(ref mut conn) => Pin::new(conn).poll_flush(cx).map_err(|e| e.into()), #[cfg(feature = "enable-native-tls")] @@ -160,7 +196,7 @@ impl Sink for ConnectionKind { fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.get_mut() { - ConnectionKind::Tcp(ref mut conn) => Pin::new(conn).poll_close(cx).map_err(|e| e.into()), + ConnectionKind::Tcp(ref mut conn) => Pin::new(conn).poll_close(cx).map_err(|e| e), #[cfg(feature = "enable-rustls")] ConnectionKind::Rustls(ref mut conn) => Pin::new(conn).poll_close(cx).map_err(|e| e.into()), #[cfg(feature = "enable-native-tls")] @@ -234,7 +270,7 @@ impl Sink for SplitSinkKind { fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.get_mut() { - SplitSinkKind::Tcp(ref mut conn) => Pin::new(conn).poll_flush(cx).map_err(|e| e.into()), + SplitSinkKind::Tcp(ref mut conn) => Pin::new(conn).poll_flush(cx).map_err(|e| e), #[cfg(feature = "enable-rustls")] SplitSinkKind::Rustls(ref mut conn) => Pin::new(conn).poll_flush(cx).map_err(|e| e.into()), #[cfg(feature = "enable-native-tls")] @@ -244,7 +280,7 @@ impl Sink for SplitSinkKind { fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.get_mut() { - SplitSinkKind::Tcp(ref mut conn) => Pin::new(conn).poll_close(cx).map_err(|e| e.into()), + SplitSinkKind::Tcp(ref mut conn) => Pin::new(conn).poll_close(cx).map_err(|e| e), #[cfg(feature = "enable-rustls")] SplitSinkKind::Rustls(ref mut conn) => Pin::new(conn).poll_close(cx).map_err(|e| e.into()), #[cfg(feature = "enable-native-tls")] @@ -303,7 +339,7 @@ pub struct RedisTransport { /// The parsed `SocketAddr` for the connection. pub addr: SocketAddr, /// The hostname used to initialize the connection. - pub default_host: ArcStr, + pub default_host: Str, /// The network connection. pub transport: ConnectionKind, /// The connection/client ID from the CLIENT ID command. @@ -315,29 +351,21 @@ pub struct RedisTransport { } impl RedisTransport { - pub async fn new_tcp(inner: &Arc, host: String, port: u16) -> Result { + pub async fn new_tcp(inner: &Arc, server: &Server) -> Result { let counters = Counters::new(&inner.counters.cmd_buffer_len); let (id, version) = (None, None); - let server = Server { - host: ArcStr::from(&host), - port, - tls_server_name: None, - }; - let default_host = ArcStr::from(host.clone()); - let codec = RedisCodec::new(inner, &server); - let addr = inner.get_resolver().await.resolve(host, port).await?; - _debug!( - inner, - "Creating TCP connection to {} at {}:{}", - server.host, - addr.ip(), - addr.port() - ); - let socket = TcpStream::connect(addr).await?; + let default_host = server.host.clone(); + let codec = RedisCodec::new(inner, server); + let addrs = inner + .get_resolver() + .await + .resolve(server.host.clone(), server.port) + .await?; + let (socket, addr) = tcp_connect_any(inner, server, &addrs).await?; let transport = ConnectionKind::Tcp(Framed::new(socket, codec)); Ok(RedisTransport { - server, + server: server.clone(), default_host, counters, addr, @@ -349,47 +377,38 @@ impl RedisTransport { #[cfg(feature = "enable-native-tls")] #[allow(unreachable_patterns)] - pub async fn new_native_tls( - inner: &Arc, - host: String, - port: u16, - server_name: Option<&ArcStr>, - ) -> Result { + pub async fn new_native_tls(inner: &Arc, server: &Server) -> Result { let connector = match inner.config.tls { Some(ref config) => match config.connector { TlsConnector::Native(ref connector) => connector.clone(), _ => return Err(RedisError::new(RedisErrorKind::Tls, "Invalid TLS configuration.")), }, - None => return RedisTransport::new_tcp(inner, host, port).await, + None => return RedisTransport::new_tcp(inner, server).await, }; let counters = Counters::new(&inner.counters.cmd_buffer_len); let (id, version) = (None, None); - let tls_server_name = server_name.map(|s| s.as_str()).unwrap_or(host.as_str()); + let tls_server_name = server + .tls_server_name + .as_ref() + .map(|s| s.clone()) + .unwrap_or(server.host.clone()); - let server = Server { - host: ArcStr::from(&host), - tls_server_name: Some(ArcStr::from(tls_server_name)), - port, - }; - let default_host = ArcStr::from(host.clone()); + let default_host = server.host.clone(); let codec = RedisCodec::new(inner, &server); - let addr = inner.get_resolver().await.resolve(host.clone(), port).await?; - _debug!( - inner, - "Creating `native-tls` connection to {} at {}:{}", - host, - addr.ip(), - addr.port() - ); + let addrs = inner + .get_resolver() + .await + .resolve(server.host.clone(), server.port) + .await?; + let (socket, addr) = tcp_connect_any(inner, &server, &addrs).await?; - let socket = TcpStream::connect(addr).await?; _debug!(inner, "native-tls handshake with server name/host: {}", tls_server_name); - let socket = connector.clone().connect(tls_server_name, socket).await?; + let socket = connector.clone().connect(&tls_server_name, socket).await?; let transport = ConnectionKind::NativeTls(Framed::new(socket, codec)); Ok(RedisTransport { - server, + server: server.clone(), default_host, counters, addr, @@ -400,59 +419,45 @@ impl RedisTransport { } #[cfg(not(feature = "enable-native-tls"))] - pub async fn new_native_tls( - inner: &Arc, - host: String, - port: u16, - _: Option<&ArcStr>, - ) -> Result { - RedisTransport::new_tcp(inner, host, port).await + pub async fn new_native_tls(inner: &Arc, server: &Server) -> Result { + RedisTransport::new_tcp(inner, server).await } #[cfg(feature = "enable-rustls")] #[allow(unreachable_patterns)] - pub async fn new_rustls( - inner: &Arc, - host: String, - port: u16, - server_name: Option<&ArcStr>, - ) -> Result { + pub async fn new_rustls(inner: &Arc, server: &Server) -> Result { let connector = match inner.config.tls { Some(ref config) => match config.connector { TlsConnector::Rustls(ref connector) => connector.clone(), _ => return Err(RedisError::new(RedisErrorKind::Tls, "Invalid TLS configuration.")), }, - None => return RedisTransport::new_tcp(inner, host, port).await, + None => return RedisTransport::new_tcp(inner, server).await, }; let counters = Counters::new(&inner.counters.cmd_buffer_len); let (id, version) = (None, None); - let tls_server_name = server_name.map(|s| s.as_str()).unwrap_or(host.as_str()); - let server = Server { - host: ArcStr::from(&host), - tls_server_name: Some(ArcStr::from(tls_server_name)), - port, - }; + let tls_server_name = server + .tls_server_name + .as_ref() + .map(|s| s.clone()) + .unwrap_or(server.host.clone()); - let default_host = ArcStr::from(host.clone()); - let codec = RedisCodec::new(inner, &server); - let addr = inner.get_resolver().await.resolve(host.clone(), port).await?; - _debug!( - inner, - "Creating `rustls` connection to {} at {}:{}", - host, - addr.ip(), - addr.port() - ); - let socket = TcpStream::connect(addr).await?; - let server_name: ServerName = tls_server_name.try_into()?; + let default_host = server.host.clone(); + let codec = RedisCodec::new(inner, server); + let addrs = inner + .get_resolver() + .await + .resolve(server.host.clone(), server.port) + .await?; + let (socket, addr) = tcp_connect_any(inner, &server, &addrs).await?; + let server_name: ServerName = tls_server_name.deref().try_into()?; _debug!(inner, "rustls handshake with server name/host: {:?}", tls_server_name); let socket = connector.clone().connect(server_name, socket).await?; let transport = ConnectionKind::Rustls(Framed::new(socket, codec)); Ok(RedisTransport { - server, + server: server.clone(), counters, default_host, addr, @@ -463,19 +468,14 @@ impl RedisTransport { } #[cfg(not(feature = "enable-rustls"))] - pub async fn new_rustls( - inner: &Arc, - host: String, - port: u16, - _: Option<&ArcStr>, - ) -> Result { - RedisTransport::new_tcp(inner, host, port).await + pub async fn new_rustls(inner: &Arc, server: &Server) -> Result { + RedisTransport::new_tcp(inner, server).await } /// Send a command to the server. pub async fn request_response(&mut self, cmd: RedisCommand, is_resp3: bool) -> Result { let frame = cmd.to_frame(is_resp3)?; - let _ = self.transport.send(frame).await?; + self.transport.send(frame).await?; let response = self.transport.next().await; match response { @@ -485,11 +485,11 @@ impl RedisTransport { } /// Set the client name with CLIENT SETNAME. - #[cfg(not(feature = "no-client-setname"))] + #[cfg(feature = "auto-client-setname")] pub async fn set_client_name(&mut self, inner: &Arc) -> Result<(), RedisError> { _debug!(inner, "Setting client name."); let name = &inner.id; - let command = RedisCommand::new(RedisCommandKind::ClientSetname, vec![name.as_str().into()]); + let command = RedisCommand::new(RedisCommandKind::ClientSetname, vec![name.clone().into()]); let response = self.request_response(command, inner.is_resp3()).await?; if protocol_utils::is_ok(&response) { @@ -501,7 +501,7 @@ impl RedisTransport { } } - #[cfg(feature = "no-client-setname")] + #[cfg(not(feature = "auto-client-setname"))] pub async fn set_client_name(&mut self, inner: &Arc) -> Result<(), RedisError> { _debug!(inner, "Skip setting client name."); Ok(()) @@ -532,13 +532,13 @@ impl RedisTransport { }; self.version = result.lines().find_map(|line| { - let parts: Vec<&str> = line.split(":").collect(); + let parts: Vec<&str> = line.split(':').collect(); if parts.len() < 2 { return None; } if parts[0] == "redis_version" { - Version::parse(&parts[1]).ok() + Version::parse(parts[1]).ok() } else { None } @@ -644,11 +644,10 @@ impl RedisTransport { /// Send `QUIT` and close the connection. pub async fn disconnect(&mut self, inner: &Arc) -> Result<(), RedisError> { - let timeout = globals().default_connection_timeout_ms(); let command: RedisCommand = RedisCommandKind::Quit.into(); let quit_ft = self.request_response(command, inner.is_resp3()); - if let Err(e) = client_utils::apply_timeout(quit_ft, timeout).await { + if let Err(e) = client_utils::apply_timeout(quit_ft, inner.internal_command_timeout()).await { _warn!(inner, "Error calling QUIT on backchannel: {:?}", e); } let _ = self.transport.close().await; @@ -689,14 +688,12 @@ impl RedisTransport { _trace!(inner, "Checking cluster info for {}", self.server); let command = RedisCommand::new(RedisCommandKind::ClusterInfo, vec![]); let response = self.request_response(command, inner.is_resp3()).await?; - let response: String = protocol_utils::frame_to_single_result(response)?.convert()?; + let response: String = protocol_utils::frame_to_results(response)?.convert()?; for line in response.lines() { - let parts: Vec<&str> = line.split(":").collect(); - if parts.len() == 2 { - if parts[0] == "cluster_state" && parts[1] == "ok" { - return Ok(()); - } + let parts: Vec<&str> = line.split(':').collect(); + if parts.len() == 2 && parts[0] == "cluster_state" && parts[1] == "ok" { + return Ok(()); } } @@ -708,17 +705,17 @@ impl RedisTransport { /// Authenticate, set the protocol version, set the client name, select the provided database, cache the /// connection ID and server version, and check the cluster state (if applicable). - pub async fn setup(&mut self, inner: &Arc, timeout: Option) -> Result<(), RedisError> { - let timeout = connection_timeout(timeout); + pub async fn setup(&mut self, inner: &Arc, timeout: Option) -> Result<(), RedisError> { + let timeout = timeout.unwrap_or(inner.internal_command_timeout()); utils::apply_timeout( async { - let _ = self.switch_protocols_and_authenticate(inner).await?; - let _ = self.set_client_name(inner).await?; - let _ = self.select_database(inner).await?; - let _ = self.cache_connection_id(inner).await?; - let _ = self.cache_server_version(inner).await?; - let _ = self.check_cluster_state(inner).await?; + self.switch_protocols_and_authenticate(inner).await?; + self.select_database(inner).await?; + self.set_client_name(inner).await?; + self.cache_connection_id(inner).await?; + self.cache_server_version(inner).await?; + self.check_cluster_state(inner).await?; Ok::<_, RedisError>(()) }, @@ -729,18 +726,22 @@ impl RedisTransport { /// Send `READONLY` to the server. #[cfg(feature = "replicas")] - pub async fn readonly(&mut self, inner: &Arc, timeout: Option) -> Result<(), RedisError> { + pub async fn readonly( + &mut self, + inner: &Arc, + timeout: Option, + ) -> Result<(), RedisError> { if !inner.config.server.is_clustered() { return Ok(()); } + let timeout = timeout.unwrap_or(inner.internal_command_timeout()); - let timeout = connection_timeout(timeout); utils::apply_timeout( async { _debug!(inner, "Sending READONLY to {}", self.server); let command = RedisCommand::new(RedisCommandKind::Readonly, vec![]); let response = self.request_response(command, inner.is_resp3()).await?; - let _ = protocol_utils::frame_to_single_result(response)?; + let _ = protocol_utils::frame_to_results(response)?; Ok::<_, RedisError>(()) }, @@ -754,9 +755,9 @@ impl RedisTransport { pub async fn role( &mut self, inner: &Arc, - timeout: Option, + timeout: Option, ) -> Result { - let timeout = connection_timeout(timeout); + let timeout = timeout.unwrap_or(inner.internal_command_timeout()); let command = RedisCommand::new(RedisCommandKind::Role, vec![]); utils::apply_timeout( @@ -764,7 +765,7 @@ impl RedisTransport { self .request_response(command, inner.is_resp3()) .await - .and_then(protocol_utils::frame_to_results_raw) + .and_then(protocol_utils::frame_to_results) }, timeout, ) @@ -801,7 +802,7 @@ impl RedisTransport { default_host, counters: counters.clone(), server: server.clone(), - addr: addr.clone(), + addr, buffer: buffer.clone(), reader: None, }; @@ -854,7 +855,7 @@ impl RedisReader { pub struct RedisWriter { pub sink: SplitSinkKind, pub server: Server, - pub default_host: ArcStr, + pub default_host: Str, pub addr: SocketAddr, pub buffer: SharedBuffer, pub version: Option, @@ -878,7 +879,7 @@ impl RedisWriter { /// Flush the sink and reset the feed counter. pub async fn flush(&mut self) -> Result<(), RedisError> { trace!("Flushing socket to {}", self.server); - let _ = self.sink.flush().await?; + self.sink.flush().await?; trace!("Flushed socket to {}", self.server); self.counters.reset_feed_count(); Ok(()) @@ -889,12 +890,12 @@ impl RedisWriter { let command = RedisCommand::new(RedisCommandKind::Role, vec![]); let role = connection::request_response(inner, self, command, None) .await - .and_then(protocol_utils::frame_to_results_raw)?; + .and_then(protocol_utils::frame_to_results)?; protocol_utils::parse_master_role_replicas(role) } - /// Check if the connection is connected and can send frames. + /// Check if the reader task is still running or awaiting frames. pub fn is_working(&self) -> bool { self .reader @@ -912,11 +913,11 @@ impl RedisWriter { if should_flush { trace!("Writing and flushing {}", self.server); - let _ = self.sink.send(frame).await?; + self.sink.send(frame).await?; self.counters.reset_feed_count(); } else { trace!("Writing without flushing {}", self.server); - let _ = self.sink.feed(frame).await?; + self.sink.feed(frame).await?; self.counters.incr_feed_count(); }; self.counters.incr_in_flight(); @@ -960,7 +961,6 @@ impl RedisWriter { /// /// Returns the in-flight commands that had not received a response. pub async fn graceful_close(mut self) -> CommandBuffer { - let timeout = globals().default_connection_timeout_ms(); let _ = utils::apply_timeout( async { let _ = self.sink.close().await; @@ -970,7 +970,7 @@ impl RedisWriter { Ok::<_, RedisError>(()) }, - timeout, + Duration::from_millis(CONNECTION_CLOSE_TIMEOUT_MS), ) .await; @@ -983,12 +983,10 @@ impl RedisWriter { /// The returned connection will not be initialized. pub async fn create( inner: &Arc, - host: String, - port: u16, - timeout_ms: Option, - tls_server_name: Option<&ArcStr>, + server: &Server, + timeout: Option, ) -> Result { - let timeout = connection_timeout(timeout_ms); + let timeout = timeout.unwrap_or(inner.connection_timeout()); _trace!( inner, @@ -997,15 +995,11 @@ pub async fn create( inner.config.uses_rustls() ); if inner.config.uses_native_tls() { - utils::apply_timeout( - RedisTransport::new_native_tls(inner, host, port, tls_server_name), - timeout, - ) - .await + utils::apply_timeout(RedisTransport::new_native_tls(inner, server), timeout).await } else if inner.config.uses_rustls() { - utils::apply_timeout(RedisTransport::new_rustls(inner, host, port, tls_server_name), timeout).await + utils::apply_timeout(RedisTransport::new_rustls(inner, server), timeout).await } else { - utils::apply_timeout(RedisTransport::new_tcp(inner, host, port), timeout).await + utils::apply_timeout(RedisTransport::new_tcp(inner, server), timeout).await } } @@ -1056,10 +1050,13 @@ pub async fn request_response( inner: &Arc, writer: &mut RedisWriter, mut command: RedisCommand, - timeout: Option, + timeout: Option, ) -> Result { let (tx, rx) = oneshot_channel(); command.response = ResponseKind::Respond(Some(tx)); + let timeout_dur = timeout + .or(command.timeout_dur.clone()) + .unwrap_or_else(|| inner.default_command_timeout()); _trace!( inner, @@ -1075,12 +1072,6 @@ pub async fn request_response( let _ = writer.pop_recent_command(); Err(e) } else { - let timeout = timeout.unwrap_or(inner.default_command_timeout()); - - if timeout > 0 { - utils::apply_timeout(async { rx.await? }, timeout).await - } else { - rx.await? - } + utils::apply_timeout(async { rx.await? }, timeout_dur).await } } diff --git a/src/protocol/hashers.rs b/src/protocol/hashers.rs index 212328d3..1b2fe3b5 100644 --- a/src/protocol/hashers.rs +++ b/src/protocol/hashers.rs @@ -4,7 +4,7 @@ use redis_protocol::redis_keyslot; fn hash_value(value: &RedisValue) -> Option { Some(match value { RedisValue::String(s) => redis_keyslot(s.as_bytes()), - RedisValue::Bytes(b) => redis_keyslot(&b), + RedisValue::Bytes(b) => redis_keyslot(b), RedisValue::Integer(i) => redis_keyslot(i.to_string().as_bytes()), RedisValue::Double(f) => redis_keyslot(f.to_string().as_bytes()), RedisValue::Null => redis_keyslot(b"nil"), @@ -16,13 +16,13 @@ fn hash_value(value: &RedisValue) -> Option { pub fn read_redis_key(value: &RedisValue) -> Option<&[u8]> { match value { RedisValue::String(s) => Some(s.as_bytes()), - RedisValue::Bytes(b) => Some(&b), + RedisValue::Bytes(b) => Some(b), _ => None, } } fn hash_key(value: &RedisValue) -> Option { - read_redis_key(value).map(|k| redis_keyslot(k)) + read_redis_key(value).map(redis_keyslot) } /// A cluster hashing policy. @@ -77,10 +77,10 @@ impl ClusterHash { /// Hash the provided arguments. pub fn hash(&self, args: &[RedisValue]) -> Option { match self { - ClusterHash::FirstValue => args.get(0).and_then(|v| hash_value(v)), - ClusterHash::FirstKey => args.iter().find_map(|v| hash_key(v)), + ClusterHash::FirstValue => args.get(0).and_then(hash_value), + ClusterHash::FirstKey => args.iter().find_map(hash_key), ClusterHash::Random => None, - ClusterHash::Offset(idx) => args.get(*idx).and_then(|v| hash_value(v)), + ClusterHash::Offset(idx) => args.get(*idx).and_then(hash_value), ClusterHash::Custom(val) => Some(*val), } } @@ -88,9 +88,9 @@ impl ClusterHash { /// Find the key to hash with the provided arguments. pub fn find_key<'a>(&self, args: &'a [RedisValue]) -> Option<&'a [u8]> { match self { - ClusterHash::FirstValue => args.get(0).and_then(|v| read_redis_key(v)), - ClusterHash::FirstKey => args.iter().find_map(|v| read_redis_key(v)), - ClusterHash::Offset(idx) => args.get(*idx).and_then(|v| read_redis_key(v)), + ClusterHash::FirstValue => args.get(0).and_then(read_redis_key), + ClusterHash::FirstKey => args.iter().find_map(read_redis_key), + ClusterHash::Offset(idx) => args.get(*idx).and_then(read_redis_key), ClusterHash::Random | ClusterHash::Custom(_) => None, } } diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index ba8450a3..6364794b 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -11,3 +11,6 @@ pub mod responders; pub mod tls; pub mod types; pub mod utils; + +#[cfg(feature = "codec")] +pub mod public; diff --git a/src/protocol/public.rs b/src/protocol/public.rs new file mode 100644 index 00000000..c54751d5 --- /dev/null +++ b/src/protocol/public.rs @@ -0,0 +1,210 @@ +use crate::error::{RedisError, RedisErrorKind}; +use bytes::BytesMut; +use redis_protocol::{ + resp2::{decode::decode_mut as resp2_decode, encode::encode_bytes as resp2_encode}, + resp3::{ + decode::streaming::decode_mut as resp3_decode, + encode::complete::encode_bytes as resp3_encode, + types::StreamedFrame, + }, +}; +use tokio_util::codec::{Decoder, Encoder}; + +pub use redis_protocol::{ + redis_keyslot, + resp2::types::{Frame as Resp2Frame, FrameKind as Resp2FrameKind}, + resp2_frame_to_resp3, + resp3::types::{Auth, Frame as Resp3Frame, FrameKind as Resp3FrameKind, RespVersion}, +}; + +/// Encode a redis command string (`SET foo bar NX`, etc) into a RESP3 blob string array. +pub fn resp3_encode_command(cmd: &str) -> Resp3Frame { + Resp3Frame::Array { + data: cmd + .split(" ") + .map(|s| Resp3Frame::BlobString { + data: s.as_bytes().to_vec().into(), + attributes: None, + }) + .collect(), + attributes: None, + } +} + +/// Encode a redis command string (`SET foo bar NX`, etc) into a RESP2 bulk string array. +pub fn resp2_encode_command(cmd: &str) -> Resp2Frame { + Resp2Frame::Array( + cmd + .split(" ") + .map(|s| Resp2Frame::BulkString(s.as_bytes().to_vec().into())) + .collect(), + ) +} + +/// A framed RESP2 codec. +/// +/// ```rust +/// use fred::{ +/// codec::{resp2_encode_command, Resp2, Resp2Frame}, +/// prelude::*, +/// }; +/// use futures::{SinkExt, StreamExt}; +/// use tokio::net::TcpStream; +/// use tokio_util::codec::Framed; +/// +/// async fn example() -> Result<(), RedisError> { +/// let socket = TcpStream::connect("127.0.0.1:6379").await?; +/// let mut framed = Framed::new(socket, Resp2::default()); +/// +/// let auth = resp2_encode_command("AUTH foo bar"); +/// let get_foo = resp2_encode_command("GET foo"); +/// +/// let _ = framed.send(auth).await?; +/// let response = framed.next().await.unwrap().unwrap(); +/// assert_eq!(response.as_str().unwrap(), "OK"); +/// +/// let _ = framed.send(get_foo).await?; +/// let response = framed.next().await.unwrap().unwrap(); +/// assert_eq!(response, Resp2Frame::Null); +/// +/// Ok(()) +/// } +/// ``` +#[derive(Default)] +pub struct Resp2; + +impl Encoder for Resp2 { + type Error = RedisError; + + fn encode(&mut self, item: Resp2Frame, dst: &mut BytesMut) -> Result<(), Self::Error> { + #[cfg(feature = "network-logs")] + trace!("RESP2 codec encode: {:?}", item); + + resp2_encode(dst, &item).map(|_| ()).map_err(RedisError::from) + } +} + +impl Decoder for Resp2 { + type Error = RedisError; + type Item = Resp2Frame; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + if src.is_empty() { + return Ok(None); + } + let parsed = match resp2_decode(src)? { + Some((frame, _, _)) => frame, + None => return Ok(None), + }; + #[cfg(feature = "network-logs")] + trace!("RESP2 codec decode: {:?}", parsed); + + Ok(Some(parsed)) + } +} + +/// A framed codec for complete and streaming/chunked RESP3 frames with optional attributes. +/// +/// ```rust +/// use fred::{ +/// codec::{resp3_encode_command, Auth, Resp3, Resp3Frame, RespVersion}, +/// prelude::*, +/// }; +/// use futures::{SinkExt, StreamExt}; +/// use tokio::net::TcpStream; +/// use tokio_util::codec::Framed; +/// +/// // send `HELLO 3 AUTH foo bar` then `GET foo` +/// async fn example() -> Result<(), RedisError> { +/// let socket = TcpStream::connect("127.0.0.1:6379").await?; +/// let mut framed = Framed::new(socket, Resp3::default()); +/// +/// let hello = Resp3Frame::Hello { +/// version: RespVersion::RESP3, +/// auth: Some(Auth { +/// username: "foo".into(), +/// password: "bar".into(), +/// }), +/// }; +/// // or use the shorthand, but this likely only works for simple use cases +/// let get_foo = resp3_encode_command("GET foo"); +/// +/// // `Framed` implements both `Sink` and `Stream` +/// let _ = framed.send(hello).await?; +/// let response = framed.next().await; +/// println!("HELLO response: {:?}", response); +/// +/// let _ = framed.send(get_foo).await?; +/// let response = framed.next().await; +/// println!("GET foo: {:?}", response); +/// +/// Ok(()) +/// } +/// ``` +#[derive(Default)] +pub struct Resp3 { + streaming: Option, +} + +impl Encoder for Resp3 { + type Error = RedisError; + + fn encode(&mut self, item: Resp3Frame, dst: &mut BytesMut) -> Result<(), Self::Error> { + #[cfg(feature = "network-logs")] + trace!("RESP3 codec encode: {:?}", item); + + resp3_encode(dst, &item).map(|_| ()).map_err(RedisError::from) + } +} + +impl Decoder for Resp3 { + type Error = RedisError; + type Item = Resp3Frame; + + // FIXME ideally this would refer to the corresponding fn in codec.rs, but that code is too tightly coupled to the + // private inner interface to expose here + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + if src.is_empty() { + return Ok(None); + } + let parsed = match resp3_decode(src)? { + Some((f, _, _)) => f, + None => return Ok(None), + }; + + if self.streaming.is_some() && parsed.is_streaming() { + return Err(RedisError::new( + RedisErrorKind::Protocol, + "Cannot start a stream while already inside a stream.", + )); + } + + let result = if let Some(ref mut state) = self.streaming { + // we started receiving streamed data earlier + state.add_frame(parsed.into_complete_frame()?); + + if state.is_finished() { + Some(state.into_frame()?) + } else { + None + } + } else { + // we're processing a complete frame or starting a new streamed frame + if parsed.is_streaming() { + self.streaming = Some(parsed.into_streaming_frame()?); + None + } else { + // we're not in the middle of a stream and we found a complete frame + Some(parsed.into_complete_frame()?) + } + }; + + if result.is_some() { + let _ = self.streaming.take(); + } + + #[cfg(feature = "network-logs")] + trace!("RESP3 codec decode: {:?}", result); + Ok(result) + } +} diff --git a/src/protocol/responders.rs b/src/protocol/responders.rs index 04db80bc..c8aca62e 100644 --- a/src/protocol/responders.rs +++ b/src/protocol/responders.rs @@ -16,6 +16,7 @@ use std::{ fmt, fmt::Formatter, iter::repeat, + mem, ops::DerefMut, sync::{atomic::AtomicUsize, Arc}, }; @@ -27,7 +28,7 @@ use parking_lot::RwLock; #[cfg(feature = "metrics")] use std::{cmp, time::Instant}; -const LAST_CURSOR: &'static str = "0"; +const LAST_CURSOR: &str = "0"; pub enum ResponseKind { /// Throw away the response frame and last command in the command buffer. @@ -112,14 +113,14 @@ impl ResponseKind { frames: frames.clone(), tx: tx.clone(), received: received.clone(), - index: index.clone(), - expected: expected.clone(), - error_early: error_early.clone(), + index: *index, + expected: *expected, + error_early: *error_early, }, ResponseKind::Multiple { received, tx, expected } => ResponseKind::Multiple { received: received.clone(), tx: tx.clone(), - expected: expected.clone(), + expected: *expected, }, ResponseKind::KeyScan(_) | ResponseKind::ValueScan(_) => return None, }) @@ -287,36 +288,18 @@ fn add_buffered_frame( Ok(()) } -/// Merge multiple potentially nested frames into one flat array of frames. +/// Check for errors while merging the provided frames into one Array frame. fn merge_multiple_frames(frames: &mut Vec, error_early: bool) -> Resp3Frame { - let inner_len = frames.iter().fold(0, |count, frame| { - count - + match frame { - Resp3Frame::Array { ref data, .. } => data.len(), - Resp3Frame::Push { ref data, .. } => data.len(), - _ => 1, + if error_early { + for frame in frames.iter() { + if frame.is_error() { + return frame.clone(); } - }); - - let mut out = Vec::with_capacity(inner_len); - for frame in frames.drain(..) { - // unwrap and return errors early - if error_early && frame.is_error() { - return frame; } - - match frame { - Resp3Frame::Array { data, .. } | Resp3Frame::Push { data, .. } => { - for inner_frame in data.into_iter() { - out.push(inner_frame); - } - }, - _ => out.push(frame), - }; } Resp3Frame::Array { - data: out, + data: mem::take(frames), attributes: None, } } @@ -391,7 +374,7 @@ fn parse_value_scan_frame(frame: Resp3Frame) -> Result<(Str, Vec), R let mut values = Vec::with_capacity(data.len()); for frame in data.into_iter() { - values.push(protocol_utils::frame_to_single_result(frame)?); + values.push(protocol_utils::frame_to_results(frame)?); } Ok((cursor, values)) @@ -586,7 +569,7 @@ pub fn respond_buffer( ); // errors are buffered like normal frames and are not returned early - if let Err(e) = add_buffered_frame(&server, &frames, index, frame) { + if let Err(e) = add_buffered_frame(server, &frames, index, frame) { respond_locked(inner, &tx, Err(e)); command.respond_to_router(inner, RouterResponse::Continue); _error!( @@ -614,7 +597,7 @@ pub fn respond_buffer( command.debug_id() ); - let frame = merge_multiple_frames(frames.lock().deref_mut(), error_early); + let frame = merge_multiple_frames(&mut frames.lock(), error_early); if frame.is_error() { let err = match frame.as_str() { Some(s) => protocol_utils::pretty_error(s), diff --git a/src/protocol/types.rs b/src/protocol/types.rs index 473ff275..7f554970 100644 --- a/src/protocol/types.rs +++ b/src/protocol/types.rs @@ -6,7 +6,6 @@ use crate::{ types::*, utils, }; -use arcstr::ArcStr; use bytes_utils::Str; use rand::Rng; pub use redis_protocol::{redis_keyslot, resp2::types::NULL, types::CRLF}; @@ -27,6 +26,7 @@ pub const REDIS_CLUSTER_SLOTS: u16 = 16384; #[cfg(any(feature = "enable-rustls", feature = "enable-native-tls"))] use std::{net::IpAddr, str::FromStr}; +/// Any kind of RESP frame. #[derive(Debug)] pub enum ProtocolFrame { Resp2(Resp2Frame), @@ -34,6 +34,7 @@ pub enum ProtocolFrame { } impl ProtocolFrame { + /// Convert the frame tp RESP3. pub fn into_resp3(self) -> Resp3Frame { // the `RedisValue::convert` logic already accounts for different encodings of maps and sets, so // we can just change everything to RESP3 above the protocol layer @@ -42,6 +43,11 @@ impl ProtocolFrame { ProtocolFrame::Resp3(frame) => frame, } } + + /// Whether the frame is encoded as a RESP3 frame. + pub fn is_resp3(&self) -> bool { + matches!(*self, ProtocolFrame::Resp3(_)) + } } impl From for ProtocolFrame { @@ -60,19 +66,20 @@ impl From for ProtocolFrame { #[derive(Debug, Clone)] pub struct Server { /// The hostname or IP address for the server. - pub host: ArcStr, + pub host: Str, /// The port for the server. pub port: u16, /// The server name used during the TLS handshake. + #[cfg(any(feature = "enable-rustls", feature = "enable-native-tls"))] #[cfg_attr(docsrs, doc(cfg(any(feature = "enable-rustls", feature = "enable-native-tls"))))] - pub tls_server_name: Option, + pub tls_server_name: Option, } impl Server { /// Create a new `Server` from parts with a TLS server name. #[cfg(any(feature = "enable-rustls", feature = "enable-native-tls"))] #[cfg_attr(docsrs, doc(cfg(any(feature = "enable-rustls", feature = "enable-native-tls"))))] - pub fn new_with_tls>(host: S, port: u16, tls_server_name: Option) -> Self { + pub fn new_with_tls>(host: S, port: u16, tls_server_name: Option) -> Self { Server { host: host.into(), port, @@ -81,10 +88,11 @@ impl Server { } /// Create a new `Server` from parts. - pub fn new>(host: S, port: u16) -> Self { + pub fn new>(host: S, port: u16) -> Self { Server { host: host.into(), port, + #[cfg(any(feature = "enable-rustls", feature = "enable-native-tls"))] tls_server_name: None, } } @@ -100,18 +108,19 @@ impl Server { Err(_) => return, }; if let Some(tls_server_name) = policy.map(&ip, default_host) { - self.tls_server_name = Some(ArcStr::from(tls_server_name)); + self.tls_server_name = Some(Str::from(tls_server_name)); } } /// Attempt to parse a `host:port` string. pub(crate) fn from_str(s: &str) -> Option { - let parts: Vec<&str> = s.trim().split(":").collect(); + let parts: Vec<&str> = s.trim().split(':').collect(); if parts.len() == 2 { - if let Some(port) = parts[1].parse::().ok() { + if let Ok(port) = parts[1].parse::() { Some(Server { host: parts[0].into(), port, + #[cfg(any(feature = "enable-rustls", feature = "enable-native-tls"))] tls_server_name: None, }) } else { @@ -126,14 +135,15 @@ impl Server { pub(crate) fn from_parts(server: &str, default_host: &str) -> Option { server_to_parts(server).ok().map(|(host, port)| { let host = if host.is_empty() { - ArcStr::from(default_host) + Str::from(default_host) } else { - ArcStr::from(host) + Str::from(host) }; Server { host, port, + #[cfg(any(feature = "enable-rustls", feature = "enable-native-tls"))] tls_server_name: None, } }) @@ -161,6 +171,7 @@ impl From<(String, u16)> for Server { Server { host: host.into(), port, + #[cfg(any(feature = "enable-native-tls", feature = "enable-rustls"))] tls_server_name: None, } } @@ -171,6 +182,7 @@ impl From<(&str, u16)> for Server { Server { host: host.into(), port, + #[cfg(any(feature = "enable-native-tls", feature = "enable-rustls"))] tls_server_name: None, } } @@ -251,6 +263,8 @@ pub struct Message { pub value: RedisValue, /// The type of message subscription. pub kind: MessageKind, + /// The server that sent the message. + pub server: Server, } pub struct KeyScanInner { @@ -332,7 +346,7 @@ impl ValueScanInner { out.insert(key, value); } - Ok(out.try_into()?) + out.try_into() } pub fn transform_zscan_result(mut data: Vec) -> Result, RedisError> { @@ -379,7 +393,7 @@ pub struct SlotRange { /// The primary server owner. pub primary: Server, /// The internal ID assigned by the server. - pub id: ArcStr, + pub id: Str, /// Replica node owners. #[cfg(feature = "replicas")] #[cfg_attr(docsrs, doc(cfg(feature = "replicas")))] @@ -398,6 +412,17 @@ impl ClusterRouting { ClusterRouting { data: Vec::new() } } + /// Create a new routing table from the result of the `CLUSTER SLOTS` command. + /// + /// The `default_host` value refers to the server that provided the response. + pub fn from_cluster_slots>(value: RedisValue, default_host: S) -> Result { + let default_host = default_host.into(); + let mut data = cluster::parse_cluster_slots(value, &default_host)?; + data.sort_by(|a, b| a.start.cmp(&b.start)); + + Ok(ClusterRouting { data }) + } + /// Read a set of unique hash slots that each map to a different primary/main node in the cluster. pub fn unique_hash_slots(&self) -> Vec { let mut out = BTreeMap::new(); @@ -425,7 +450,7 @@ impl ClusterRouting { &mut self, inner: &Arc, cluster_slots: RedisValue, - default_host: &str, + default_host: &Str, ) -> Result<(), RedisError> { self.data = cluster::parse_cluster_slots(cluster_slots, default_host)?; self.data.sort_by(|a, b| a.start.cmp(&b.start)); @@ -478,7 +503,7 @@ impl ClusterRouting { /// Read a random primary node hash slot range from the cluster cache. pub fn random_slot(&self) -> Option<&SlotRange> { - if self.data.len() > 0 { + if !self.data.is_empty() { let idx = rand::thread_rng().gen_range(0 .. self.data.len()); Some(&self.data[idx]) } else { @@ -499,29 +524,30 @@ impl ClusterRouting { #[cfg_attr(docsrs, doc(cfg(feature = "dns")))] pub trait Resolve: Send + Sync + 'static { /// Resolve a hostname. - async fn resolve(&self, host: String, port: u16) -> Result; + async fn resolve(&self, host: Str, port: u16) -> Result, RedisError>; } /// Default DNS resolver that uses `to_socket_addrs` under the hood. #[derive(Clone, Debug)] pub struct DefaultResolver { - id: ArcStr, + id: Str, } impl DefaultResolver { /// Create a new resolver using the system's default DNS resolution. - pub fn new(id: &ArcStr) -> Self { + pub fn new(id: &Str) -> Self { DefaultResolver { id: id.clone() } } } #[async_trait] impl Resolve for DefaultResolver { - async fn resolve(&self, host: String, port: u16) -> Result { + async fn resolve(&self, host: Str, port: u16) -> Result, RedisError> { let client_id = self.id.clone(); tokio::task::spawn_blocking(move || { - let ips: Vec = format!("{}:{}", host, port).to_socket_addrs()?.into_iter().collect(); + let addr = format!("{}:{}", host, port); + let ips: Vec = addr.to_socket_addrs()?.collect(); if ips.is_empty() { Err(RedisError::new( @@ -529,18 +555,8 @@ impl Resolve for DefaultResolver { format!("Failed to resolve {}:{}", host, port), )) } else { - let possible_addrs = ips.len(); - let addr = ips[0]; - - trace!( - "{}: Using {} among {} possible socket addresses for {}:{}", - client_id, - addr.ip(), - possible_addrs, - host, - port - ); - Ok(addr) + trace!("{}: Found {} addresses for {}", client_id, ips.len(), addr); + Ok(ips) } }) .await? diff --git a/src/protocol/utils.rs b/src/protocol/utils.rs index b744da84..efe358d6 100644 --- a/src/protocol/utils.rs +++ b/src/protocol/utils.rs @@ -8,7 +8,6 @@ use crate::{ }, types::*, utils, - utils::redis_string_to_f64, }; use bytes::Bytes; use bytes_utils::Str; @@ -16,15 +15,8 @@ use redis_protocol::{ resp2::types::Frame as Resp2Frame, resp3::types::{Auth, Frame as Resp3Frame, FrameMap, PUBSUB_PUSH_PREFIX}, }; -use semver::Version; use std::{borrow::Cow, collections::HashMap, convert::TryInto, ops::Deref, str, sync::Arc}; -macro_rules! parse_or_zero( - ($data:ident, $t:ty) => { - $data.parse::<$t>().ok().unwrap_or(0) - } -); - pub fn initial_buffer_size(inner: &Arc) -> usize { if inner.performance.load().as_ref().auto_pipeline { // TODO make this configurable @@ -34,14 +26,8 @@ pub fn initial_buffer_size(inner: &Arc) -> usize { } } -/// Read the major redis version, assuming version 6 if a version is not provided. -#[allow(dead_code)] -pub fn major_redis_version(version: &Option) -> u8 { - version.as_ref().map(|v| v.major as u8).unwrap_or(6) -} - pub fn parse_cluster_error(data: &str) -> Result<(ClusterErrorKind, u16, String), RedisError> { - let parts: Vec<&str> = data.split(" ").collect(); + let parts: Vec<&str> = data.split(' ').collect(); if parts.len() == 3 { let kind: ClusterErrorKind = parts[0].try_into()?; let slot: u16 = parts[1].parse()?; @@ -76,16 +62,8 @@ pub fn is_ok(frame: &Resp3Frame) -> bool { } } -/// Whether the provided frame is null. -pub fn is_null(frame: &Resp3Frame) -> bool { - match frame { - Resp3Frame::Null => true, - _ => false, - } -} - pub fn server_to_parts(server: &str) -> Result<(&str, u16), RedisError> { - let parts: Vec<&str> = server.split(":").collect(); + let parts: Vec<&str> = server.split(':').collect(); if parts.len() < 2 { return Err(RedisError::new(RedisErrorKind::IO, "Invalid server.")); } @@ -125,13 +103,13 @@ pub fn pretty_error(resp: &str) -> RedisError { let kind = { let mut parts = resp.split_whitespace(); - match parts.next().unwrap_or("").as_ref() { + match parts.next().unwrap_or("") { "" => RedisErrorKind::Unknown, "ERR" => RedisErrorKind::Unknown, "WRONGTYPE" => RedisErrorKind::InvalidArgument, "NOAUTH" | "WRONGPASS" => RedisErrorKind::Auth, "MOVED" | "ASK" | "CLUSTERDOWN" => RedisErrorKind::Cluster, - "Invalid" => match parts.next().unwrap_or("").as_ref() { + "Invalid" => match parts.next().unwrap_or("") { "argument(s)" | "Argument" => RedisErrorKind::InvalidArgument, "command" | "Command" => RedisErrorKind::InvalidCommand, _ => RedisErrorKind::Unknown, @@ -163,10 +141,11 @@ pub fn frame_into_string(frame: Resp3Frame) -> Result { } /// Parse the frame from a shard pubsub channel. -pub fn parse_shard_pubsub_frame(frame: &Resp3Frame) -> Option { +pub fn parse_shard_pubsub_frame(server: &Server, frame: &Resp3Frame) -> Option { let value = match frame { Resp3Frame::Array { ref data, .. } | Resp3Frame::Push { ref data, .. } => { if data.len() >= 3 && data.len() <= 5 { + // check both resp2 and resp3 formats let has_either_prefix = (data[0].as_str().map(|s| s == PUBSUB_PUSH_PREFIX).unwrap_or(false) && data[1].as_str().map(|s| s == "smessage").unwrap_or(false)) || (data[0].as_str().map(|s| s == "smessage").unwrap_or(false)); @@ -196,6 +175,7 @@ pub fn parse_shard_pubsub_frame(frame: &Resp3Frame) -> Option { channel, value, kind: MessageKind::SMessage, + server: server.clone(), }) } @@ -253,28 +233,33 @@ pub fn parse_message_fields(frame: &Resp3Frame) -> Result<(Str, RedisValue), Red Ok((channel, value)) } -/// Convert the frame to a `(channel, message)` tuple from the pubsub interface. -pub fn frame_to_pubsub(frame: Resp3Frame) -> Result { - if let Some(message) = parse_shard_pubsub_frame(&frame) { +/// Parse the frame as a pubsub message. +pub fn frame_to_pubsub(server: &Server, frame: Resp3Frame) -> Result { + if let Some(message) = parse_shard_pubsub_frame(server, &frame) { return Ok(message); } let kind = parse_message_kind(&frame)?; let (channel, value) = parse_message_fields(&frame)?; - Ok(Message { kind, channel, value }) + Ok(Message { + kind, + channel, + value, + server: server.clone(), + }) } /// Attempt to parse a RESP3 frame as a pubsub message in the RESP2 format. /// /// This can be useful in cases where the codec layer automatically upgrades to RESP3, /// but the contents of the pubsub message still use the RESP2 format. -pub fn parse_as_resp2_pubsub(frame: Resp3Frame) -> Result { - if let Some(message) = parse_shard_pubsub_frame(&frame) { +// TODO move and redo this in redis_protocol +pub fn parse_as_resp2_pubsub(server: &Server, frame: Resp3Frame) -> Result { + if let Some(message) = parse_shard_pubsub_frame(server, &frame) { return Ok(message); } // resp3 has an added "pubsub" simple string frame at the front - // TODO move and redo this in redis_protocol let mut out = Vec::with_capacity(frame.len() + 1); out.push(Resp3Frame::SimpleString { data: PUBSUB_PUSH_PREFIX.into(), @@ -288,7 +273,7 @@ pub fn parse_as_resp2_pubsub(frame: Resp3Frame) -> Result { attributes: None, }; - frame_to_pubsub(frame) + frame_to_pubsub(server, frame) } else { Err(RedisError::new( RedisErrorKind::Protocol, @@ -346,7 +331,7 @@ pub fn check_resp3_auth_error(frame: Resp3Frame) -> Resp3Frame { /// Try to parse the data as a string, and failing that return a byte slice. pub fn string_or_bytes(data: Bytes) -> RedisValue { - if let Some(s) = Str::from_inner(data.clone()).ok() { + if let Ok(s) = Str::from_inner(data.clone()) { RedisValue::String(s) } else { RedisValue::Bytes(data) @@ -377,27 +362,13 @@ pub fn frame_to_str(frame: &Resp3Frame) -> Option { } } -fn parse_nested_array(data: Vec) -> Result { - let mut out = Vec::with_capacity(data.len()); - - for frame in data.into_iter() { - out.push(frame_to_results(frame)?); - } - - if out.len() == 1 { - Ok(out.pop().unwrap()) - } else { - Ok(RedisValue::Array(out)) - } -} - fn parse_nested_map(data: FrameMap) -> Result { let mut out = HashMap::with_capacity(data.len()); // maybe make this smarter, but that would require changing the RedisMap type to use potentially non-hashable types // as keys... for (key, value) in data.into_iter() { - let key: RedisKey = frame_to_single_result(key)?.try_into()?; + let key: RedisKey = frame_to_results(key)?.try_into()?; let value = frame_to_results(value)?; out.insert(key, value); @@ -416,8 +387,6 @@ pub fn check_null_timeout(frame: &Resp3Frame) -> Result<(), RedisError> { } /// Parse the protocol frame into a redis value, with support for arbitrarily nested arrays. -/// -/// If the array contains one element then that element will be returned. pub fn frame_to_results(frame: Resp3Frame) -> Result { let value = match frame { Resp3Frame::Null => RedisValue::Null, @@ -441,75 +410,23 @@ pub fn frame_to_results(frame: Resp3Frame) -> Result { Resp3Frame::Double { data, .. } => data.into(), Resp3Frame::BigNumber { data, .. } => string_or_bytes(data), Resp3Frame::Boolean { data, .. } => data.into(), - Resp3Frame::Array { data, .. } => parse_nested_array(data)?, - Resp3Frame::Push { data, .. } => parse_nested_array(data)?, - Resp3Frame::Set { data, .. } => { - let mut out = Vec::with_capacity(data.len()); - for frame in data.into_iter() { - out.push(frame_to_results(frame)?); - } - - RedisValue::Array(out) - }, - Resp3Frame::Map { data, .. } => RedisValue::Map(parse_nested_map(data)?), - _ => { - return Err(RedisError::new( - RedisErrorKind::Protocol, - "Invalid response frame type.", - )) - }, - }; - - Ok(value) -} - -/// Parse the protocol frame into a redis value, with support for arbitrarily nested arrays. -/// -/// Unlike `frame_to_results` this will not unwrap single-element arrays. -pub fn frame_to_results_raw(frame: Resp3Frame) -> Result { - let value = match frame { - Resp3Frame::Null => RedisValue::Null, - Resp3Frame::SimpleString { data, .. } => { - let value = string_or_bytes(data); - - if value.as_str().map(|s| s == QUEUED).unwrap_or(false) { - RedisValue::Queued - } else { - value - } - }, - Resp3Frame::SimpleError { data, .. } => return Err(pretty_error(&data)), - Resp3Frame::BlobString { data, .. } => string_or_bytes(data), - Resp3Frame::BlobError { data, .. } => { - let parsed = String::from_utf8_lossy(&data); - return Err(pretty_error(parsed.as_ref())); - }, - Resp3Frame::VerbatimString { data, .. } => string_or_bytes(data), - Resp3Frame::Number { data, .. } => data.into(), - Resp3Frame::Double { data, .. } => data.into(), - Resp3Frame::BigNumber { data, .. } => string_or_bytes(data), - Resp3Frame::Boolean { data, .. } => data.into(), - Resp3Frame::Array { data, .. } | Resp3Frame::Push { data, .. } => { - let mut out = Vec::with_capacity(data.len()); - for frame in data.into_iter() { - out.push(frame_to_results_raw(frame)?); - } - - RedisValue::Array(out) - }, - Resp3Frame::Set { data, .. } => { - let mut out = Vec::with_capacity(data.len()); - for frame in data.into_iter() { - out.push(frame_to_results_raw(frame)?); - } - - RedisValue::Array(out) - }, + Resp3Frame::Array { data, .. } | Resp3Frame::Push { data, .. } => RedisValue::Array( + data + .into_iter() + .map(frame_to_results) + .collect::, _>>()?, + ), + Resp3Frame::Set { data, .. } => RedisValue::Array( + data + .into_iter() + .map(frame_to_results) + .collect::, _>>()?, + ), Resp3Frame::Map { data, .. } => { let mut out = HashMap::with_capacity(data.len()); for (key, value) in data.into_iter() { - let key: RedisKey = frame_to_single_result(key)?.try_into()?; - let value = frame_to_results_raw(value)?; + let key: RedisKey = frame_to_results(key)?.try_into()?; + let value = frame_to_results(value)?; out.insert(key, value); } @@ -527,72 +444,6 @@ pub fn frame_to_results_raw(frame: Resp3Frame) -> Result Ok(value) } -/// Parse the protocol frame into a single redis value, returning an error if the result contains nested arrays, an -/// array with more than one value, or any other aggregate type. -/// -/// If the array only contains one value then that value will be returned. -/// -/// This function is equivalent to [frame_to_results] but with an added validation layer if the result set is a nested -/// array, aggregate type, etc. -#[cfg(not(feature = "mocks"))] -pub fn frame_to_single_result(frame: Resp3Frame) -> Result { - match frame { - Resp3Frame::SimpleString { data, .. } => { - let value = string_or_bytes(data); - - if value.as_str().map(|s| s == QUEUED).unwrap_or(false) { - Ok(RedisValue::Queued) - } else { - Ok(value) - } - }, - Resp3Frame::SimpleError { data, .. } => Err(pretty_error(&data)), - Resp3Frame::Number { data, .. } => Ok(data.into()), - Resp3Frame::Double { data, .. } => Ok(data.into()), - Resp3Frame::BigNumber { data, .. } => Ok(string_or_bytes(data)), - Resp3Frame::Boolean { data, .. } => Ok(data.into()), - Resp3Frame::VerbatimString { data, .. } => Ok(string_or_bytes(data)), - Resp3Frame::BlobString { data, .. } => Ok(string_or_bytes(data)), - Resp3Frame::BlobError { data, .. } => { - // errors don't have a great way to represent non-utf8 strings... - let parsed = String::from_utf8_lossy(&data); - Err(pretty_error(parsed.as_ref())) - }, - Resp3Frame::Array { mut data, .. } | Resp3Frame::Push { mut data, .. } => { - if data.len() > 1 { - return Err(RedisError::new( - RedisErrorKind::Protocol, - "Could not convert multiple frames to RedisValue.", - )); - } else if data.is_empty() { - return Ok(RedisValue::Null); - } - - let first_frame = data.pop().unwrap(); - if first_frame.is_array() || first_frame.is_error() { - // there shouldn't be errors buried in arrays, nor should there be more than one layer of nested arrays - return Err(RedisError::new( - RedisErrorKind::Protocol, - "Invalid nested array or error.", - )); - } - - frame_to_single_result(first_frame) - }, - Resp3Frame::Map { .. } | Resp3Frame::Set { .. } => { - Err(RedisError::new(RedisErrorKind::Protocol, "Invalid aggregate type.")) - }, - Resp3Frame::Null => Ok(RedisValue::Null), - _ => Err(RedisError::new(RedisErrorKind::Protocol, "Unexpected frame kind.")), - } -} - -/// Remove the (often unwanted) validation and parsing layer when using the mocking layer. -#[cfg(feature = "mocks")] -pub fn frame_to_single_result(frame: Resp3Frame) -> Result { - frame_to_results(frame) -} - /// Flatten a single nested layer of arrays or sets into one array. pub fn flatten_frame(frame: Resp3Frame) -> Resp3Frame { match frame { @@ -665,7 +516,7 @@ pub fn frame_to_map(frame: Resp3Frame) -> Result { let mut inner = HashMap::with_capacity(data.len() / 2); while data.len() >= 2 { let value = frame_to_results(data.pop().unwrap())?; - let key = frame_to_single_result(data.pop().unwrap())?.try_into()?; + let key = frame_to_results(data.pop().unwrap())?.try_into()?; inner.insert(key, value); } @@ -817,380 +668,6 @@ pub fn expect_ok(value: &RedisValue) -> Result<(), RedisError> { } } -fn parse_u64(val: &Resp3Frame) -> u64 { - match *val { - Resp3Frame::Number { ref data, .. } => { - if *data < 0 { - 0 - } else { - *data as u64 - } - }, - Resp3Frame::Double { ref data, .. } => *data as u64, - Resp3Frame::BlobString { ref data, .. } | Resp3Frame::SimpleString { ref data, .. } => str::from_utf8(data) - .ok() - .and_then(|s| s.parse::().ok()) - .unwrap_or(0), - _ => 0, - } -} - -fn parse_f64(val: &Resp3Frame) -> f64 { - match *val { - Resp3Frame::Number { ref data, .. } => *data as f64, - Resp3Frame::Double { ref data, .. } => *data, - Resp3Frame::BlobString { ref data, .. } | Resp3Frame::SimpleString { ref data, .. } => str::from_utf8(data) - .ok() - .and_then(|s| redis_string_to_f64(s).ok()) - .unwrap_or(0.0), - _ => 0.0, - } -} - -fn parse_db_memory_stats(data: &Vec) -> Result { - if data.len() % 2 != 0 { - return Err(RedisError::new( - RedisErrorKind::Protocol, - "Invalid MEMORY STATS database response. Result must have an even number of frames.", - )); - } - - let mut out = DatabaseMemoryStats::default(); - for chunk in data.chunks(2) { - let key = match chunk[0].as_str() { - Some(s) => s, - None => continue, - }; - - match key.as_ref() { - "overhead.hashtable.main" => out.overhead_hashtable_main = parse_u64(&chunk[1]), - "overhead.hashtable.expires" => out.overhead_hashtable_expires = parse_u64(&chunk[1]), - _ => {}, - }; - } - - Ok(out) -} - -fn parse_memory_stat_field(stats: &mut MemoryStats, key: &str, value: &Resp3Frame) { - match key.as_ref() { - "peak.allocated" => stats.peak_allocated = parse_u64(value), - "total.allocated" => stats.total_allocated = parse_u64(value), - "startup.allocated" => stats.startup_allocated = parse_u64(value), - "replication.backlog" => stats.replication_backlog = parse_u64(value), - "clients.slaves" => stats.clients_slaves = parse_u64(value), - "clients.normal" => stats.clients_normal = parse_u64(value), - "aof.buffer" => stats.aof_buffer = parse_u64(value), - "lua.caches" => stats.lua_caches = parse_u64(value), - "overhead.total" => stats.overhead_total = parse_u64(value), - "keys.count" => stats.keys_count = parse_u64(value), - "keys.bytes-per-key" => stats.keys_bytes_per_key = parse_u64(value), - "dataset.bytes" => stats.dataset_bytes = parse_u64(value), - "dataset.percentage" => stats.dataset_percentage = parse_f64(value), - "peak.percentage" => stats.peak_percentage = parse_f64(value), - "allocator.allocated" => stats.allocator_allocated = parse_u64(value), - "allocator.active" => stats.allocator_active = parse_u64(value), - "allocator.resident" => stats.allocator_resident = parse_u64(value), - "allocator-fragmentation.ratio" => stats.allocator_fragmentation_ratio = parse_f64(value), - "allocator-fragmentation.bytes" => stats.allocator_fragmentation_bytes = parse_u64(value), - "allocator-rss.ratio" => stats.allocator_rss_ratio = parse_f64(value), - "allocator-rss.bytes" => stats.allocator_rss_bytes = parse_u64(value), - "rss-overhead.ratio" => stats.rss_overhead_ratio = parse_f64(value), - "rss-overhead.bytes" => stats.rss_overhead_bytes = parse_u64(value), - "fragmentation" => stats.fragmentation = parse_f64(value), - "fragmentation.bytes" => stats.fragmentation_bytes = parse_u64(value), - _ => {}, - } -} - -pub fn parse_memory_stats(data: &Vec) -> Result { - if data.len() % 2 != 0 { - return Err(RedisError::new( - RedisErrorKind::Protocol, - "Invalid MEMORY STATS response. Result must have an even number of frames.", - )); - } - - let mut out = MemoryStats::default(); - for chunk in data.chunks(2) { - let key = match chunk[0].as_str() { - Some(s) => s, - None => continue, - }; - - if key.starts_with("db.") { - let db = match key.split(".").last() { - Some(db) => match db.parse::().ok() { - Some(db) => db, - None => continue, - }, - None => continue, - }; - - let inner = match chunk[1] { - Resp3Frame::Array { ref data, .. } => data, - _ => continue, - }; - let parsed = parse_db_memory_stats(inner)?; - - out.db.insert(db, parsed); - } else { - parse_memory_stat_field(&mut out, key, &chunk[1]); - } - } - - Ok(out) -} - -fn parse_acl_getuser_flag(value: &Resp3Frame) -> Result, RedisError> { - if let Resp3Frame::Array { ref data, .. } = value { - let mut out = Vec::with_capacity(data.len()); - - for frame in data.iter() { - let flag = match frame.as_str() { - Some(s) => match s.as_ref() { - "on" => AclUserFlag::On, - "off" => AclUserFlag::Off, - "allcommands" => AclUserFlag::AllCommands, - "allkeys" => AclUserFlag::AllKeys, - "allchannels" => AclUserFlag::AllChannels, - "nopass" => AclUserFlag::NoPass, - _ => continue, - }, - None => continue, - }; - - out.push(flag); - } - - Ok(out) - } else { - Err(RedisError::new( - RedisErrorKind::Protocol, - "Invalid ACL user flags. Expected array.", - )) - } -} - -fn frames_to_strings(frames: &Resp3Frame) -> Result, RedisError> { - match frames { - Resp3Frame::Array { ref data, .. } => { - let mut out = Vec::with_capacity(data.len()); - - for frame in data.iter() { - let val = match frame.as_str() { - Some(v) => v.to_owned(), - None => continue, - }; - - out.push(val); - } - - Ok(out) - }, - Resp3Frame::SimpleString { ref data, .. } => Ok(vec![String::from_utf8(data.to_vec())?]), - Resp3Frame::BlobString { ref data, .. } => Ok(vec![String::from_utf8(data.to_vec())?]), - _ => Err(RedisError::new( - RedisErrorKind::Protocol, - "Expected string or array of frames.", - )), - } -} - -fn parse_acl_getuser_field(user: &mut AclUser, key: &str, value: &Resp3Frame) -> Result<(), RedisError> { - match key.as_ref() { - "passwords" => user.passwords = frames_to_strings(value)?, - "keys" => user.keys = frames_to_strings(value)?, - "channels" => user.channels = frames_to_strings(value)?, - "commands" => { - if let Some(commands) = value.as_str() { - user.commands = commands.split(" ").map(|s| s.to_owned()).collect(); - } - }, - _ => { - debug!("Skip ACL GETUSER field: {}", key); - }, - }; - - Ok(()) -} - -pub fn frame_map_or_set_to_nested_array(frame: Resp3Frame) -> Result { - match frame { - Resp3Frame::Map { data, .. } => { - let mut out = Vec::with_capacity(data.len() * 2); - for (key, value) in data.into_iter() { - out.push(key); - out.push(frame_map_or_set_to_nested_array(value)?); - } - - Ok(Resp3Frame::Array { - data: out, - attributes: None, - }) - }, - Resp3Frame::Set { data, .. } => { - let mut out = Vec::with_capacity(data.len()); - for frame in data.into_iter() { - out.push(frame_map_or_set_to_nested_array(frame)?); - } - - Ok(Resp3Frame::Array { - data: out, - attributes: None, - }) - }, - _ => Ok(frame), - } -} - -pub fn parse_acl_getuser_frames(frames: Vec) -> Result { - if frames.len() % 2 != 0 { - return Err(RedisError::new( - RedisErrorKind::Protocol, - "Invalid number of response frames.", - )); - } - - let mut user = AclUser::default(); - for chunk in frames.chunks(2) { - let key = match chunk[0].as_str() { - Some(s) => s, - None => continue, - }; - - if key == "flags" { - user.flags = parse_acl_getuser_flag(&chunk[1])?; - } else { - parse_acl_getuser_field(&mut user, key, &chunk[1])? - } - } - - Ok(user) -} - -fn parse_slowlog_entry(frames: Vec) -> Result { - if frames.len() < 4 { - return Err(RedisError::new( - RedisErrorKind::Protocol, - "Expected at least 4 response frames.", - )); - } - - let id = match frames[0] { - Resp3Frame::Number { ref data, .. } => *data, - _ => return Err(RedisError::new(RedisErrorKind::Protocol, "Expected integer ID.")), - }; - let timestamp = match frames[1] { - Resp3Frame::Number { ref data, .. } => *data, - _ => return Err(RedisError::new(RedisErrorKind::Protocol, "Expected integer timestamp.")), - }; - let duration = match frames[2] { - Resp3Frame::Number { ref data, .. } => *data as u64, - _ => return Err(RedisError::new(RedisErrorKind::Protocol, "Expected integer duration.")), - }; - let args = match frames[3] { - Resp3Frame::Array { ref data, .. } => data - .iter() - .filter_map(|frame| frame.as_str().map(|s| s.to_owned())) - .collect(), - _ => return Err(RedisError::new(RedisErrorKind::Protocol, "Expected arguments array.")), - }; - - let (ip, name) = if frames.len() == 6 { - let ip = match frames[4].as_str() { - Some(s) => s.to_owned(), - None => return Err(RedisError::new(RedisErrorKind::Protocol, "Expected IP address string.")), - }; - let name = match frames[5].as_str() { - Some(s) => s.to_owned(), - None => { - return Err(RedisError::new( - RedisErrorKind::Protocol, - "Expected client name string.", - )) - }, - }; - - (Some(ip), Some(name)) - } else { - (None, None) - }; - - Ok(SlowlogEntry { - id, - timestamp, - duration, - args, - ip, - name, - }) -} - -pub fn parse_slowlog_entries(frames: Vec) -> Result, RedisError> { - let mut out = Vec::with_capacity(frames.len()); - - for frame in frames.into_iter() { - if let Resp3Frame::Array { data, .. } = frame { - out.push(parse_slowlog_entry(data)?); - } else { - return Err(RedisError::new( - RedisErrorKind::Protocol, - "Expected array of slowlog fields.", - )); - } - } - - Ok(out) -} - -fn parse_cluster_info_line(info: &mut ClusterInfo, line: &str) -> Result<(), RedisError> { - let parts: Vec<&str> = line.split(":").collect(); - if parts.len() != 2 { - return Err(RedisError::new(RedisErrorKind::Protocol, "Expected key:value pair.")); - } - let (field, val) = (parts[0], parts[1]); - - match field.as_ref() { - "cluster_state" => match val.as_ref() { - "ok" => info.cluster_state = ClusterState::Ok, - "fail" => info.cluster_state = ClusterState::Fail, - _ => return Err(RedisError::new(RedisErrorKind::Protocol, "Invalid cluster state.")), - }, - "cluster_slots_assigned" => info.cluster_slots_assigned = parse_or_zero!(val, u16), - "cluster_slots_ok" => info.cluster_slots_ok = parse_or_zero!(val, u16), - "cluster_slots_pfail" => info.cluster_slots_pfail = parse_or_zero!(val, u16), - "cluster_slots_fail" => info.cluster_slots_fail = parse_or_zero!(val, u16), - "cluster_known_nodes" => info.cluster_known_nodes = parse_or_zero!(val, u16), - "cluster_size" => info.cluster_size = parse_or_zero!(val, u32), - "cluster_current_epoch" => info.cluster_current_epoch = parse_or_zero!(val, u64), - "cluster_my_epoch" => info.cluster_my_epoch = parse_or_zero!(val, u64), - "cluster_stats_messages_sent" => info.cluster_stats_messages_sent = parse_or_zero!(val, u64), - "cluster_stats_messages_received" => info.cluster_stats_messages_received = parse_or_zero!(val, u64), - _ => { - warn!("Invalid cluster info field: {}", line); - }, - }; - - Ok(()) -} - -pub fn parse_cluster_info(data: Resp3Frame) -> Result { - if let Some(data) = data.as_str() { - let mut out = ClusterInfo::default(); - - for line in data.lines().into_iter() { - let trimmed = line.trim(); - if !trimmed.is_empty() { - let _ = parse_cluster_info_line(&mut out, trimmed)?; - } - } - Ok(out) - } else { - Err(RedisError::new(RedisErrorKind::Protocol, "Expected string response.")) - } -} - /// Parse the replicas from the ROLE response returned from a master/primary node. #[cfg(feature = "replicas")] pub fn parse_master_role_replicas(data: RedisValue) -> Result, RedisError> { @@ -1220,169 +697,14 @@ pub fn parse_master_role_replicas(data: RedisValue) -> Result, Redis } } -fn frame_to_f64(frame: &Resp3Frame) -> Result { - match frame { - Resp3Frame::Double { ref data, .. } => Ok(*data), - _ => { - if let Some(s) = frame.as_str() { - utils::redis_string_to_f64(s) - } else { - Err(RedisError::new( - RedisErrorKind::Protocol, - "Expected bulk string or double.", - )) - } - }, - } -} - -pub fn parse_geo_position(frame: &Resp3Frame) -> Result { - if let Resp3Frame::Array { ref data, .. } = frame { - if data.len() == 2 { - let longitude = frame_to_f64(&data[0])?; - let latitude = frame_to_f64(&data[1])?; - - Ok(GeoPosition { longitude, latitude }) - } else { - Err(RedisError::new( - RedisErrorKind::Protocol, - "Expected array with 2 coordinates.", - )) - } +pub fn assert_array_len(data: &Vec, len: usize) -> Result<(), RedisError> { + if data.len() == len { + Ok(()) } else { - Err(RedisError::new(RedisErrorKind::Protocol, "Expected array.")) - } -} - -fn assert_frame_len(frames: &Vec, len: usize) -> Result<(), RedisError> { - if frames.len() != len { Err(RedisError::new( - RedisErrorKind::Protocol, - format!("Expected {} frames", len), + RedisErrorKind::Parse, + format!("Expected {} values.", len), )) - } else { - Ok(()) - } -} - -fn parse_geo_member(frame: &Resp3Frame) -> Result { - frame - .as_str() - .ok_or(RedisError::new(RedisErrorKind::Protocol, "Expected string")) - .map(|s| s.into()) -} - -fn parse_geo_dist(frame: &Resp3Frame) -> Result { - match frame { - Resp3Frame::Double { ref data, .. } => Ok(*data), - _ => frame - .as_str() - .ok_or(RedisError::new(RedisErrorKind::Protocol, "Expected double.")) - .and_then(|s| utils::redis_string_to_f64(s)), - } -} - -fn parse_geo_hash(frame: &Resp3Frame) -> Result { - if let Resp3Frame::Number { ref data, .. } = frame { - Ok(*data) - } else { - Err(RedisError::new(RedisErrorKind::Protocol, "Expected integer.")) - } -} - -pub fn parse_georadius_info( - frame: &Resp3Frame, - withcoord: bool, - withdist: bool, - withhash: bool, -) -> Result { - if let Resp3Frame::Array { ref data, .. } = frame { - let mut out = GeoRadiusInfo::default(); - - if withcoord && withdist && withhash { - // 4 elements: member, dist, hash, position - let _ = assert_frame_len(data, 4)?; - - out.member = parse_geo_member(&data[0])?; - out.distance = Some(parse_geo_dist(&data[1])?); - out.hash = Some(parse_geo_hash(&data[2])?); - out.position = Some(parse_geo_position(&data[3])?); - } else if withcoord && withdist { - // 3 elements: member, dist, position - let _ = assert_frame_len(data, 3)?; - - out.member = parse_geo_member(&data[0])?; - out.distance = Some(parse_geo_dist(&data[1])?); - out.position = Some(parse_geo_position(&data[2])?); - } else if withcoord && withhash { - // 3 elements: member, hash, position - let _ = assert_frame_len(data, 3)?; - - out.member = parse_geo_member(&data[0])?; - out.hash = Some(parse_geo_hash(&data[1])?); - out.position = Some(parse_geo_position(&data[2])?); - } else if withdist && withhash { - // 3 elements: member, dist, hash - let _ = assert_frame_len(data, 3)?; - - out.member = parse_geo_member(&data[0])?; - out.distance = Some(parse_geo_dist(&data[1])?); - out.hash = Some(parse_geo_hash(&data[2])?); - } else if withcoord { - // 2 elements: member, position - let _ = assert_frame_len(data, 2)?; - - out.member = parse_geo_member(&data[0])?; - out.position = Some(parse_geo_position(&data[1])?); - } else if withdist { - // 2 elements: member, dist - let _ = assert_frame_len(data, 2)?; - - out.member = parse_geo_member(&data[0])?; - out.distance = Some(parse_geo_dist(&data[1])?); - } else if withhash { - // 2 elements: member, hash - let _ = assert_frame_len(data, 2)?; - - out.member = parse_geo_member(&data[0])?; - out.hash = Some(parse_geo_hash(&data[1])?); - } - - Ok(out) - } else { - let member: RedisValue = match frame.as_str() { - Some(s) => s.into(), - None => { - return Err(RedisError::new( - RedisErrorKind::Protocol, - "Expected string or array of frames.", - )) - }, - }; - - Ok(GeoRadiusInfo { - member, - ..Default::default() - }) - } -} - -pub fn parse_georadius_result( - frame: Resp3Frame, - withcoord: bool, - withdist: bool, - withhash: bool, -) -> Result, RedisError> { - if let Resp3Frame::Array { data, .. } = frame { - let mut out = Vec::with_capacity(data.len()); - - for frame in data.into_iter() { - out.push(parse_georadius_info(&frame, withcoord, withdist, withhash)?); - } - - Ok(out) - } else { - Err(RedisError::new(RedisErrorKind::Protocol, "Expected array.")) } } @@ -1559,7 +881,7 @@ pub fn command_to_resp3_frame(command: &RedisCommand) -> Result { - let parts: Vec<&str> = kind.cmd.trim().split(" ").collect(); + let parts: Vec<&str> = kind.cmd.trim().split(' ').collect(); let mut bulk_strings = Vec::with_capacity(parts.len() + args.len()); for part in parts.into_iter() { @@ -1609,7 +931,7 @@ pub fn command_to_resp2_frame(command: &RedisCommand) -> Result { - let parts: Vec<&str> = kind.cmd.trim().split(" ").collect(); + let parts: Vec<&str> = kind.cmd.trim().split(' ').collect(); let mut bulk_strings = Vec::with_capacity(parts.len() + args.len()); for part in parts.into_iter() { @@ -1649,7 +971,7 @@ pub fn command_to_frame(command: &RedisCommand, is_resp3: bool) -> Result Resp3Frame { Resp3Frame::SimpleString { @@ -1672,80 +994,81 @@ mod tests { } } - fn string_vec(d: Vec<&str>) -> Vec { - d.into_iter().map(|s| s.to_owned()).collect() - } - #[test] fn should_parse_memory_stats() { // better from()/into() interfaces for frames coming in the next redis-protocol version... - let frames: Vec = vec![ - str_to_f("peak.allocated"), - int_to_f(934192), - str_to_f("total.allocated"), - int_to_f(872040), - str_to_f("startup.allocated"), - int_to_f(809912), - str_to_f("replication.backlog"), - int_to_f(0), - str_to_f("clients.slaves"), - int_to_f(0), - str_to_f("clients.normal"), - int_to_f(20496), - str_to_f("aof.buffer"), - int_to_f(0), - str_to_f("lua.caches"), - int_to_f(0), - str_to_f("db.0"), - Resp3Frame::Array { - data: vec![ - str_to_f("overhead.hashtable.main"), - int_to_f(72), - str_to_f("overhead.hashtable.expires"), - int_to_f(0), - ], - attributes: None, - }, - str_to_f("overhead.total"), - int_to_f(830480), - str_to_f("keys.count"), - int_to_f(1), - str_to_f("keys.bytes-per-key"), - int_to_f(62128), - str_to_f("dataset.bytes"), - int_to_f(41560), - str_to_f("dataset.percentage"), - str_to_f("66.894157409667969"), - str_to_f("peak.percentage"), - str_to_f("93.346977233886719"), - str_to_f("allocator.allocated"), - int_to_f(1022640), - str_to_f("allocator.active"), - int_to_f(1241088), - str_to_f("allocator.resident"), - int_to_f(5332992), - str_to_f("allocator-fragmentation.ratio"), - str_to_f("1.2136118412017822"), - str_to_f("allocator-fragmentation.bytes"), - int_to_f(218448), - str_to_f("allocator-rss.ratio"), - str_to_f("4.2970294952392578"), - str_to_f("allocator-rss.bytes"), - int_to_f(4091904), - str_to_f("rss-overhead.ratio"), - str_to_f("2.0268816947937012"), - str_to_f("rss-overhead.bytes"), - int_to_f(5476352), - str_to_f("fragmentation"), - str_to_f("13.007383346557617"), - str_to_f("fragmentation.bytes"), - int_to_f(9978328), - ]; - let memory_stats = parse_memory_stats(&frames).unwrap(); + let input = frame_to_results(Resp3Frame::Array { + data: vec![ + str_to_f("peak.allocated"), + int_to_f(934192), + str_to_f("total.allocated"), + int_to_f(872040), + str_to_f("startup.allocated"), + int_to_f(809912), + str_to_f("replication.backlog"), + int_to_f(0), + str_to_f("clients.slaves"), + int_to_f(0), + str_to_f("clients.normal"), + int_to_f(20496), + str_to_f("aof.buffer"), + int_to_f(0), + str_to_f("lua.caches"), + int_to_f(0), + str_to_f("db.0"), + Resp3Frame::Array { + data: vec![ + str_to_f("overhead.hashtable.main"), + int_to_f(72), + str_to_f("overhead.hashtable.expires"), + int_to_f(0), + ], + attributes: None, + }, + str_to_f("overhead.total"), + int_to_f(830480), + str_to_f("keys.count"), + int_to_f(1), + str_to_f("keys.bytes-per-key"), + int_to_f(62128), + str_to_f("dataset.bytes"), + int_to_f(41560), + str_to_f("dataset.percentage"), + str_to_f("66.894157409667969"), + str_to_f("peak.percentage"), + str_to_f("93.346977233886719"), + str_to_f("allocator.allocated"), + int_to_f(1022640), + str_to_f("allocator.active"), + int_to_f(1241088), + str_to_f("allocator.resident"), + int_to_f(5332992), + str_to_f("allocator-fragmentation.ratio"), + str_to_f("1.2136118412017822"), + str_to_f("allocator-fragmentation.bytes"), + int_to_f(218448), + str_to_f("allocator-rss.ratio"), + str_to_f("4.2970294952392578"), + str_to_f("allocator-rss.bytes"), + int_to_f(4091904), + str_to_f("rss-overhead.ratio"), + str_to_f("2.0268816947937012"), + str_to_f("rss-overhead.bytes"), + int_to_f(5476352), + str_to_f("fragmentation"), + str_to_f("13.007383346557617"), + str_to_f("fragmentation.bytes"), + int_to_f(9978328), + ], + attributes: None, + }) + .unwrap(); + let memory_stats: MemoryStats = input.convert().unwrap(); let expected_db_0 = DatabaseMemoryStats { - overhead_hashtable_expires: 0, - overhead_hashtable_main: 72, + overhead_hashtable_expires: 0, + overhead_hashtable_main: 72, + overhead_hashtable_slot_to_keys: 0, }; let mut expected_db = HashMap::new(); expected_db.insert(0, expected_db_0); @@ -1763,16 +1086,16 @@ mod tests { keys_count: 1, keys_bytes_per_key: 62128, dataset_bytes: 41560, - dataset_percentage: 66.894157409667969, - peak_percentage: 93.346977233886719, + dataset_percentage: 66.894_157_409_667_97, + peak_percentage: 93.346_977_233_886_72, allocator_allocated: 1022640, allocator_active: 1241088, allocator_resident: 5332992, allocator_fragmentation_ratio: 1.2136118412017822, allocator_fragmentation_bytes: 218448, - allocator_rss_ratio: 4.2970294952392578, + allocator_rss_ratio: 4.297_029_495_239_258, allocator_rss_bytes: 4091904, - rss_overhead_ratio: 2.0268816947937012, + rss_overhead_ratio: 2.026_881_694_793_701, rss_overhead_bytes: 5476352, fragmentation: 13.007383346557617, fragmentation_bytes: 9978328, @@ -1781,66 +1104,6 @@ mod tests { assert_eq!(memory_stats, expected); } - #[test] - fn should_parse_acl_getuser_response() { - // 127.0.0.1:6379> acl getuser alec - // 1) "flags" - // 2) 1) "on" - // 3) "passwords" - // 4) 1) "c56e8629954a900e993e84ed3d4b134b9450da1b411a711d047d547808c3ece5" - // 2) "39b039a94deaa548cf6382282c4591eccdc648706f9d608eceb687d452a31a45" - // 5) "commands" - // 6) "-@all +@sortedset +@geo +config|get" - // 7) "keys" - // 8) 1) "a" - // 2) "b" - // 3) "c" - // 9) "channels" - // 10) 1) "c1" - // 2) "c2" - - let input = vec![ - str_to_bs("flags"), - Resp3Frame::Array { - data: vec![str_to_bs("on")], - attributes: None, - }, - str_to_bs("passwords"), - Resp3Frame::Array { - data: vec![ - str_to_bs("c56e8629954a900e993e84ed3d4b134b9450da1b411a711d047d547808c3ece5"), - str_to_bs("39b039a94deaa548cf6382282c4591eccdc648706f9d608eceb687d452a31a45"), - ], - attributes: None, - }, - str_to_bs("commands"), - str_to_bs("-@all +@sortedset +@geo +config|get"), - str_to_bs("keys"), - Resp3Frame::Array { - data: vec![str_to_bs("a"), str_to_bs("b"), str_to_bs("c")], - attributes: None, - }, - str_to_bs("channels"), - Resp3Frame::Array { - data: vec![str_to_bs("c1"), str_to_bs("c2")], - attributes: None, - }, - ]; - let actual = parse_acl_getuser_frames(input).unwrap(); - - let expected = AclUser { - flags: vec![AclUserFlag::On], - passwords: string_vec(vec![ - "c56e8629954a900e993e84ed3d4b134b9450da1b411a711d047d547808c3ece5", - "39b039a94deaa548cf6382282c4591eccdc648706f9d608eceb687d452a31a45", - ]), - commands: string_vec(vec!["-@all", "+@sortedset", "+@geo", "+config|get"]), - keys: string_vec(vec!["a", "b", "c"]), - channels: string_vec(vec!["c1", "c2"]), - }; - assert_eq!(actual, expected); - } - #[test] fn should_parse_slowlog_entries_redis_3() { // redis 127.0.0.1:6379> slowlog get 2 @@ -1855,29 +1118,33 @@ mod tests { // 2) "get" // 3) "100" - let input = vec![ - Resp3Frame::Array { - data: vec![int_to_f(14), int_to_f(1309448221), int_to_f(15), Resp3Frame::Array { - data: vec![str_to_bs("ping")], + let input = frame_to_results(Resp3Frame::Array { + data: vec![ + Resp3Frame::Array { + data: vec![int_to_f(14), int_to_f(1309448221), int_to_f(15), Resp3Frame::Array { + data: vec![str_to_bs("ping")], + attributes: None, + }], attributes: None, - }], - attributes: None, - }, - Resp3Frame::Array { - data: vec![int_to_f(13), int_to_f(1309448128), int_to_f(30), Resp3Frame::Array { - data: vec![str_to_bs("slowlog"), str_to_bs("get"), str_to_bs("100")], + }, + Resp3Frame::Array { + data: vec![int_to_f(13), int_to_f(1309448128), int_to_f(30), Resp3Frame::Array { + data: vec![str_to_bs("slowlog"), str_to_bs("get"), str_to_bs("100")], + attributes: None, + }], attributes: None, - }], - attributes: None, - }, - ]; - let actual = parse_slowlog_entries(input).unwrap(); + }, + ], + attributes: None, + }) + .unwrap(); + let actual: Vec = input.convert().unwrap(); let expected = vec![ SlowlogEntry { id: 14, timestamp: 1309448221, - duration: 15, + duration: Duration::from_micros(15), args: vec!["ping".into()], ip: None, name: None, @@ -1885,7 +1152,7 @@ mod tests { SlowlogEntry { id: 13, timestamp: 1309448128, - duration: 30, + duration: Duration::from_micros(30), args: vec!["slowlog".into(), "get".into(), "100".into()], ip: None, name: None, @@ -1913,43 +1180,47 @@ mod tests { // 5) "127.0.0.1:58217" // 6) "worker-123" - let input = vec![ - Resp3Frame::Array { - data: vec![ - int_to_f(14), - int_to_f(1309448221), - int_to_f(15), - Resp3Frame::Array { - data: vec![str_to_bs("ping")], - attributes: None, - }, - str_to_bs("127.0.0.1:58217"), - str_to_bs("worker-123"), - ], - attributes: None, - }, - Resp3Frame::Array { - data: vec![ - int_to_f(13), - int_to_f(1309448128), - int_to_f(30), - Resp3Frame::Array { - data: vec![str_to_bs("slowlog"), str_to_bs("get"), str_to_bs("100")], - attributes: None, - }, - str_to_bs("127.0.0.1:58217"), - str_to_bs("worker-123"), - ], - attributes: None, - }, - ]; - let actual = parse_slowlog_entries(input).unwrap(); + let input = frame_to_results(Resp3Frame::Array { + data: vec![ + Resp3Frame::Array { + data: vec![ + int_to_f(14), + int_to_f(1309448221), + int_to_f(15), + Resp3Frame::Array { + data: vec![str_to_bs("ping")], + attributes: None, + }, + str_to_bs("127.0.0.1:58217"), + str_to_bs("worker-123"), + ], + attributes: None, + }, + Resp3Frame::Array { + data: vec![ + int_to_f(13), + int_to_f(1309448128), + int_to_f(30), + Resp3Frame::Array { + data: vec![str_to_bs("slowlog"), str_to_bs("get"), str_to_bs("100")], + attributes: None, + }, + str_to_bs("127.0.0.1:58217"), + str_to_bs("worker-123"), + ], + attributes: None, + }, + ], + attributes: None, + }) + .unwrap(); + let actual: Vec = input.convert().unwrap(); let expected = vec![ SlowlogEntry { id: 14, timestamp: 1309448221, - duration: 15, + duration: Duration::from_micros(15), args: vec!["ping".into()], ip: Some("127.0.0.1:58217".into()), name: Some("worker-123".into()), @@ -1957,7 +1228,7 @@ mod tests { SlowlogEntry { id: 13, timestamp: 1309448128, - duration: 30, + duration: Duration::from_micros(30), args: vec!["slowlog".into(), "get".into(), "100".into()], ip: Some("127.0.0.1:58217".into()), name: Some("worker-123".into()), @@ -1969,7 +1240,7 @@ mod tests { #[test] fn should_parse_cluster_info() { - let input = "cluster_state:fail + let input: RedisValue = "cluster_state:fail cluster_slots_assigned:16384 cluster_slots_ok:16384 cluster_slots_pfail:3 @@ -1979,7 +1250,8 @@ cluster_size:3 cluster_current_epoch:6 cluster_my_epoch:2 cluster_stats_messages_sent:1483972 -cluster_stats_messages_received:1483968"; +cluster_stats_messages_received:1483968" + .into(); let expected = ClusterInfo { cluster_state: ClusterState::Fail, @@ -1994,12 +1266,8 @@ cluster_stats_messages_received:1483968"; cluster_stats_messages_sent: 1483972, cluster_stats_messages_received: 1483968, }; + let actual: ClusterInfo = input.convert().unwrap(); - let actual = parse_cluster_info(Resp3Frame::BlobString { - data: input.as_bytes().into(), - attributes: None, - }) - .unwrap(); assert_eq!(actual, expected); } } diff --git a/src/router/centralized.rs b/src/router/centralized.rs index e7610bca..c560f10d 100644 --- a/src/router/centralized.rs +++ b/src/router/centralized.rs @@ -16,21 +16,21 @@ use crate::{ use std::sync::Arc; use tokio::task::JoinHandle; -pub async fn send_command( +pub async fn write( inner: &Arc, writer: &mut Option, command: RedisCommand, force_flush: bool, -) -> Result { +) -> Written { if let Some(writer) = writer.as_mut() { - Ok(utils::write_command(inner, writer, command, force_flush).await) + utils::write_command(inner, writer, command, force_flush).await } else { _debug!(inner, "Failed to read connection for {}", command.kind.to_str_debug()); - Ok(Written::Disconnect(( + Written::Disconnected(( None, Some(command), RedisError::new(RedisErrorKind::IO, "Missing connection."), - ))) + )) } } @@ -140,14 +140,12 @@ pub async fn process_response_frame( let _ = tx.send(RouterResponse::TransactionError((error, command))); } return Ok(()); + } else if command.kind.ends_transaction() { + command.respond_to_router(inner, RouterResponse::TransactionResult(frame)); + return Ok(()); } else { - if command.kind.ends_transaction() { - command.respond_to_router(inner, RouterResponse::TransactionResult(frame)); - return Ok(()); - } else { - command.respond_to_router(inner, RouterResponse::Continue); - return Ok(()); - } + command.respond_to_router(inner, RouterResponse::Continue); + return Ok(()); } } @@ -209,17 +207,11 @@ pub async fn initialize_connection( ServerConfig::Centralized { ref server } => server.clone(), _ => return Err(RedisError::new(RedisErrorKind::Config, "Expected centralized config.")), }; - let mut transport = connection::create( - inner, - server.host.as_str().to_owned(), - server.port, - None, - server.tls_server_name.as_ref(), - ) - .await?; - let _ = transport.setup(inner, None).await?; - - let (_, _writer) = connection::split_and_initialize(inner, transport, false, spawn_reader_task)?; + let mut transport = connection::create(inner, &server, None).await?; + transport.setup(inner, None).await?; + let (server, _writer) = connection::split_and_initialize(inner, transport, false, spawn_reader_task)?; + inner.notifications.broadcast_reconnect(server); + *writer = Some(_writer); Ok(()) }, diff --git a/src/router/clustered.rs b/src/router/clustered.rs index 6b68bf64..0c562b35 100644 --- a/src/router/clustered.rs +++ b/src/router/clustered.rs @@ -1,6 +1,5 @@ use crate::{ error::{RedisError, RedisErrorKind}, - globals::globals, interfaces, interfaces::Resp3Frame, modules::inner::RedisClientInner, @@ -16,88 +15,97 @@ use crate::{ types::ClusterStateChange, utils as client_utils, }; +use futures::future::try_join_all; +use parking_lot::Mutex; use std::{ - collections::{BTreeSet, HashMap}, + collections::{BTreeSet, HashMap, VecDeque}, iter::repeat, sync::Arc, }; use tokio::task::JoinHandle; -pub fn find_cluster_node<'a>( +/// Find the cluster node that should receive the command. +pub fn route_command<'a>( inner: &Arc, state: &'a ClusterRouting, command: &RedisCommand, ) -> Option<&'a Server> { - command - .cluster_hash() - .and_then(|slot| state.get_server(slot)) - .or_else(|| { - let node = state.random_node(); - _trace!( - inner, - "Using random cluster node `{:?}` for {}", - node, - command.kind.to_str_debug() - ); - node + if let Some(ref server) = command.cluster_node { + // this `_server` has a lifetime tied to `command`, so we switch `server` to refer to the record in `state` while + // we check whether that node exists in the cluster. we return None here if the command specifies a server that + // does not exist in the cluster. + _trace!(inner, "Routing with custom cluster node: {}", server); + state.slots().iter().find_map(|slot| { + if slot.primary == *server { + Some(&slot.primary) + } else { + None + } }) + } else { + command + .cluster_hash() + .and_then(|slot| state.get_server(slot)) + .or_else(|| { + // for some commands we know they can go to any node, but for others it may depend on the arguments provided. + if command.args().is_empty() || command.kind.use_random_cluster_node() { + let node = state.random_node(); + _trace!( + inner, + "Using random cluster node `{:?}` for {}", + node, + command.kind.to_str_debug() + ); + node + } else { + None + } + }) + } } /// Write a command to the cluster according to the [cluster hashing](https://redis.io/docs/reference/cluster-spec/) interface. -pub async fn send_command( +pub async fn write( inner: &Arc, writers: &mut HashMap, state: &ClusterRouting, - mut command: RedisCommand, + command: RedisCommand, force_flush: bool, -) -> Result { - // first check whether the caller specified a specific cluster node that should receive the command - let server = if let Some(ref _server) = command.cluster_node { - // this `_server` has a lifetime tied to `command`, so we switch `server` to refer to the record in `state` while - // we check whether that node exists in the cluster - let server = state.slots().iter().find_map(|slot| { - if slot.primary == *_server { - Some(&slot.primary) - } else { - None - } - }); - - if let Some(server) = server { - server - } else { - _debug!( - inner, - "Respond to caller with error from missing cluster node override ({})", - _server - ); - command.respond_to_caller(Err(RedisError::new( - RedisErrorKind::Cluster, - "Missing cluster node override.", - ))); - command.respond_to_router(inner, RouterResponse::Continue); +) -> Written { + let has_custom_server = command.cluster_node.is_some(); + let server = match route_command(inner, state, &command) { + Some(server) => server, + None => { + return if has_custom_server { + _debug!( + inner, + "Respond to caller with error from missing cluster node override ({:?})", + command.cluster_node + ); + command.finish( + inner, + Err(RedisError::new( + RedisErrorKind::Cluster, + "Missing cluster node override.", + )), + ); - return Ok(Written::Ignore); - } - } else { - // otherwise apply whichever cluster hash policy exists in the command - match find_cluster_node(inner, state, &command) { - Some(server) => server, - None => { + Written::Ignore + } else { // these errors usually mean the cluster is partially down or misconfigured _warn!( inner, "Possible cluster misconfiguration. Missing hash slot owner for {:?}.", command.cluster_hash() ); - return Ok(Written::Sync(command)); - }, - } + Written::NotFound(command) + }; + }, }; if let Some(writer) = writers.get_mut(server) { _debug!(inner, "Writing command `{}` to {}", command.kind.to_str_debug(), server); - Ok(utils::write_command(inner, writer, command, force_flush).await) + utils::write_command(inner, writer, command, force_flush).await } else { // a reconnect message should already be queued from the reader task _debug!( @@ -107,11 +115,11 @@ pub async fn send_command( command.kind.to_str_debug() ); - Ok(Written::Disconnect(( + Written::Disconnected(( Some(server.clone()), Some(command), RedisError::new(RedisErrorKind::IO, "Missing connection."), - ))) + )) } } @@ -164,7 +172,7 @@ pub async fn send_all_cluster_command( let mut cmd = command.duplicate(cmd_responder); cmd.skip_backpressure = true; - if let Written::Disconnect((server, _, err)) = utils::write_command(inner, writer, cmd, true).await { + if let Written::Disconnected((server, _, err)) = utils::write_command(inner, writer, cmd, true).await { _debug!( inner, "Exit all nodes command early ({}/{}: {:?}) from error: {:?}", @@ -193,8 +201,8 @@ pub fn parse_cluster_changes( for server in writers.keys() { old_servers.insert(server.clone()); } - let add = new_servers.difference(&old_servers).map(|s| s.clone()).collect(); - let remove = old_servers.difference(&new_servers).map(|s| s.clone()).collect(); + let add = new_servers.difference(&old_servers).cloned().collect(); + let remove = old_servers.difference(&new_servers).cloned().collect(); ClusterChange { add, remove } } @@ -444,14 +452,12 @@ pub async fn process_response_frame( let _ = tx.send(RouterResponse::TransactionError((error, command))); } return Ok(()); + } else if command.kind.ends_transaction() { + command.respond_to_router(inner, RouterResponse::TransactionResult(frame)); + return Ok(()); } else { - if command.kind.ends_transaction() { - command.respond_to_router(inner, RouterResponse::TransactionResult(frame)); - return Ok(()); - } else { - command.respond_to_router(inner, RouterResponse::Continue); - return Ok(()); - } + command.respond_to_router(inner, RouterResponse::Continue); + return Ok(()); } } @@ -507,21 +513,13 @@ pub async fn connect_any( } else { BTreeSet::new() }; - all_servers.extend(inner.config.server.hosts().into_iter().map(|server| server.clone())); + all_servers.extend(inner.config.server.hosts().into_iter().cloned()); _debug!(inner, "Attempting clustered connections to any of {:?}", all_servers); let num_servers = all_servers.len(); let mut last_error = None; for (idx, server) in all_servers.into_iter().enumerate() { - let connection = connection::create( - inner, - server.host.as_str().to_owned(), - server.port, - None, - server.tls_server_name.as_ref(), - ) - .await; - let mut connection = match connection { + let mut connection = match connection::create(inner, &server, None).await { Ok(connection) => connection, Err(e) => { last_error = Some(e); @@ -559,8 +557,6 @@ pub async fn cluster_slots_backchannel( inner: &Arc, cache: Option<&ClusterRouting>, ) -> Result { - let timeout = globals().default_connection_timeout_ms(); - let (response, host) = { let command: RedisCommand = RedisCommandKind::ClusterSlots.into(); @@ -571,10 +567,13 @@ pub async fn cluster_slots_backchannel( let default_host = transport.default_host.clone(); _trace!(inner, "Sending backchannel CLUSTER SLOTS to {}", transport.server); - client_utils::apply_timeout(transport.request_response(command, inner.is_resp3()), timeout) - .await - .ok() - .map(|frame| (frame, default_host)) + client_utils::apply_timeout( + transport.request_response(command, inner.is_resp3()), + inner.internal_command_timeout(), + ) + .await + .ok() + .map(|frame| (frame, default_host)) } else { None } @@ -591,8 +590,11 @@ pub async fn cluster_slots_backchannel( if frame.is_error() { // try connecting to any of the nodes, then try again let mut transport = connect_any(inner, old_cache).await?; - let frame = - client_utils::apply_timeout(transport.request_response(command, inner.is_resp3()), timeout).await?; + let frame = client_utils::apply_timeout( + transport.request_response(command, inner.is_resp3()), + inner.internal_command_timeout(), + ) + .await?; let host = transport.default_host.clone(); inner.update_backchannel(transport).await; @@ -604,14 +606,18 @@ pub async fn cluster_slots_backchannel( } else { // try connecting to any of the nodes, then try again let mut transport = connect_any(inner, old_cache).await?; - let frame = client_utils::apply_timeout(transport.request_response(command, inner.is_resp3()), timeout).await?; + let frame = client_utils::apply_timeout( + transport.request_response(command, inner.is_resp3()), + inner.internal_command_timeout(), + ) + .await?; let host = transport.default_host.clone(); inner.update_backchannel(transport).await; (frame, host) }; - (protocol_utils::frame_to_results_raw(frame)?, host) + (protocol_utils::frame_to_results(frame)?, host) }; _trace!(inner, "Recv CLUSTER SLOTS response: {:?}", response); if response.is_null() { @@ -624,10 +630,27 @@ pub async fn cluster_slots_backchannel( let mut new_cache = ClusterRouting::new(); _debug!(inner, "Rebuilding cluster state from host: {}", host); - new_cache.rebuild(inner, response, host.as_str())?; + new_cache.rebuild(inner, response, &host)?; Ok(new_cache) } +/// Check each connection and remove it from the writer map if it's not [working](RedisWriter::is_working). +pub async fn drop_broken_connections(writers: &mut HashMap) -> CommandBuffer { + let mut new_writers = HashMap::with_capacity(writers.len()); + let mut buffer = VecDeque::new(); + + for (server, writer) in writers.drain() { + if writer.is_working() { + new_writers.insert(server, writer); + } else { + buffer.extend(writer.graceful_close().await); + } + } + + *writers = new_writers; + buffer +} + /// Run `CLUSTER SLOTS`, update the cached routing table, and modify the connection map. pub async fn sync( inner: &Arc, @@ -647,8 +670,9 @@ pub async fn sync( .update_cluster_state(Some(state.clone())); *cache = state.clone(); + buffer.extend(drop_broken_connections(writers).await); // detect changes to the cluster topology - let changes = parse_cluster_changes(&state, &writers); + let changes = parse_cluster_changes(&state, writers); _debug!(inner, "Changing cluster connections: {:?}", changes); broadcast_cluster_change(inner, &changes); @@ -664,20 +688,26 @@ pub async fn sync( buffer.extend(commands); } + let mut connections_ft = Vec::with_capacity(changes.add.len()); + let new_writers = Arc::new(Mutex::new(HashMap::with_capacity(changes.add.len()))); // connect to each of the new nodes for server in changes.add.into_iter() { - _debug!(inner, "Connecting to cluster node {}", server); - let mut transport = connection::create( - inner, - server.host.as_str().to_owned(), - server.port, - None, - server.tls_server_name.as_ref(), - ) - .await?; - let _ = transport.setup(inner, None).await?; + let _inner = inner.clone(); + let _new_writers = new_writers.clone(); + connections_ft.push(async move { + _debug!(inner, "Connecting to cluster node {}", server); + let mut transport = connection::create(&_inner, &server, None).await?; + transport.setup(&_inner, None).await?; + + let (server, writer) = connection::split_and_initialize(&_inner, transport, false, spawn_reader_task)?; + inner.notifications.broadcast_reconnect(server.clone()); + _new_writers.lock().insert(server, writer); + Ok::<_, RedisError>(()) + }); + } - let (server, writer) = connection::split_and_initialize(inner, transport, false, spawn_reader_task)?; + let _ = try_join_all(connections_ft).await?; + for (server, writer) in new_writers.lock().drain() { writers.insert(server, writer); } diff --git a/src/router/commands.rs b/src/router/commands.rs index ba197912..79a0d906 100644 --- a/src/router/commands.rs +++ b/src/router/commands.rs @@ -7,19 +7,17 @@ use crate::{ utils as client_utils, }; use redis_protocol::resp3::types::Frame as Resp3Frame; -use std::{sync::Arc, time::Duration}; -use tokio::{sync::oneshot::Sender as OneshotSender, time::sleep}; +use std::sync::Arc; +use tokio::sync::oneshot::Sender as OneshotSender; -#[cfg(feature = "mocks")] -use crate::{modules::mocks::Mocks, protocol::utils as protocol_utils}; #[cfg(feature = "full-tracing")] use tracing_futures::Instrument; /// Wait for the response from the reader task, handling cluster redirections if needed. /// -/// Returns the command to be retried later if needed. +/// The command is returned if it failed to write but could be immediately retried. /// -/// Note: This does **not** handle transaction errors. +/// Errors from this function should end the connection task. async fn handle_router_response( inner: &Arc, router: &mut Router, @@ -39,30 +37,42 @@ async fn handle_router_response( match response { RouterResponse::Continue => Ok(None), RouterResponse::Ask((slot, server, mut command)) => { - let _ = utils::send_asking_with_policy(inner, router, &server, slot).await?; - command.hasher = ClusterHash::Custom(slot); - command.use_replica = false; - Ok(Some(command)) + if let Err(e) = command.decr_check_redirections() { + command.respond_to_caller(Err(e)); + Ok(None) + } else { + utils::send_asking_with_policy(inner, router, &server, slot).await?; + command.hasher = ClusterHash::Custom(slot); + command.use_replica = false; + command.attempts_remaining += 1; + Ok(Some(command)) + } }, RouterResponse::Moved((slot, server, mut command)) => { // check if slot belongs to server, if not then run sync cluster if !router.cluster_node_owns_slot(slot, &server) { - let _ = utils::sync_cluster_with_policy(inner, router).await?; + utils::sync_cluster_with_policy(inner, router).await?; } - command.hasher = ClusterHash::Custom(slot); - command.use_replica = false; - Ok(Some(command)) + if let Err(e) = command.decr_check_redirections() { + command.finish(inner, Err(e)); + Ok(None) + } else { + command.hasher = ClusterHash::Custom(slot); + command.use_replica = false; + command.attempts_remaining += 1; + Ok(Some(command)) + } }, - RouterResponse::ConnectionClosed((error, mut command)) => { - let command = if command.attempted >= inner.max_command_attempts() { - command.respond_to_caller(Err(error.clone())); + RouterResponse::ConnectionClosed((error, command)) => { + let command = if command.should_finish_with_error(inner) { + command.finish(inner, Err(error.clone())); None } else { Some(command) }; - let _ = utils::reconnect_with_policy(inner, router).await?; + utils::reconnect_with_policy(inner, router).await?; Ok(command) }, RouterResponse::TransactionError(_) | RouterResponse::TransactionResult(_) => { @@ -78,16 +88,15 @@ async fn handle_router_response( } } -/// Continuously write the command until it is sent or fails with a fatal error. -/// -/// If the connection closes the command will be queued to run later. The reader task will send a command to reconnect -/// some time later. +/// Continuously write the command until it is sent, queued to try later, or fails with a fatal error. async fn write_with_backpressure( inner: &Arc, router: &mut Router, command: RedisCommand, force_pipeline: bool, ) -> Result<(), RedisError> { + _trace!(inner, "Writing command: {:?}", command); + let mut _command: Option = Some(command); let mut _backpressure: Option = None; loop { @@ -95,8 +104,13 @@ async fn write_with_backpressure( Some(command) => command, None => return Err(RedisError::new(RedisErrorKind::Unknown, "Missing command.")), }; - // TODO clean this up - let rx = match _backpressure { + if let Err(e) = command.decr_check_attempted() { + command.finish(inner, Err(e)); + break; + } + + // apply backpressure first if needed. as a part of that check we may decide to block on the next command. + let router_rx = match _backpressure { Some(backpressure) => match backpressure.wait(inner, &mut command).await { Ok(Some(rx)) => Some(rx), Ok(None) => { @@ -119,54 +133,72 @@ async fn write_with_backpressure( } }, }; - let closes_connection = command.kind.closes_connection(); let is_blocking = command.blocks_connection(); let use_replica = command.use_replica; let result = if use_replica { - router.write_replica_command(command, false).await + router.write_replica(command, false).await } else { - router.write_command(command, false).await + router.write(command, false).await }; match result { - Ok(Written::Backpressure((command, backpressure))) => { + Written::Backpressure((mut command, backpressure)) => { _debug!(inner, "Recv backpressure again for {}.", command.kind.to_str_debug()); + // backpressure doesn't count as a write attempt + command.attempts_remaining += 1; _command = Some(command); _backpressure = Some(backpressure); continue; }, - Ok(Written::Disconnect((server, command, error))) => { + Written::Disconnected((server, command, error)) => { _debug!(inner, "Handle disconnect for {:?} due to {:?}", server, error); let commands = router.connections.disconnect(inner, server.as_ref()).await; router.buffer_commands(commands); if let Some(command) = command { - router.buffer_command(command); + if command.should_finish_with_error(inner) { + command.finish(inner, Err(error)); + } else { + router.buffer_command(command); + } } router.sync_network_timeout_state(); break; }, - Ok(Written::Sync(command)) => { + Written::NotFound(mut command) => { + if let Err(e) = command.decr_check_redirections() { + command.finish(inner, Err(e)); + break; + } + _debug!(inner, "Perform cluster sync after missing hash slot lookup."); - // disconnecting from everything forces the caller into a reconnect loop - router.disconnect_all().await; - router.buffer_command(command); - break; + if let Err(error) = router.sync_cluster().await { + // try to sync the cluster once, and failing that buffer the command. a failed cluster sync will clear local + // cluster state and old connections, which then forces a reconnect from the reader tasks when the streams + // close. + _warn!(inner, "Failed to sync cluster after NotFound: {:?}", error); + router.buffer_command(command); + break; + } else { + _command = Some(command); + _backpressure = None; + continue; + } }, - Ok(Written::Ignore) => { + Written::Ignore => { _trace!(inner, "Ignore `Written` response."); break; }, - Ok(Written::SentAll) => { + Written::SentAll => { _trace!(inner, "Sent command to all servers."); let _ = router.check_and_flush().await; - if let Some(mut command) = handle_router_response(inner, router, rx).await? { + if let Some(command) = handle_router_response(inner, router, router_rx).await? { // commands that are sent to all nodes are not retried after a connection closing _warn!(inner, "Responding with canceled error after all nodes command failure."); - command.respond_to_caller(Err(RedisError::new_canceled())); + command.finish(inner, Err(RedisError::new_canceled())); break; } else { if closes_connection { @@ -177,7 +209,7 @@ async fn write_with_backpressure( break; } }, - Ok(Written::Sent((server, flushed))) => { + Written::Sent((server, flushed)) => { _trace!(inner, "Sent command to {}. Flushed: {}", server, flushed); if is_blocking { inner.backchannel.write().await.set_blocked(&server); @@ -196,7 +228,7 @@ async fn write_with_backpressure( } } - if let Some(command) = handle_router_response(inner, router, rx).await? { + if let Some(command) = handle_router_response(inner, router, router_rx).await? { _command = Some(command); _backpressure = None; continue; @@ -209,15 +241,28 @@ async fn write_with_backpressure( break; } }, - Err(e) => { - if use_replica { - _debug!(inner, "Disconnect replicas after error writing to replica: {:?}", e); - // triggers a reconnect message from the reader tasks - router.disconnect_replicas().await; - break; - } else { - return Err(e); + Written::Error((error, command)) => { + _debug!(inner, "Fatal error writing command: {:?}", error); + if let Some(command) = command { + command.finish(inner, Err(error.clone())); } + inner.notifications.broadcast_error(error.clone()); + + return Err(error); + }, + #[cfg(feature = "replicas")] + Written::Fallback(command) => { + _error!( + inner, + "Unexpected replica response to {} ({})", + command.kind.to_str_debug(), + command.debug_id() + ); + command.finish( + inner, + Err(RedisError::new(RedisErrorKind::Replica, "Unexpected replica response.")), + ); + break; }, } } @@ -287,36 +332,20 @@ async fn process_ask( command.use_replica = false; command.hasher = ClusterHash::Custom(slot); - let mut _command = Some(command); - loop { - let mut command = match _command.take() { - Some(command) => command, - None => { - _warn!(inner, "Missing command following an ASKING redirect."); - return Ok(()); - }, - }; - - if let Err(e) = utils::send_asking_with_policy(inner, router, &server, slot).await { - command.respond_to_caller(Err(e.clone())); - return Err(e); - } - - if let Err(e) = command.incr_check_attempted(inner.max_command_attempts()) { - command.respond_to_caller(Err(e)); - break; - } - // TODO fix this for blocking commands - if let Err((error, command)) = router.write_direct(command, &server).await { - _warn!(inner, "Error retrying command after ASKING: {:?}", error); - _command = Some(command); - continue; - } else { - break; - } + if let Err(e) = command.decr_check_redirections() { + command.respond_to_caller(Err(e)); + return Ok(()); + } + if let Err(e) = utils::send_asking_with_policy(inner, router, &server, slot).await { + command.respond_to_caller(Err(e.clone())); + return Err(e); + } + if let Err(error) = write_with_backpressure_t(inner, router, command, false).await { + _debug!(inner, "Error sending command after ASKING: {:?}", error); + Err(error) + } else { + Ok(()) } - - Ok(()) } /// Sync the cluster state then retry the command. @@ -330,36 +359,21 @@ async fn process_moved( command.use_replica = false; command.hasher = ClusterHash::Custom(slot); - let mut _command = Some(command); - loop { - let mut command = match _command.take() { - Some(command) => command, - None => { - _warn!(inner, "Missing command following an MOVED redirect."); - return Ok(()); - }, - }; - - if let Err(e) = utils::sync_cluster_with_policy(inner, router).await { - command.respond_to_caller(Err(e.clone())); - return Err(e); - } - - if let Err(e) = command.incr_check_attempted(inner.max_command_attempts()) { - command.respond_to_caller(Err(e)); - break; - } - // TODO fix this for blocking commands - if let Err((error, command)) = router.write_direct(command, &server).await { - _warn!(inner, "Error retrying command after ASKING: {:?}", error); - _command = Some(command); - continue; - } else { - break; - } + _debug!(inner, "Syncing cluster after MOVED {} {}", slot, server); + if let Err(e) = utils::sync_cluster_with_policy(inner, router).await { + command.respond_to_caller(Err(e.clone())); + return Err(e); + } + if let Err(e) = command.decr_check_redirections() { + command.respond_to_caller(Err(e)); + return Ok(()); + } + if let Err(error) = write_with_backpressure_t(inner, router, command, false).await { + _debug!(inner, "Error sending command after MOVED: {:?}", error); + Err(error) + } else { + Ok(()) } - - Ok(()) } #[cfg(feature = "replicas")] @@ -467,101 +481,6 @@ fn process_connections( } /// Process any kind of router command. -#[cfg(feature = "mocks")] -fn process_command(inner: &Arc, command: RouterCommand) -> Result<(), RedisError> { - match command { - RouterCommand::Transaction { commands, tx, .. } => { - let mocked = commands.into_iter().skip(1).map(|c| c.to_mocked()).collect(); - - match inner.config.mocks.process_transaction(mocked) { - Ok(result) => { - let _ = tx.send(Ok(protocol_utils::mocked_value_to_frame(result))); - Ok(()) - }, - Err(err) => { - let _ = tx.send(Err(err)); - Ok(()) - }, - } - }, - RouterCommand::Pipeline { mut commands } => { - for mut command in commands.into_iter() { - let mocked = command.to_mocked(); - let result = inner - .config - .mocks - .process_command(mocked) - .map(|result| protocol_utils::mocked_value_to_frame(result)); - - let _ = command.respond_to_caller(result); - } - - Ok(()) - }, - RouterCommand::Command(mut command) => { - let result = inner - .config - .mocks - .process_command(command.to_mocked()) - .map(|result| protocol_utils::mocked_value_to_frame(result)); - let _ = command.respond_to_caller(result); - - Ok(()) - }, - _ => Err(RedisError::new(RedisErrorKind::Unknown, "Unimplemented.")), - } -} - -#[cfg(feature = "mocks")] -async fn process_commands(inner: &Arc, rx: &mut CommandReceiver) -> Result<(), RedisError> { - while let Some(command) = rx.recv().await { - inner.counters.decr_cmd_buffer_len(); - - _trace!(inner, "Recv mock command: {:?}", command); - if let Err(e) = process_command(inner, command) { - // errors on this interface end the client connection task - _error!(inner, "Ending early after error processing mock command: {:?}", e); - if e.is_canceled() { - break; - } else { - return Err(e); - } - } - } - - Ok(()) -} - -#[cfg(feature = "mocks")] -pub async fn start(inner: &Arc) -> Result<(), RedisError> { - sleep(Duration::from_millis(10)).await; - if !client_utils::check_and_set_client_state(&inner.state, ClientState::Disconnected, ClientState::Connecting) { - return Err(RedisError::new( - RedisErrorKind::Unknown, - "Connections are already initialized or connecting.", - )); - } - - _debug!(inner, "Starting mocking layer"); - let mut rx = match inner.take_command_rx() { - Some(rx) => rx, - None => { - return Err(RedisError::new( - RedisErrorKind::Config, - "Redis client is already initialized.", - )) - }, - }; - - inner.notifications.broadcast_connect(Ok(())); - inner.notifications.broadcast_reconnect(); - let result = process_commands(inner, &mut rx).await; - inner.store_command_rx(rx); - result -} - -/// Process any kind of router command. -#[cfg(not(feature = "mocks"))] async fn process_command( inner: &Arc, router: &mut Router, @@ -596,7 +515,6 @@ async fn process_command( } /// Start processing commands from the client front end. -#[cfg(not(feature = "mocks"))] async fn process_commands( inner: &Arc, router: &mut Router, @@ -627,43 +545,135 @@ async fn process_commands( } /// Start the command processing stream, initiating new connections in the process. -#[cfg(not(feature = "mocks"))] pub async fn start(inner: &Arc) -> Result<(), RedisError> { - sleep(Duration::from_millis(10)).await; - if !client_utils::check_and_set_client_state(&inner.state, ClientState::Disconnected, ClientState::Connecting) { - return Err(RedisError::new( - RedisErrorKind::Unknown, - "Connections are already initialized or connecting.", - )); + #[cfg(feature = "mocks")] + if let Some(ref mocks) = inner.config.mocks { + return mocking::start(inner, mocks).await; } + + let mut rx = match inner.take_command_rx() { + Some(rx) => rx, + None => { + // the `_lock` field on inner synchronizes the getters/setters on the command channel halves, so if this field + // is None then another task must have set and removed the receiver concurrently. + return Err(RedisError::new( + RedisErrorKind::Config, + "Another connection task is already running.", + )); + }, + }; + inner.reset_reconnection_attempts(); let mut router = Router::new(inner); - _debug!(inner, "Initializing router with policy: {:?}", inner.reconnect_policy()); - if inner.config.fail_fast { + let result = if inner.config.fail_fast { if let Err(e) = router.connect().await { inner.notifications.broadcast_connect(Err(e.clone())); inner.notifications.broadcast_error(e.clone()); - return Err(e); + Err(e) } else { client_utils::set_client_state(&inner.state, ClientState::Connected); inner.notifications.broadcast_connect(Ok(())); - inner.notifications.broadcast_reconnect(); + Ok(()) } } else { - let _ = utils::reconnect_with_policy(inner, &mut router).await?; + utils::reconnect_with_policy(inner, &mut router).await + }; + + if let Err(error) = result { + inner.store_command_rx(rx, false); + Err(error) + } else { + let result = process_commands(inner, &mut router, &mut rx).await; + inner.store_command_rx(rx, false); + result } +} - let mut rx = match inner.take_command_rx() { - Some(rx) => rx, - None => { - return Err(RedisError::new( - RedisErrorKind::Config, - "Redis client is already initialized.", - )) - }, - }; - let result = process_commands(inner, &mut router, &mut rx).await; - inner.store_command_rx(rx); - result +#[cfg(feature = "mocks")] +mod mocking { + use super::*; + use crate::{modules::mocks::Mocks, protocol::utils as protocol_utils}; + + /// Process any kind of router command. + pub fn process_command(mocks: &Arc, command: RouterCommand) -> Result<(), RedisError> { + match command { + RouterCommand::Transaction { commands, tx, .. } => { + let mocked = commands.into_iter().skip(1).map(|c| c.to_mocked()).collect(); + + match mocks.process_transaction(mocked) { + Ok(result) => { + let _ = tx.send(Ok(protocol_utils::mocked_value_to_frame(result))); + Ok(()) + }, + Err(err) => { + let _ = tx.send(Err(err)); + Ok(()) + }, + } + }, + RouterCommand::Pipeline { commands } => { + for mut command in commands.into_iter() { + let mocked = command.to_mocked(); + let result = mocks + .process_command(mocked) + .map(|result| protocol_utils::mocked_value_to_frame(result)); + + let _ = command.respond_to_caller(result); + } + + Ok(()) + }, + RouterCommand::Command(mut command) => { + let result = mocks + .process_command(command.to_mocked()) + .map(|result| protocol_utils::mocked_value_to_frame(result)); + let _ = command.respond_to_caller(result); + + Ok(()) + }, + _ => Err(RedisError::new(RedisErrorKind::Unknown, "Unimplemented.")), + } + } + + pub async fn process_commands( + inner: &Arc, + mocks: &Arc, + rx: &mut CommandReceiver, + ) -> Result<(), RedisError> { + while let Some(command) = rx.recv().await { + inner.counters.decr_cmd_buffer_len(); + + _trace!(inner, "Recv mock command: {:?}", command); + if let Err(e) = process_command(mocks, command) { + // errors on this interface end the client connection task + _error!(inner, "Ending early after error processing mock command: {:?}", e); + if e.is_canceled() { + break; + } else { + return Err(e); + } + } + } + + Ok(()) + } + + pub async fn start(inner: &Arc, mocks: &Arc) -> Result<(), RedisError> { + _debug!(inner, "Starting mocking layer"); + let mut rx = match inner.take_command_rx() { + Some(rx) => rx, + None => { + return Err(RedisError::new( + RedisErrorKind::Config, + "Redis client is already initialized.", + )) + }, + }; + + inner.notifications.broadcast_connect(Ok(())); + let result = process_commands(inner, mocks, &mut rx).await; + inner.store_command_rx(rx, false); + result + } } diff --git a/src/router/mod.rs b/src/router/mod.rs index f204c9b5..9cc4d66a 100644 --- a/src/router/mod.rs +++ b/src/router/mod.rs @@ -21,6 +21,9 @@ use std::{ }; use tokio::sync::oneshot::channel as oneshot_channel; +#[cfg(feature = "replicas")] +use std::collections::HashSet; + pub mod centralized; pub mod clustered; pub mod commands; @@ -36,6 +39,8 @@ pub mod utils; use crate::router::replicas::Replicas; /// The result of an attempt to send a command to the server. +// This is not an ideal pattern, but it mostly comes from the requirement that the shared buffer interface take +// ownership over the command. pub enum Written { /// Apply backpressure to the command before retrying. Backpressure((RedisCommand, Backpressure)), @@ -43,12 +48,17 @@ pub enum Written { Sent((Server, bool)), /// Indicates that the command was sent to all servers. SentAll, - /// Disconnect from the provided server and retry the command later. - Disconnect((Option, Option, RedisError)), - /// Indicates that the result should be ignored since the command will not be retried. + /// The command could not be written since the connection is down. + Disconnected((Option, Option, RedisError)), + /// Ignore the result and move on to the next command. Ignore, - /// (Cluster only) Synchronize the cached cluster routing table and retry. - Sync(RedisCommand), + /// The command could not be routed to any server. + NotFound(RedisCommand), + /// A fatal error that should interrupt the router. + Error((RedisError, Option)), + /// Restart the write process on a primary node connection. + #[cfg(feature = "replicas")] + Fallback(RedisCommand), } impl fmt::Display for Written { @@ -57,9 +67,12 @@ impl fmt::Display for Written { Written::Backpressure(_) => "Backpressure", Written::Sent(_) => "Sent", Written::SentAll => "SentAll", - Written::Disconnect(_) => "Disconnect", + Written::Disconnected(_) => "Disconnected", Written::Ignore => "Ignore", - Written::Sync(_) => "Sync", + Written::NotFound(_) => "NotFound", + Written::Error(_) => "Error", + #[cfg(feature = "replicas")] + Written::Fallback(_) => "Fallback", }) } } @@ -83,14 +96,14 @@ impl Backpressure { match self { Backpressure::Error(e) => Err(e), Backpressure::Wait(duration) => { - _debug!(inner, "Backpressure policy (wait): {}ms", duration.as_millis()); - trace::backpressure_event(&command, Some(duration.as_millis())); - let _ = inner.wait_with_interrupt(duration).await?; + _debug!(inner, "Backpressure policy (wait): {:?}", duration); + trace::backpressure_event(command, Some(duration.as_millis())); + inner.wait_with_interrupt(duration).await?; Ok(None) }, Backpressure::Block => { _debug!(inner, "Backpressure (block)"); - trace::backpressure_event(&command, None); + trace::backpressure_event(command, None); if !command.has_router_channel() { _trace!( inner, @@ -107,9 +120,10 @@ impl Backpressure { } } +/// Connection maps for the supported deployment types. pub enum Connections { Centralized { - /// The connection to the server. + /// The connection to the primary server. writer: Option, }, Clustered { @@ -181,8 +195,6 @@ impl Connections { } /// Whether or not the connection map has a connection to the provided server`. - /// - /// The connection is tested by calling `flush`. pub fn has_server_connection(&mut self, server: &Server) -> bool { match self { Connections::Centralized { ref mut writer } | Connections::Sentinel { ref mut writer } => { @@ -221,11 +233,7 @@ impl Connections { .as_mut() .and_then(|writer| if writer.server == *server { Some(writer) } else { None }) }, - Connections::Clustered { ref mut writers, .. } => { - writers - .iter_mut() - .find_map(|(_, writer)| if writer.server == *server { Some(writer) } else { None }) - }, + Connections::Clustered { ref mut writers, .. } => writers.get_mut(server), } } @@ -341,7 +349,7 @@ impl Connections { } } - /// Read a map of connection IDs (via `CLIENT ID`) for each inner connection. + /// Read a map of connection IDs (via `CLIENT ID`) for each inner connections. pub fn connection_ids(&self) -> HashMap { let mut out = HashMap::new(); @@ -400,39 +408,29 @@ impl Connections { } /// Send a command to the server(s). - pub async fn write_command( - &mut self, - inner: &Arc, - command: RedisCommand, - force_flush: bool, - ) -> Result { + pub async fn write(&mut self, inner: &Arc, command: RedisCommand, force_flush: bool) -> Written { match self { Connections::Clustered { ref mut writers, ref mut cache, - } => clustered::send_command(inner, writers, cache, command, force_flush).await, - Connections::Centralized { ref mut writer } => { - centralized::send_command(inner, writer, command, force_flush).await - }, - Connections::Sentinel { ref mut writer, .. } => { - centralized::send_command(inner, writer, command, force_flush).await - }, + } => clustered::write(inner, writers, cache, command, force_flush).await, + Connections::Centralized { ref mut writer } => centralized::write(inner, writer, command, force_flush).await, + Connections::Sentinel { ref mut writer, .. } => centralized::write(inner, writer, command, force_flush).await, } } /// Send a command to all servers in a cluster. - pub async fn write_all_cluster( - &mut self, - inner: &Arc, - command: RedisCommand, - ) -> Result { + pub async fn write_all_cluster(&mut self, inner: &Arc, command: RedisCommand) -> Written { if let Connections::Clustered { ref mut writers, .. } = self { - let _ = clustered::send_all_cluster_command(inner, writers, command).await?; - Ok(Written::SentAll) + if let Err(error) = clustered::send_all_cluster_command(inner, writers, command).await { + Written::Disconnected((None, None, error)) + } else { + Written::SentAll + } } else { - Err(RedisError::new( - RedisErrorKind::Config, - "Expected clustered configuration.", + Written::Error(( + RedisError::new(RedisErrorKind::Config, "Expected clustered configuration."), + None, )) } } @@ -454,15 +452,8 @@ impl Connections { /// Connect or reconnect to the provided `host:port`. pub async fn add_connection(&mut self, inner: &Arc, server: &Server) -> Result<(), RedisError> { if let Connections::Clustered { ref mut writers, .. } = self { - let mut transport = connection::create( - inner, - server.host.as_str().to_owned(), - server.port, - None, - server.tls_server_name.as_ref(), - ) - .await?; - let _ = transport.setup(inner, None).await?; + let mut transport = connection::create(inner, server, None).await?; + transport.setup(inner, None).await?; let (server, writer) = connection::split_and_initialize(inner, transport, false, clustered::spawn_reader_task)?; writers.insert(server, writer); @@ -504,9 +495,13 @@ impl Connections { /// A struct for routing commands to the server(s). pub struct Router { + /// The connection map for each deployment type. pub connections: Connections, + /// The inner client state associated with the router. pub inner: Arc, + /// Storage for commands that should be deferred or retried later. pub buffer: CommandBuffer, + /// The replica routing interface. #[cfg(feature = "replicas")] pub replicas: Replicas, } @@ -531,15 +526,24 @@ impl Router { } } + /// Sync the local connection state with the task that periodically scans for unresponsive connection timeouts. #[cfg(feature = "check-unresponsive")] pub fn sync_network_timeout_state(&self) { self.inner.network_timeouts.state().sync(&self.inner, &self.connections); + + #[cfg(feature = "replicas")] + self + .inner + .network_timeouts + .state() + .sync_replicas(&self.inner, &self.replicas.writers); } + /// Sync the local connection state with the task that periodically scans for unresponsive connection timeouts. #[cfg(not(feature = "check-unresponsive"))] pub fn sync_network_timeout_state(&self) {} - /// Read the connection identifier for the provided command. + /// Read the server that should receive the provided command. pub fn find_connection(&self, command: &RedisCommand) -> Option<&Server> { match self.connections { Connections::Centralized { ref writer } => writer.as_ref().map(|w| &w.server), @@ -548,59 +552,32 @@ impl Router { } } - /// Route and write the command to the server(s). - /// - /// If the command cannot be written: - /// * The command will be queued to run later. - /// * The associated connection will be dropped. - /// * The reader task for that connection will close, sending a `Reconnect` message to the router. - /// - /// Errors are handled internally, but may be returned if the command was queued to run later. - pub async fn write_command(&mut self, mut command: RedisCommand, force_flush: bool) -> Result { - if let Err(e) = command.incr_check_attempted(self.inner.max_command_attempts()) { - debug!( - "{}: Skipping command `{}` after too many failed attempts.", - self.inner.id, - command.kind.to_str_debug() - ); - command.respond_to_caller(Err(e)); - return Ok(Written::Ignore); - } - if command.attempted > 1 { + /// Attempt to send the command to the server. + pub async fn write(&mut self, command: RedisCommand, force_flush: bool) -> Written { + let send_all_cluster_nodes = self.inner.config.server.is_clustered() + && (command.kind.is_all_cluster_nodes() || command.kind.closes_connection()); + + if command.write_attempts >= 1 { self.inner.counters.incr_redelivery_count(); } - - let send_all_cluster_nodes = command.kind.is_all_cluster_nodes() - || (command.kind.closes_connection() && self.inner.config.server.is_clustered()); - if send_all_cluster_nodes { self.connections.write_all_cluster(&self.inner, command).await } else { - match self.connections.write_command(&self.inner, command, force_flush).await { - Ok(result) => Ok(result), - Err((error, command)) => { - self.buffer_command(command); - Err(error) - }, - } + self.connections.write(&self.inner, command, force_flush).await } } /// Write a command to a replica node if possible, falling back to a primary node if configured. #[cfg(feature = "replicas")] - pub async fn write_replica_command( - &mut self, - mut command: RedisCommand, - force_flush: bool, - ) -> Result { + pub async fn write_replica(&mut self, mut command: RedisCommand, force_flush: bool) -> Written { if !command.use_replica { - return self.write_command(command, force_flush).await; + return self.write(command, force_flush).await; } let primary = match self.find_connection(&command) { Some(server) => server.clone(), None => { - if self.inner.config.replica.primary_fallback { + return if self.inner.connection.replica.primary_fallback { debug!( "{}: Fallback to primary node connection for {} ({})", self.inner.id, @@ -609,87 +586,47 @@ impl Router { ); command.use_replica = false; - return self.write_command(command, force_flush).await; + self.write(command, force_flush).await } else { - command.respond_to_caller(Err(RedisError::new( - RedisErrorKind::Replica, - "Missing primary node connection.", - ))); - return Ok(Written::Ignore); + command.finish( + &self.inner, + Err(RedisError::new( + RedisErrorKind::Replica, + "Missing primary node connection.", + )), + ); + + Written::Ignore } }, }; - if let Err(e) = command.incr_check_attempted(self.inner.max_command_attempts()) { - debug!( - "{}: Skipping replica command `{}` after too many failed attempts.", - self.inner.id, - command.kind.to_str_debug() - ); - command.respond_to_caller(Err(e)); - return Ok(Written::Ignore); - } - if command.attempted > 1 { - self.inner.counters.incr_redelivery_count(); - } - - let result = self - .replicas - .write_command(&self.inner, &primary, command, force_flush) - .await; + let result = self.replicas.write(&self.inner, &primary, command, force_flush).await; match result { - Ok(result) => { - if let Err(e) = self.replicas.check_and_flush().await { - error!("{}: Error flushing replica connections: {:?}", self.inner.id, e); - } - - Ok(result) - }, - Err((error, mut command)) => { - if self.inner.config.replica.primary_fallback { - debug!( - "{}: Fall back to primary node for {} ({}) after replica error: {:?}", - self.inner.id, - command.kind.to_str_debug(), - command.debug_id(), - error - ); + Written::Fallback(mut command) => { + debug!( + "{}: Fall back to primary node for {} ({}) after replica error", + self.inner.id, + command.kind.to_str_debug(), + command.debug_id(), + ); - command.use_replica = false; - return self.write_command(command, force_flush).await; - } else { - trace!( - "{}: Add {} ({}) to replica retry buffer.", - self.inner.id, - command.kind.to_str_debug(), - command.debug_id() - ); - self.replicas.add_to_retry_buffer(command); - } - Err(error) + utils::defer_replica_sync(&self.inner); + command.use_replica = false; + self.write(command, force_flush).await }, + _ => result, } } /// Write a command to a replica node if possible, falling back to a primary node if configured. #[cfg(not(feature = "replicas"))] - pub async fn write_replica_command( - &mut self, - command: RedisCommand, - force_flush: bool, - ) -> Result { - self.write_command(command, force_flush).await + pub async fn write_replica(&mut self, command: RedisCommand, force_flush: bool) -> Written { + self.write(command, force_flush).await } - /// Attempt to write the command to a specific server without backpressure, returning the error and command on - /// failure. - /// - /// The associated connection will be dropped if needed. The caller is responsible for returning errors. - pub async fn write_direct( - &mut self, - mut command: RedisCommand, - server: &Server, - ) -> Result<(), (RedisError, RedisCommand)> { + /// Attempt to write the command to a specific server without backpressure. + pub async fn write_direct(&mut self, mut command: RedisCommand, server: &Server) -> Written { debug!( "{}: Direct write `{}` command to {}, ID: {}", self.inner.id, @@ -701,14 +638,10 @@ impl Router { let writer = match self.connections.get_connection_mut(server) { Some(writer) => writer, None => { - let err = RedisError::new( - RedisErrorKind::Unknown, - format!("Failed to find connection for {}", server), - ); - return Err((err, command)); + trace!("{}: Missing connection to {}", self.inner.id, server); + return Written::NotFound(command); }, }; - let frame = match utils::prepare_command(&self.inner, &writer.counters, &mut command) { Ok((frame, _)) => frame, Err(e) => { @@ -718,118 +651,24 @@ impl Router { command.kind.to_str_debug() ); // do not retry commands that trigger frame encoding errors - command.respond_to_caller(Err(e)); - return Ok(()); + command.finish(&self.inner, Err(e)); + return Written::Ignore; }, }; - let blocks_connection = command.blocks_connection(); - // always flush the socket in this case + command.write_attempts += 1; writer.push_command(&self.inner, command); - if let Err(e) = writer.write_frame(frame, true).await { - let command = match writer.pop_recent_command() { - Some(cmd) => cmd, - None => { - error!( - "{}: Failed to take recent command off queue after write failure.", - self.inner.id - ); - return Ok(()); - }, - }; - - debug!( - "{}: Error sending command {}: {:?}", - self.inner.id, - command.kind.to_str_debug(), - e - ); - Err((e, command)) + if let Err(error) = writer.write_frame(frame, true).await { + let command = writer.pop_recent_command(); + debug!("{}: Error sending command: {:?}", self.inner.id, error); + Written::Disconnected((Some(writer.server.clone()), command, error)) } else { if blocks_connection { self.inner.backchannel.write().await.set_blocked(&writer.server); } - Ok(()) - } - } - - /// Write the command once without checking for backpressure, returning any connection errors and queueing the - /// command to run later if needed. - /// - /// The associated connection will be dropped if needed. - pub async fn write_once(&mut self, command: RedisCommand, server: &Server) -> Result<(), RedisError> { - let inner = self.inner.clone(); - _debug!( - inner, - "Writing `{}` command once to {}", - command.kind.to_str_debug(), - server - ); - let is_blocking = command.blocks_connection(); - let write_result = { - let writer = match self.connections.get_connection_mut(server) { - Some(writer) => writer, - None => { - return Err(RedisError::new( - RedisErrorKind::Unknown, - format!("Failed to find connection for {}", server), - )) - }, - }; - - utils::write_command(&inner, writer, command, true).await - }; - - match write_result { - Written::Disconnect((server, command, error)) => { - let buffer = self.connections.disconnect(&inner, server.as_ref()).await; - self.buffer_commands(buffer); - self.sync_network_timeout_state(); - - if let Some(command) = command { - _debug!( - inner, - "Dropping command after write failure in write_once: {}", - command.kind.to_str_debug() - ); - } - // the connection error is sent to the caller in `write_command` - Err(error) - }, - Written::Sync(command) => { - _debug!(inner, "Missing hash slot. Disconnecting and syncing cluster."); - let buffer = self.connections.disconnect_all(&inner).await; - self.buffer_commands(buffer); - self.buffer_command(command); - self.sync_network_timeout_state(); - - Err(RedisError::new( - RedisErrorKind::Protocol, - "Invalid or missing hash slot.", - )) - }, - Written::SentAll => { - let _ = self.check_and_flush().await?; - Ok(()) - }, - Written::Sent((server, flushed)) => { - trace!("{}: Sent command to {} (flushed: {})", self.inner.id, server, flushed); - if is_blocking { - inner.backchannel.write().await.set_blocked(&server); - } - if !flushed { - let _ = self.check_and_flush().await?; - } - - Ok(()) - }, - Written::Ignore => Err(RedisError::new(RedisErrorKind::Unknown, "Could not send command.")), - Written::Backpressure(_) => Err(RedisError::new( - RedisErrorKind::Unknown, - "Unexpected backpressure flag.", - )), + Written::Sent((writer.server.clone(), true)) } } @@ -838,8 +677,8 @@ impl Router { pub async fn disconnect_all(&mut self) { let commands = self.connections.disconnect_all(&self.inner).await; self.buffer_commands(commands); - self.sync_network_timeout_state(); self.disconnect_replicas().await; + self.sync_network_timeout_state(); } /// Disconnect from all the servers, moving the in-flight messages to the internal command buffer and triggering a @@ -910,13 +749,13 @@ impl Router { self.sync_network_timeout_state(); if result.is_ok() { - self.retry_buffer().await; - if let Err(e) = self.sync_replicas().await { if !self.inner.ignore_replica_reconnect_errors() { return Err(e); } } + + self.retry_buffer().await; } result @@ -926,11 +765,22 @@ impl Router { #[cfg(feature = "replicas")] pub async fn sync_replicas(&mut self) -> Result<(), RedisError> { debug!("{}: Syncing replicas...", self.inner.id); - let _ = self.replicas.clear_connections(&self.inner).await?; - let replicas = self.connections.replica_map(&self.inner).await?; + self.replicas.drop_broken_connections().await; + let old_connections = self.replicas.active_connections(); + let new_replica_map = self.connections.replica_map(&self.inner).await?; - for (mut replica, primary) in replicas.into_iter() { - let should_use = if let Some(filter) = self.inner.config.replica.filter.as_ref() { + let old_connections_idx: HashSet<_> = old_connections.iter().collect(); + let new_connections_idx: HashSet<_> = new_replica_map.keys().collect(); + let remove: Vec<_> = old_connections_idx.difference(&new_connections_idx).collect(); + + for server in remove.into_iter() { + debug!("{}: Dropping replica connection to {}", self.inner.id, server); + self.replicas.drop_writer(&server).await; + self.replicas.remove_replica(&server); + } + + for (mut replica, primary) in new_replica_map.into_iter() { + let should_use = if let Some(filter) = self.inner.connection.replica.filter.as_ref() { filter.filter(&primary, &replica).await } else { true @@ -951,8 +801,7 @@ impl Router { .server_state .write() .update_replicas(self.replicas.routing_table()); - - self.replicas.retry_buffer(&self.inner); + self.sync_network_timeout_state(); Ok(()) } @@ -963,12 +812,11 @@ impl Router { } /// Attempt to replay all queued commands on the internal buffer without backpressure. - /// - /// If a command cannot be written the underlying connections will close and the unsent commands will remain on the - /// internal buffer. pub async fn retry_buffer(&mut self) { - let mut commands: VecDeque = self.buffer.drain(..).collect(); - let mut failed_command = None; + let mut failed_commands: VecDeque<_> = VecDeque::new(); + let mut commands: VecDeque<_> = self.buffer.drain(..).collect(); + #[cfg(feature = "replicas")] + commands.extend(self.replicas.take_retry_buffer()); for mut command in commands.drain(..) { if client_utils::read_bool_atomic(&command.timed_out) { @@ -980,51 +828,64 @@ impl Router { continue; } + if let Err(e) = command.decr_check_attempted() { + command.finish(&self.inner, Err(e)); + continue; + } command.skip_backpressure = true; trace!( - "{}: Retry `{}` ({}) command, attempt {}/{}", + "{}: Retry `{}` ({}) command, attempts left: {}", self.inner.id, command.kind.to_str_debug(), command.debug_id(), - command.attempted, - self.inner.max_command_attempts() + command.attempts_remaining, ); - match self.write_command(command, true).await { - Ok(Written::Disconnect((server, command, error))) => { + + let result = if command.use_replica { + self.write_replica(command, true).await + } else { + self.write(command, true).await + }; + + match result { + Written::Disconnected((server, command, error)) => { if let Some(command) = command { - failed_command = Some(command); + failed_commands.push_back(command); } - warn!( - "{}: Disconnect from {:?} while replaying command: {:?}", - self.inner.id, server, error + debug!( + "{}: Disconnect while retrying after write error: {:?}", + &self.inner.id, error ); - self.disconnect_all().await; // triggers a reconnect if needed - break; + // triggers a reconnect if needed + self.connections.disconnect(&self.inner, server.as_ref()).await; + continue; }, - Ok(Written::Sync(command)) => { - failed_command = Some(command); + Written::NotFound(command) => { + failed_commands.push_back(command); - warn!("{}: Disconnect and re-sync cluster state.", self.inner.id); - self.disconnect_all().await; // triggers a reconnect if needed + warn!( + "{}: Disconnect and re-sync cluster state after routing error while retrying commands.", + self.inner.id + ); + // triggers a reconnect if needed + self.disconnect_all().await; break; }, - Err(error) => { + Written::Error((error, command)) => { warn!("{}: Error replaying command: {:?}", self.inner.id, error); - self.disconnect_all().await; // triggers a reconnect if needed + if let Some(command) = command { + command.finish(&self.inner, Err(error)); + } + self.disconnect_all().await; break; }, - Ok(written) => { - warn!("{}: Unexpected retry result: {}", self.inner.id, written); - continue; - }, + _ => {}, } } - if let Some(command) = failed_command { - self.buffer_command(command); - } - self.buffer_commands(commands); + failed_commands.extend(commands); + self.buffer_commands(failed_commands); } /// Check each connection for pending frames that have not been flushed, and flush the connection if needed. @@ -1071,11 +932,11 @@ impl Router { .unwrap_or(true); if should_sync { - let _ = self.sync_cluster().await?; + self.sync_cluster().await?; } } else if *kind == ClusterErrorKind::Ask { if !self.connections.has_server_connection(server) { - let _ = self.connections.add_connection(&self.inner, server).await?; + self.connections.add_connection(&self.inner, server).await?; self .inner .backchannel @@ -1089,9 +950,16 @@ impl Router { let (tx, rx) = oneshot_channel(); let mut command = RedisCommand::new_asking(slot); command.response = ResponseKind::Respond(Some(tx)); + command.skip_backpressure = true; + + match self.write_direct(command, server).await { + Written::Error((error, _)) => return Err(error), + Written::Disconnected((_, _, error)) => return Err(error), + Written::NotFound(_) => return Err(RedisError::new(RedisErrorKind::Cluster, "Connection not found.")), + _ => {}, + }; - let _ = self.write_once(command, &server).await?; - let _ = rx.await??; + let _ = client_utils::apply_timeout(rx, self.inner.internal_command_timeout()).await??; } Ok(()) diff --git a/src/router/replicas.rs b/src/router/replicas.rs index ec509d55..dd65bd50 100644 --- a/src/router/replicas.rs +++ b/src/router/replicas.rs @@ -1,12 +1,11 @@ #[cfg(all(feature = "replicas", any(feature = "enable-native-tls", feature = "enable-rustls")))] -use crate::types::{HostMapping, TlsHostMapping}; +use crate::types::TlsHostMapping; #[cfg(feature = "replicas")] use crate::{ error::{RedisError, RedisErrorKind}, - interfaces, modules::inner::RedisClientInner, protocol::{ - command::{RedisCommand, RouterCommand}, + command::RedisCommand, connection, connection::{CommandBuffer, RedisWriter}, }, @@ -15,7 +14,8 @@ use crate::{ }; #[cfg(feature = "replicas")] use std::{ - collections::{BTreeSet, HashMap, VecDeque}, + collections::{HashMap, VecDeque}, + convert::identity, fmt, fmt::Formatter, sync::Arc, @@ -98,13 +98,54 @@ impl Default for ReplicaConfig { } } +/// A container for round-robin routing among replica nodes. +// This implementation optimizes for next() at the cost of add() and remove() +#[derive(Clone, Debug, PartialEq, Eq, Default)] +#[cfg(feature = "replicas")] +#[cfg_attr(docsrs, doc(cfg(feature = "replicas")))] +pub struct ReplicaRouter { + counter: usize, + servers: Vec, +} + +#[cfg(feature = "replicas")] +impl ReplicaRouter { + /// Read the server that should receive the next command. + pub fn next(&mut self) -> Option<&Server> { + self.counter = (self.counter + 1) % self.servers.len(); + self.servers.get(self.counter) + } + + /// Conditionally add the server to the replica set. + pub fn add(&mut self, server: Server) { + if !self.servers.contains(&server) { + self.servers.push(server); + } + } + + /// Remove the server from the replica set. + pub fn remove(&mut self, server: &Server) { + self.servers = self.servers.drain(..).filter(|_server| server != _server).collect(); + } + + /// The size of the replica set. + pub fn len(&self) -> usize { + self.servers.len() + } + + /// Iterate over the replica set. + pub fn iter(&self) -> impl Iterator { + self.servers.iter() + } +} + /// A container for round-robin routing to replica servers. #[cfg(feature = "replicas")] #[cfg_attr(docsrs, doc(cfg(feature = "replicas")))] #[derive(Clone, Debug, Eq, PartialEq, Default)] pub struct ReplicaSet { /// A map of primary server IDs to a counter and set of replica server IDs. - servers: HashMap)>, + servers: HashMap, } #[cfg(feature = "replicas")] @@ -119,41 +160,63 @@ impl ReplicaSet { /// Add a replica node to the routing table. pub fn add(&mut self, primary: Server, replica: Server) { - let (_, replicas) = self.servers.entry(primary).or_insert((0, Vec::new())); - - if !replicas.contains(&replica) { - replicas.push(replica); - } + self + .servers + .entry(primary) + .or_insert(ReplicaRouter::default()) + .add(replica); } /// Remove a replica node mapping from the routing table. pub fn remove(&mut self, primary: &Server, replica: &Server) { - if let Some((count, mut replicas)) = self.servers.remove(primary) { - replicas = replicas.drain(..).filter(|node| node != replica).collect(); + let should_remove = if let Some(router) = self.servers.get_mut(primary) { + router.remove(replica); + router.len() == 0 + } else { + false + }; - if !replicas.is_empty() { - self.servers.insert(primary.clone(), (count, replicas)); - } + if should_remove { + self.servers.remove(primary); } } + /// Remove the replica from all routing sets. + pub fn remove_replica(&mut self, replica: &Server) { + self.servers = self + .servers + .drain() + .filter_map(|(primary, mut routing)| { + routing.remove(replica); + + if routing.len() > 0 { + Some((primary, routing)) + } else { + None + } + }) + .collect(); + } + /// Read the server ID of the next replica that should receive a command. pub fn next_replica(&mut self, primary: &Server) -> Option<&Server> { - self.servers.get_mut(primary).and_then(|(idx, replicas)| { - *idx += 1; - replicas.get(*idx % replicas.len()) - }) + self.servers.get_mut(primary).and_then(|router| router.next()) } /// Read all the replicas associated with the provided primary node. - pub fn replicas(&self, primary: &Server) -> Option<&Vec> { - self.servers.get(primary).map(|(_, replicas)| replicas) + pub fn replicas(&self, primary: &Server) -> impl Iterator { + self + .servers + .get(primary) + .map(|router| router.iter()) + .into_iter() + .flat_map(identity) } /// Return a map of replica nodes to primary nodes. pub fn to_map(&self) -> HashMap { let mut out = HashMap::with_capacity(self.servers.len()); - for (primary, (_, replicas)) in self.servers.iter() { + for (primary, replicas) in self.servers.iter() { for replica in replicas.iter() { out.insert(replica.clone(), primary.clone()); } @@ -165,7 +228,7 @@ impl ReplicaSet { /// Read the set of all known replica nodes for all primary nodes. pub fn all_replicas(&self) -> Vec { let mut out = Vec::with_capacity(self.servers.len()); - for (_, (_, replicas)) in self.servers.iter() { + for (_, replicas) in self.servers.iter() { for replica in replicas.iter() { out.push(replica.clone()); } @@ -183,9 +246,9 @@ impl ReplicaSet { /// A struct for routing commands to replica nodes. #[cfg(feature = "replicas")] pub struct Replicas { - writers: HashMap, - routing: ReplicaSet, - buffer: CommandBuffer, + pub(crate) writers: HashMap, + routing: ReplicaSet, + buffer: VecDeque, } #[cfg(feature = "replicas")] @@ -199,31 +262,6 @@ impl Replicas { } } - pub fn add_to_retry_buffer(&mut self, command: RedisCommand) { - self.buffer.push_back(command); - } - - /// Retry the commands in the cached retry buffer by sending them to the router again. - pub fn retry_buffer(&mut self, inner: &Arc) { - let retry_count = inner.config.replica.connection_error_count; - for mut command in self.buffer.drain(..) { - if retry_count > 0 && command.attempted >= retry_count { - _trace!( - inner, - "Switch {} ({}) to fall back to primary after retry.", - command.kind.to_str_debug(), - command.debug_id() - ); - command.attempted = 0; - command.use_replica = false; - } - - if let Err(e) = interfaces::send_to_router(inner, RouterCommand::Command(command)) { - _error!(inner, "Error sending replica command to router: {:?}", e); - } - } - } - /// Sync the connection map in place based on the cached routing table. pub async fn sync_connections(&mut self, inner: &Arc) -> Result<(), RedisError> { for (_, writer) in self.writers.drain() { @@ -259,15 +297,8 @@ impl Replicas { primary ); - if !inner.config.replica.lazy_connections || force { - let mut transport = connection::create( - inner, - replica.host.as_str().to_owned(), - replica.port, - None, - replica.tls_server_name.as_ref(), - ) - .await?; + if !inner.connection.replica.lazy_connections || force { + let mut transport = connection::create(inner, &replica, None).await?; let _ = transport.setup(inner, None).await?; let (_, writer) = if inner.config.server.is_clustered() { @@ -284,6 +315,19 @@ impl Replicas { Ok(()) } + /// Drop the socket associated with the provided server. + pub async fn drop_writer(&mut self, replica: &Server) { + if let Some(writer) = self.writers.remove(replica) { + let commands = writer.graceful_close().await; + self.buffer.extend(commands); + } + } + + /// Remove the replica from the routing table. + pub fn remove_replica(&mut self, replica: &Server) { + self.routing.remove_replica(replica); + } + /// Close the replica connection and optionally remove the replica from the routing table. pub async fn remove_connection( &mut self, @@ -298,10 +342,7 @@ impl Replicas { replica, primary ); - if let Some(writer) = self.writers.remove(replica) { - let commands = writer.graceful_close().await; - self.buffer.extend(commands); - } + self.drop_writer(replica).await; if !keep_routable { self.routing.remove(primary, replica); @@ -320,11 +361,9 @@ impl Replicas { /// Whether a working connection exists to any replica for the provided primary node. pub fn has_replica_connection(&self, primary: &Server) -> bool { - if let Some(replicas) = self.routing.replicas(primary) { - for replica in replicas.iter() { - if self.has_connection(replica) { - return true; - } + for replica in self.routing.replicas(primary) { + if self.has_connection(replica) { + return true; } } @@ -341,118 +380,85 @@ impl Replicas { self.routing.to_map() } - /// Discover and connect to replicas via the `ROLE` command. - pub async fn sync_by_role( - &mut self, - inner: &Arc, - primary: &mut RedisWriter, - ) -> Result<(), RedisError> { - for replica in primary.discover_replicas(inner).await? { - self.routing.add(primary.server.clone(), replica); - } - - Ok(()) - } - - /// Discover and connect to replicas by inspecting the cached `CLUSTER SLOTS` state. - pub fn sync_by_cached_cluster_state( - &mut self, - inner: &Arc, - primary: &Server, - ) -> Result<(), RedisError> { - let replicas: Vec = inner.with_cluster_state(|state| { - Ok( - state - .slots() - .iter() - .fold(BTreeSet::new(), |mut replicas, slot| { - if slot.primary == *primary { - replicas.extend(slot.replicas.clone()); - } - - replicas - }) - .into_iter() - .collect(), - ) - })?; - - for replica in replicas.into_iter() { - self.routing.add(primary.clone(), replica); + /// Check the active connections and drop any without a working reader task. + pub async fn drop_broken_connections(&mut self) { + let mut new_writers = HashMap::with_capacity(self.writers.len()); + for (server, writer) in self.writers.drain() { + if writer.is_working() { + new_writers.insert(server, writer); + } else { + let commands = writer.graceful_close().await; + self.buffer.extend(commands); + } } - Ok(()) + self.writers = new_writers; } - /// Check if the provided connection has any known replica nodes, and if so add them to the cached routing table. - pub async fn check_replicas( - &mut self, - inner: &Arc, - primary: &mut RedisWriter, - ) -> Result<(), RedisError> { - if inner.config.server.is_clustered() { - if let Err(_) = self.sync_by_cached_cluster_state(inner, &primary.server) { - _warn!(inner, "Failed to discover replicas via cached CLUSTER SLOTS."); - self.sync_by_role(inner, primary).await - } else { - Ok(()) - } - } else { - self.sync_by_role(inner, primary).await - } + /// Read the set of all active connections. + pub fn active_connections(&self) -> Vec { + self + .writers + .iter() + .filter_map(|(server, writer)| { + if writer.is_working() { + Some(server.clone()) + } else { + None + } + }) + .collect() } /// Send a command to one of the replicas associated with the provided primary server. - pub async fn write_command( + pub async fn write( &mut self, inner: &Arc, primary: &Server, mut command: RedisCommand, force_flush: bool, - ) -> Result { + ) -> Written { let replica = match self.routing.next_replica(primary) { Some(replica) => replica.clone(), None => { - // these errors indicate we do not know of any replica node associated with the primary node - - return if inner.config.replica.primary_fallback { - // FIXME this is ugly and leaks implementation details to the caller - Err(( - RedisError::new(RedisErrorKind::Replica, "Missing replica node."), - command, - )) + // we do not know of any replica node associated with the primary node + return if inner.connection.replica.primary_fallback { + Written::Fallback(command) } else { - command.respond_to_caller(Err(RedisError::new(RedisErrorKind::Replica, "Missing replica node."))); - Ok(Written::Ignore) + command.finish( + inner, + Err(RedisError::new(RedisErrorKind::Replica, "Missing replica node.")), + ); + Written::Ignore }; }, }; let writer = match self.writers.get_mut(&replica) { Some(writer) => writer, None => { - // these errors indicate that we know a replica node _should_ exist, but we are not connected or cannot + // these errors indicate that we know a replica node should exist, but we are not connected or cannot // connect to it. in this case we want to hide the error, trigger a reconnect, and retry the command later. - - if inner.config.replica.lazy_connections { + if inner.connection.replica.lazy_connections { _debug!(inner, "Lazily adding {} replica connection", replica); if let Err(e) = self.add_connection(inner, primary.clone(), replica.clone(), true).await { - return Err((e, command)); + // we tried connecting once but failed. + return Written::Disconnected((Some(replica.clone()), Some(command), e)); } match self.writers.get_mut(&replica) { Some(writer) => writer, None => { - return Err(( - RedisError::new(RedisErrorKind::Replica, "Missing replica node connection."), - command, - )) + // the connection should be here if self.add_connection succeeded + return Written::Disconnected(( + Some(replica.clone()), + Some(command), + RedisError::new(RedisErrorKind::Replica, "Missing connection."), + )); }, } } else { - return Err(( - RedisError::new(RedisErrorKind::Replica, "Missing replica node connection."), - command, - )); + // we don't have a connection to the replica and we're not configured to lazily create new ones + return Written::NotFound(command); } }, }; @@ -461,13 +467,12 @@ impl Replicas { Err(e) => { _warn!(inner, "Frame encoding error for {}", command.kind.to_str_debug()); // do not retry commands that trigger frame encoding errors - command.respond_to_caller(Err(e)); - return Ok(Written::Ignore); + command.finish(inner, Err(e)); + return Written::Ignore; }, }; let blocks_connection = command.blocks_connection(); - // always flush the socket in this case _debug!( inner, "Sending {} ({}) to replica {}", @@ -475,26 +480,37 @@ impl Replicas { command.debug_id(), replica ); + command.write_attempts += 1; writer.push_command(inner, command); if let Err(e) = writer.write_frame(frame, should_flush).await { let command = match writer.pop_recent_command() { Some(cmd) => cmd, None => { _error!(inner, "Failed to take recent command off queue after write failure."); - return Ok(Written::Ignore); + return Written::Ignore; }, }; - _debug!(inner, "Error sending command {}: {:?}", command.kind.to_str_debug(), e); - Err((e, command)) + _debug!( + inner, + "Error sending replica command {}: {:?}", + command.kind.to_str_debug(), + e + ); + Written::Disconnected((Some(writer.server.clone()), Some(command), e)) } else { if blocks_connection { inner.backchannel.write().await.set_blocked(&writer.server); } - Ok(Written::Sent((writer.server.clone(), true))) + Written::Sent((writer.server.clone(), should_flush)) } } + + /// Take the commands stored for retry later. + pub fn take_retry_buffer(&mut self) -> CommandBuffer { + self.buffer.drain(..).collect() + } } #[cfg(all(feature = "replicas", any(feature = "enable-native-tls", feature = "enable-rustls")))] @@ -511,7 +527,7 @@ pub fn map_replica_tls_names(inner: &Arc, primary: &Server, re return; } - replica.set_tls_server_name(policy, primary.host.as_str()); + replica.set_tls_server_name(policy, &primary.host); } #[cfg(all( diff --git a/src/router/responses.rs b/src/router/responses.rs index adcebc2c..a3592dc5 100644 --- a/src/router/responses.rs +++ b/src/router/responses.rs @@ -14,19 +14,19 @@ use crate::globals::globals; #[cfg(feature = "client-tracking")] use crate::types::Invalidation; -const KEYSPACE_PREFIX: &'static str = "__keyspace@"; -const KEYEVENT_PREFIX: &'static str = "__keyevent@"; +const KEYSPACE_PREFIX: &str = "__keyspace@"; +const KEYEVENT_PREFIX: &str = "__keyevent@"; #[cfg(feature = "client-tracking")] const INVALIDATION_CHANNEL: &'static str = "__redis__:invalidate"; fn parse_keyspace_notification(channel: &str, message: &RedisValue) -> Option { if channel.starts_with(KEYEVENT_PREFIX) { - let parts: Vec<&str> = channel.splitn(2, "@").collect(); + let parts: Vec<&str> = channel.splitn(2, '@').collect(); if parts.len() < 2 { return None; } - let suffix: Vec<&str> = parts[1].splitn(2, ":").collect(); + let suffix: Vec<&str> = parts[1].splitn(2, ':').collect(); if suffix.len() < 2 { return None; } @@ -43,12 +43,12 @@ fn parse_keyspace_notification(channel: &str, message: &RedisValue) -> Option = channel.splitn(2, "@").collect(); + let parts: Vec<&str> = channel.splitn(2, '@').collect(); if parts.len() < 2 { return None; } - let suffix: Vec<&str> = parts[1].splitn(2, ":").collect(); + let suffix: Vec<&str> = parts[1].splitn(2, ':').collect(); if suffix.len() < 2 { return None; } @@ -69,6 +69,7 @@ fn parse_keyspace_notification(channel: &str, message: &RedisValue) -> Option bool { s == "message" || s == "pmessage" || s == "smessage" } @@ -90,15 +91,20 @@ fn check_pubsub_formats(frame: &Resp3Frame) -> (bool, bool) { // so here we check the frame contents according to the RESP2 pubsub rules let resp3 = (data.len() == 3 || data.len() == 4) && data[0].as_str().map(check_message_prefix).unwrap_or(false); - (false, resp3) + (resp3, false) } /// Try to parse the frame in either RESP2 or RESP3 pubsub formats. -fn parse_pubsub_message(frame: Resp3Frame, is_resp3: bool, is_resp2: bool) -> Result { +fn parse_pubsub_message( + server: &Server, + frame: Resp3Frame, + is_resp3: bool, + is_resp2: bool, +) -> Result { if is_resp3 { - protocol_utils::frame_to_pubsub(frame) + protocol_utils::frame_to_pubsub(server, frame) } else if is_resp2 { - protocol_utils::parse_as_resp2_pubsub(frame) + protocol_utils::parse_as_resp2_pubsub(server, frame) } else { Err(RedisError::new(RedisErrorKind::Protocol, "Invalid pubsub message.")) } @@ -190,10 +196,11 @@ pub fn check_pubsub_message(inner: &Arc, server: &Server, fram let span = trace::create_pubsub_span(inner, &frame); _trace!(inner, "Processing pubsub message from {}.", server); let parsed_frame = if let Some(ref span) = span { - let _enter = span.enter(); - parse_pubsub_message(frame, is_resp3_pubsub, is_resp2_pubsub) + #[allow(clippy::let_unit_value)] + let _ = span.enter(); + parse_pubsub_message(server, frame, is_resp3_pubsub, is_resp2_pubsub) } else { - parse_pubsub_message(frame, is_resp3_pubsub, is_resp2_pubsub) + parse_pubsub_message(server, frame, is_resp3_pubsub, is_resp2_pubsub) }; let message = match parsed_frame { @@ -209,12 +216,10 @@ pub fn check_pubsub_message(inner: &Arc, server: &Server, fram if is_pubsub_invalidation(&message) { broadcast_pubsub_invalidation(inner, message, server); + } else if let Some(event) = parse_keyspace_notification(&message.channel, &message.value) { + inner.notifications.broadcast_keyspace(event); } else { - if let Some(event) = parse_keyspace_notification(&message.channel, &message.value) { - inner.notifications.broadcast_keyspace(event); - } else { - inner.notifications.broadcast_pubsub(message); - } + inner.notifications.broadcast_pubsub(message); } None @@ -230,7 +235,7 @@ pub async fn check_and_set_unblocked_flag(inner: &Arc, command /// Parse the response frame to see if it's an auth error. fn parse_redis_auth_error(frame: &Resp3Frame) -> Option { if frame.is_error() { - match protocol_utils::frame_to_single_result(frame.clone()) { + match protocol_utils::frame_to_results(frame.clone()) { Ok(_) => None, Err(e) => match e.kind() { RedisErrorKind::Auth => Some(e), @@ -275,13 +280,13 @@ fn is_clusterdown_error(frame: &Resp3Frame) -> Option<&str> { match frame { Resp3Frame::SimpleError { data, .. } => { if data.trim().starts_with("CLUSTERDOWN") { - Some(&data) + Some(data) } else { None } }, Resp3Frame::BlobError { data, .. } => { - let parsed = match str::from_utf8(&data) { + let parsed = match str::from_utf8(data) { Ok(s) => s, Err(_) => return None, }; diff --git a/src/router/sentinel.rs b/src/router/sentinel.rs index a1450983..5f49a818 100644 --- a/src/router/sentinel.rs +++ b/src/router/sentinel.rs @@ -2,7 +2,6 @@ use crate::{ error::{RedisError, RedisErrorKind}, - globals::globals, modules::inner::RedisClientInner, protocol::{ command::{RedisCommand, RedisCommandKind}, @@ -11,28 +10,30 @@ use crate::{ }, router::{centralized, Connections}, types::{RedisValue, Server, ServerConfig}, + utils, }; +use bytes_utils::Str; use std::{ collections::{HashMap, HashSet}, sync::Arc, }; -pub static CONFIG: &'static str = "CONFIG"; -pub static SET: &'static str = "SET"; -pub static CKQUORUM: &'static str = "CKQUORUM"; -pub static FLUSHCONFIG: &'static str = "FLUSHCONFIG"; -pub static FAILOVER: &'static str = "FAILOVER"; -pub static GET_MASTER_ADDR_BY_NAME: &'static str = "GET-MASTER-ADDR-BY-NAME"; -pub static INFO_CACHE: &'static str = "INFO-CACHE"; -pub static MASTERS: &'static str = "MASTERS"; -pub static MASTER: &'static str = "MASTER"; -pub static MONITOR: &'static str = "MONITOR"; -pub static MYID: &'static str = "MYID"; -pub static PENDING_SCRIPTS: &'static str = "PENDING-SCRIPTS"; -pub static REMOVE: &'static str = "REMOVE"; -pub static REPLICAS: &'static str = "REPLICAS"; -pub static SENTINELS: &'static str = "SENTINELS"; -pub static SIMULATE_FAILURE: &'static str = "SIMULATE-FAILURE"; +pub static CONFIG: &str = "CONFIG"; +pub static SET: &str = "SET"; +pub static CKQUORUM: &str = "CKQUORUM"; +pub static FLUSHCONFIG: &str = "FLUSHCONFIG"; +pub static FAILOVER: &str = "FAILOVER"; +pub static GET_MASTER_ADDR_BY_NAME: &str = "GET-MASTER-ADDR-BY-NAME"; +pub static INFO_CACHE: &str = "INFO-CACHE"; +pub static MASTERS: &str = "MASTERS"; +pub static MASTER: &str = "MASTER"; +pub static MONITOR: &str = "MONITOR"; +pub static MYID: &str = "MYID"; +pub static PENDING_SCRIPTS: &str = "PENDING-SCRIPTS"; +pub static REMOVE: &str = "REMOVE"; +pub static REPLICAS: &str = "REPLICAS"; +pub static SENTINELS: &str = "SENTINELS"; +pub static SIMULATE_FAILURE: &str = "SIMULATE-FAILURE"; macro_rules! stry ( ($expr:expr) => { @@ -158,25 +159,17 @@ async fn read_sentinels( /// Connect to any of the sentinel nodes provided on the associated `RedisConfig`. async fn connect_to_sentinel(inner: &Arc) -> Result { let (hosts, (username, password)) = read_sentinel_nodes_and_auth(inner)?; - let timeout = globals().sentinel_connection_timeout_ms() as u64; for server in hosts.into_iter() { _debug!(inner, "Connecting to sentinel {}", server); - let mut transport = try_or_continue!( - connection::create( - inner, - server.host.as_str().to_owned(), - server.port, - Some(timeout), - server.tls_server_name.as_ref() + let mut transport = try_or_continue!(connection::create(inner, &server, None).await); + try_or_continue!( + utils::apply_timeout( + transport.authenticate(&inner.id, username.clone(), password.clone(), false), + inner.internal_command_timeout() ) .await ); - let _ = try_or_continue!( - transport - .authenticate(&inner.id, username.clone(), password.clone(), false) - .await - ); return Ok(transport); } @@ -208,19 +201,29 @@ async fn discover_primary_node( static_val!(GET_MASTER_ADDR_BY_NAME), service_name.into(), ]); - let frame = sentinel.request_response(command, false).await?; + let frame = utils::apply_timeout( + sentinel.request_response(command, false), + inner.internal_command_timeout(), + ) + .await?; let response = stry!(protocol_utils::frame_to_results(frame)); - let (host, port): (String, u16) = if response.is_null() { + let server = if response.is_null() { return Err(RedisError::new( RedisErrorKind::Sentinel, "Missing primary address in response from sentinel node.", )); } else { - stry!(response.convert()) + let (host, port): (Str, u16) = stry!(response.convert()); + Server { + host, + port, + #[cfg(any(feature = "enable-rustls", feature = "enable-native-tls"))] + tls_server_name: None, + } }; - let mut transport = stry!(connection::create(inner, host, port, None, None).await); - let _ = stry!(transport.setup(inner, None).await); + let mut transport = stry!(connection::create(inner, &server, None).await); + stry!(transport.setup(inner, None).await); Ok(transport) } @@ -318,9 +321,19 @@ pub async fn initialize_connection( Connections::Sentinel { writer } => { let mut sentinel = connect_to_sentinel(inner).await?; let mut transport = discover_primary_node(inner, &mut sentinel).await?; - let _ = check_primary_node_role(inner, &mut transport).await?; - let _ = update_cached_client_state(inner, writer, sentinel, transport).await?; + let server = transport.server.clone(); + + utils::apply_timeout( + async { + check_primary_node_role(inner, &mut transport).await?; + update_cached_client_state(inner, writer, sentinel, transport).await?; + Ok::<_, RedisError>(()) + }, + inner.internal_command_timeout(), + ) + .await?; + inner.notifications.broadcast_reconnect(server); Ok(()) }, _ => Err(RedisError::new( diff --git a/src/router/transactions.rs b/src/router/transactions.rs index 9f53436f..4cc5bf26 100644 --- a/src/router/transactions.rs +++ b/src/router/transactions.rs @@ -2,19 +2,13 @@ use crate::{ error::{RedisError, RedisErrorKind}, interfaces::Resp3Frame, modules::inner::RedisClientInner, - router::{utils, Router}, protocol::{ - command::{ - ClusterErrorKind, - RouterReceiver, - RouterResponse, - RedisCommand, - RedisCommandKind, - ResponseSender, - }, + command::{ClusterErrorKind, RedisCommand, RedisCommandKind, ResponseSender, RouterReceiver, RouterResponse}, responders::ResponseKind, }, + router::{utils, Router, Written}, types::{ClusterHash, Server}, + utils as client_utils, }; use std::sync::Arc; @@ -49,17 +43,25 @@ async fn write_command( ) -> Result { _trace!( inner, - "Sending trx command {} to {}", + "Sending trx command {} ({}) to {}", command.kind.to_str_debug(), + command.debug_id(), server ); - if let Err(e) = router.write_once(command, server).await { + let timeout_dur = command.timeout_dur.unwrap_or_else(|| inner.default_command_timeout()); + let result = match router.write_direct(command, server).await { + Written::Error((error, _)) => Err(error), + Written::Disconnected((_, _, error)) => Err(error), + Written::NotFound(_) => Err(RedisError::new(RedisErrorKind::Cluster, "Connection not found.")), + _ => Ok(()), + }; + if let Err(e) = result { _debug!(inner, "Error writing trx command: {:?}", e); return Ok(TransactionResponse::Retry(e)); } - match rx.await? { + match client_utils::apply_timeout(rx, timeout_dur).await? { RouterResponse::Continue => Ok(TransactionResponse::Continue), RouterResponse::Ask((slot, server, _)) => { Ok(TransactionResponse::Redirection((ClusterErrorKind::Ask, slot, server))) @@ -113,7 +115,7 @@ async fn send_discard( write_command(inner, router, server, command, true, rx).await } -fn update_hash_slot(commands: &mut Vec, slot: u16) { +fn update_hash_slot(commands: &mut [RedisCommand], slot: u16) { for command in commands.iter_mut() { command.hasher = ClusterHash::Custom(slot); } @@ -134,26 +136,41 @@ pub async fn run( let _ = tx.send(Ok(Resp3Frame::Null)); return Ok(()); } + // each of the commands should have the same options + let max_attempts = if commands[0].attempts_remaining == 0 { + inner.max_command_attempts() + } else { + commands[0].attempts_remaining + }; + let max_redirections = if commands[0].redirections_remaining == 0 { + inner.connection.max_redirections + } else { + commands[0].redirections_remaining + }; let mut attempted = 0; + let mut redirections = 0; 'outer: loop { _debug!(inner, "Starting transaction {} (attempted: {})", id, attempted); - let server = match router.find_connection(&commands[0]) { - Some(server) => server.clone(), - None => { - let _ = if inner.config.server.is_clustered() { - // optimistically sync the cluster, then fall back to a full reconnect - if router.sync_cluster().await.is_err() { + let server = if let Some(server) = commands[0].cluster_node.as_ref() { + server.clone() + } else { + match router.find_connection(&commands[0]) { + Some(server) => server.clone(), + None => { + if inner.config.server.is_clustered() { + // optimistically sync the cluster, then fall back to a full reconnect + if router.sync_cluster().await.is_err() { + utils::reconnect_with_policy(inner, router).await? + } + } else { utils::reconnect_with_policy(inner, router).await? - } - } else { - utils::reconnect_with_policy(inner, router).await? - }; + }; - attempted += 1; - continue; - }, + continue; + }, + } }; let mut idx = 0; @@ -176,22 +193,29 @@ pub async fn run( Ok(TransactionResponse::Retry(error)) => { _debug!(inner, "Retrying trx {} after WATCH error: {:?}.", id, error); - if attempted >= inner.max_command_attempts() { + attempted += 1; + if attempted >= max_attempts { let _ = tx.send(Err(error)); return Ok(()); } else { - let _ = utils::reconnect_with_policy(inner, router).await?; + utils::reconnect_with_policy(inner, router).await?; } - attempted += 1; continue 'outer; }, Ok(TransactionResponse::Redirection((kind, slot, server))) => { + redirections += 1; + if redirections > max_redirections { + let _ = tx.send(Err(RedisError::new( + RedisErrorKind::Cluster, + "Too many cluster redirections.", + ))); + return Ok(()); + } + _debug!(inner, "Recv {} redirection to {} for WATCH in trx {}", kind, server, id); update_hash_slot(&mut commands, slot); - let _ = utils::cluster_redirect_with_policy(inner, router, kind, slot, &server).await?; - - attempted += 1; + utils::cluster_redirect_with_policy(inner, router, kind, slot, &server).await?; continue 'outer; }, Ok(TransactionResponse::Finished(frame)) => { @@ -206,7 +230,10 @@ pub async fn run( }; } - // start sending the trx commands + if attempted > 0 { + inner.counters.incr_redelivery_count(); + } + // send each of the commands. the first one is always MULTI 'inner: while idx < commands.len() { let command = commands[idx].duplicate(ResponseKind::Skip); let rx = command.create_router_channel(); @@ -222,24 +249,32 @@ pub async fn run( _warn!(inner, "Error sending DISCARD in trx {}: {:?}", id, e); } - if attempted >= inner.max_command_attempts() { + attempted += 1; + if attempted >= max_attempts { let _ = tx.send(Err(error)); return Ok(()); } else { - let _ = utils::reconnect_with_policy(inner, router).await?; + utils::reconnect_with_policy(inner, router).await?; } - attempted += 1; continue 'outer; }, Ok(TransactionResponse::Redirection((kind, slot, server))) => { + redirections += 1; + if redirections > max_redirections { + let _ = tx.send(Err(RedisError::new( + RedisErrorKind::Cluster, + "Too many cluster redirections.", + ))); + return Ok(()); + } + update_hash_slot(&mut commands, slot); if let Err(e) = send_discard(inner, router, &server, id).await { _warn!(inner, "Error sending DISCARD in trx {}: {:?}", id, e); } - let _ = utils::cluster_redirect_with_policy(inner, router, kind, slot, &server).await?; + utils::cluster_redirect_with_policy(inner, router, kind, slot, &server).await?; - attempted += 1; continue 'outer; }, Ok(TransactionResponse::Finished(frame)) => { @@ -266,14 +301,14 @@ pub async fn run( _warn!(inner, "Error sending DISCARD in trx {}: {:?}", id, e); } - if attempted >= inner.max_command_attempts() { + attempted += 1; + if attempted >= max_attempts { let _ = tx.send(Err(error)); return Ok(()); } else { - let _ = utils::reconnect_with_policy(inner, router).await?; + utils::reconnect_with_policy(inner, router).await?; } - attempted += 1; continue 'outer; }, Ok(TransactionResponse::Redirection((kind, slot, dest))) => { diff --git a/src/router/types.rs b/src/router/types.rs index 310b249f..a4a1e7d6 100644 --- a/src/router/types.rs +++ b/src/router/types.rs @@ -1,5 +1,7 @@ use crate::protocol::types::Server; +#[cfg(all(feature = "replicas", feature = "check-unresponsive"))] +use crate::protocol::connection::RedisWriter; #[cfg(feature = "check-unresponsive")] use crate::{ globals::globals, @@ -59,6 +61,8 @@ impl ConnectionState { let guard = self.interrupts.read(); for server in servers.into_iter() { + inner.notifications.broadcast_unresponsive(server.clone()); + if let Some(tx) = guard.get(&server) { _debug!(inner, "Interrupting reader task for {}", server); let _ = tx.send(()); @@ -99,14 +103,28 @@ impl ConnectionState { }; } + /// Add the replica connections to the internal connection map. + #[cfg(feature = "replicas")] + pub fn sync_replicas(&self, inner: &Arc, replicas: &HashMap) { + _debug!( + inner, + "Syncing replica connection state with unresponsive network task." + ); + let mut guard = self.commands.write(); + for (server, writer) in replicas.iter() { + guard.insert(server.clone(), writer.buffer.clone()); + } + } + pub fn unresponsive_connections(&self, inner: &Arc) -> VecDeque { _debug!(inner, "Checking unresponsive connections..."); let now = Instant::now(); - let timeout_duration = inner.with_perf_config(|perf| { - _trace!(inner, "Using network timeout: {}", perf.network_timeout_ms); - Duration::from_millis(perf.network_timeout_ms) - }); + _trace!( + inner, + "Using network timeout: {:?}", + inner.connection.unresponsive_timeout + ); let mut unresponsive = VecDeque::new(); for (server, commands) in self.commands.read().iter() { @@ -132,7 +150,7 @@ impl ConnectionState { continue; } let command_duration = now.duration_since(last_command_sent); - if command_duration > timeout_duration { + if command_duration > inner.connection.unresponsive_timeout { _warn!( inner, "Server {} unresponsive after {} ms", diff --git a/src/router/utils.rs b/src/router/utils.rs index 1383ab7c..020b1404 100644 --- a/src/router/utils.rs +++ b/src/router/utils.rs @@ -24,6 +24,8 @@ use tokio::{ sync::{mpsc::UnboundedReceiver, oneshot::channel as oneshot_channel}, }; +#[cfg(feature = "replicas")] +use crate::{interfaces, protocol::command::RouterCommand}; #[cfg(feature = "check-unresponsive")] use futures::future::Either; #[cfg(feature = "check-unresponsive")] @@ -49,12 +51,12 @@ pub fn check_backpressure( BackpressurePolicy::Drain => Ok(Some(Backpressure::Block)), BackpressurePolicy::Sleep { disable_backpressure_scaling, - min_sleep_duration_ms, + min_sleep_duration, } => { let duration = if disable_backpressure_scaling { - Duration::from_millis(min_sleep_duration_ms) + min_sleep_duration } else { - Duration::from_millis(cmp::max(min_sleep_duration_ms, in_flight as u64)) + Duration::from_millis(cmp::max(min_sleep_duration.as_millis() as u64, in_flight as u64)) }; Ok(Some(Backpressure::Wait(duration))) @@ -137,7 +139,7 @@ pub async fn write_command( }, Err(e) => { // return manual backpressure errors directly to the caller - command.respond_to_caller(Err(e)); + command.finish(inner, Err(e)); return Written::Ignore; }, _ => {}, @@ -148,7 +150,7 @@ pub async fn write_command( Err(e) => { _warn!(inner, "Frame encoding error for {}", command.kind.to_str_debug()); // do not retry commands that trigger frame encoding errors - command.respond_to_caller(Err(e)); + command.finish(inner, Err(e)); return Written::Ignore; }, }; @@ -160,24 +162,12 @@ pub async fn write_command( command.debug_id(), writer.server ); + command.write_attempts += 1; writer.push_command(inner, command); if let Err(e) = writer.write_frame(frame, should_flush).await { - let mut command = match writer.pop_recent_command() { - Some(cmd) => cmd, - None => { - _error!(inner, "Failed to take recent command off queue after write failure."); - return Written::Ignore; - }, - }; - - _debug!(inner, "Error sending command {}: {:?}", command.kind.to_str_debug(), e); - if command.should_send_write_error(inner) { - command.respond_to_caller(Err(e.clone())); - Written::Disconnect((Some(writer.server.clone()), None, e)) - } else { - inner.notifications.broadcast_error(e.clone()); - Written::Disconnect((Some(writer.server.clone()), Some(command), e)) - } + let command = writer.pop_recent_command(); + _debug!(inner, "Error sending command: {:?}", e); + Written::Disconnected((Some(writer.server.clone()), command, e)) } else { Written::Sent((writer.server.clone(), should_flush)) } @@ -219,17 +209,18 @@ pub fn check_final_write_attempt(inner: &Arc, buffer: &SharedB let mut guard = buffer.lock(); let commands = guard .drain(..) - .filter_map(|mut command| { - if command.has_router_channel() { - if command.attempted >= inner.max_command_attempts() { - let error = error - .clone() - .unwrap_or(RedisError::new(RedisErrorKind::IO, "Connection Closed")); - command.respond_to_caller(Err(error)); - None - } else { - Some(command) - } + .filter_map(|command| { + if command.should_finish_with_error(inner) { + command.finish( + inner, + Err( + error + .clone() + .unwrap_or(RedisError::new(RedisErrorKind::IO, "Connection Closed")), + ), + ); + + None } else { Some(command) } @@ -294,7 +285,7 @@ pub fn next_reconnection_delay(inner: &Arc) -> Result, router: &mut Router) inner.notifications.broadcast_error(e.clone()); Err(e) } else { - // try to flush any previously in-flight commands - router.retry_buffer().await; - if let Err(e) = router.sync_replicas().await { _warn!(inner, "Error syncing replicas: {:?}", e); if !inner.ignore_replica_reconnect_errors() { @@ -321,23 +309,26 @@ pub async fn reconnect_once(inner: &Arc, router: &mut Router) return Err(e); } } + // try to flush any previously in-flight commands + router.retry_buffer().await; client_utils::set_client_state(&inner.state, ClientState::Connected); inner.notifications.broadcast_connect(Ok(())); - inner.notifications.broadcast_reconnect(); inner.reset_reconnection_attempts(); Ok(()) } } /// Reconnect to the server(s) until the max reconnect policy attempts are reached. +/// +/// Errors from this function should end the connection task. pub async fn reconnect_with_policy(inner: &Arc, router: &mut Router) -> Result<(), RedisError> { let mut delay = utils::next_reconnection_delay(inner)?; loop { if !delay.is_zero() { _debug!(inner, "Sleeping for {} ms.", delay.as_millis()); - let _ = inner.wait_with_interrupt(delay).await?; + inner.wait_with_interrupt(delay).await?; } if let Err(e) = reconnect_once(inner, router).await { @@ -363,12 +354,12 @@ pub async fn cluster_redirect_with_policy( slot: u16, server: &Server, ) -> Result<(), RedisError> { - let mut delay = inner.with_perf_config(|perf| Duration::from_millis(perf.cluster_cache_update_delay_ms as u64)); + let mut delay = inner.connection.cluster_cache_update_delay; loop { if !delay.is_zero() { _debug!(inner, "Sleeping for {} ms.", delay.as_millis()); - let _ = inner.wait_with_interrupt(delay).await?; + inner.wait_with_interrupt(delay).await?; } if let Err(e) = router.cluster_redirection(&kind, slot, server).await { @@ -383,19 +374,21 @@ pub async fn cluster_redirect_with_policy( Ok(()) } -/// Repeatedly try to send `ASKING` to the provided server, reconnecting as needed. +/// Repeatedly try to send `ASKING` to the provided server, reconnecting as needed.f +/// +/// Errors from this function should end the connection task. pub async fn send_asking_with_policy( inner: &Arc, router: &mut Router, server: &Server, slot: u16, ) -> Result<(), RedisError> { - let mut delay = inner.with_perf_config(|perf| Duration::from_millis(perf.cluster_cache_update_delay_ms as u64)); + let mut delay = inner.connection.cluster_cache_update_delay; loop { if !delay.is_zero() { _debug!(inner, "Sleeping for {} ms.", delay.as_millis()); - let _ = inner.wait_with_interrupt(delay).await?; + inner.wait_with_interrupt(delay).await?; } if !router.connections.has_server_connection(server) { @@ -411,20 +404,25 @@ pub async fn send_asking_with_policy( command.skip_backpressure = true; command.response = ResponseKind::Respond(Some(tx)); - if let Err(error) = router.write_once(command, server).await { + let result = match router.write_direct(command, server).await { + Written::Error((error, _)) => Err(error), + Written::Disconnected((_, _, error)) => Err(error), + Written::NotFound(_) => Err(RedisError::new(RedisErrorKind::Cluster, "Connection not found.")), + _ => Ok(()), + }; + + if let Err(error) = result { if error.should_not_reconnect() { break; + } else if let Err(_) = reconnect_once(inner, router).await { + delay = utils::next_reconnection_delay(inner)?; + continue; } else { - if let Err(_) = reconnect_once(inner, router).await { - delay = utils::next_reconnection_delay(inner)?; - continue; - } else { - delay = Duration::from_millis(0); - continue; - } + delay = Duration::from_millis(0); + continue; } } else { - match rx.await { + match client_utils::apply_timeout(rx, inner.internal_command_timeout()).await { Ok(Err(e)) => { // error writing the command _debug!(inner, "Reconnect once after error from ASKING: {:?}", e); @@ -490,13 +488,15 @@ pub async fn sync_replicas_with_policy(inner: &Arc, router: &m } /// Repeatedly try to sync the cluster state, reconnecting as needed until the max reconnection attempts is reached. +/// +/// Errors from this function should end the connection task. pub async fn sync_cluster_with_policy(inner: &Arc, router: &mut Router) -> Result<(), RedisError> { - let mut delay = inner.with_perf_config(|config| Duration::from_millis(config.cluster_cache_update_delay_ms as u64)); + let mut delay = inner.connection.cluster_cache_update_delay; loop { if !delay.is_zero() { _debug!(inner, "Sleeping for {} ms.", delay.as_millis()); - let _ = inner.wait_with_interrupt(delay).await?; + inner.wait_with_interrupt(delay).await?; } if let Err(e) = router.sync_cluster().await { @@ -521,6 +521,15 @@ pub async fn sync_cluster_with_policy(inner: &Arc, router: &mu Ok(()) } +#[cfg(feature = "replicas")] +pub fn defer_replica_sync(inner: &Arc) { + let (tx, _) = oneshot_channel(); + let cmd = RouterCommand::SyncReplicas { tx }; + if let Err(_) = interfaces::send_to_router(inner, cmd) { + _warn!(inner, "Failed to start deferred replica sync.") + } +} + #[cfg(feature = "check-unresponsive")] pub async fn next_frame( inner: &Arc, diff --git a/src/trace/disabled.rs b/src/trace/disabled.rs index ab30684f..f4f18641 100644 --- a/src/trace/disabled.rs +++ b/src/trace/disabled.rs @@ -15,12 +15,12 @@ pub struct Span {} #[cfg(not(feature = "full-tracing"))] impl Span { - pub fn enter(&self) -> () { - () + pub fn enter(&self) { + } pub fn record(&self, _field: &Q, _value: &V) -> &Self { - &self + self } } diff --git a/src/trace/enabled.rs b/src/trace/enabled.rs index 04b8865c..37cc68f9 100644 --- a/src/trace/enabled.rs +++ b/src/trace/enabled.rs @@ -3,7 +3,7 @@ use crate::{ protocol::{command::RedisCommand, utils as protocol_utils}, }; use redis_protocol::resp3::types::Frame; -use std::{fmt, sync::Arc}; +use std::{fmt, ops::Deref, sync::Arc}; pub use tracing::span::Span; use tracing::{event, field::Empty, Id as TraceId, Level}; @@ -61,7 +61,7 @@ pub fn create_command_span(inner: &Arc) -> Span { inner.tracing_span_level(), "redis_command", module = "fred", - client_id = inner.id.as_str(), + client_id = &inner.id.deref(), cmd = Empty, req_size = Empty, res_size = Empty @@ -97,7 +97,7 @@ pub fn create_pubsub_span(inner: &Arc, frame: &Frame) -> Optio parent: None, "parse_pubsub", module = "fred", - client_id = &inner.id.as_str(), + client_id = &inner.id.deref(), res_size = &protocol_utils::resp3_frame_size(frame), channel = Empty ); @@ -106,7 +106,6 @@ pub fn create_pubsub_span(inner: &Arc, frame: &Frame) -> Optio } else { None } - } #[cfg(not(feature = "full-tracing"))] diff --git a/src/types/acl.rs b/src/types/acl.rs deleted file mode 100644 index 0d8b2944..00000000 --- a/src/types/acl.rs +++ /dev/null @@ -1,132 +0,0 @@ -use crate::types::RedisValue; - -/// ACL rules describing the keys a user can access. -#[derive(Clone, Debug, Eq, PartialEq)] -pub enum AclKeyPattern { - AllKeys, - Custom(String), -} - -impl AclKeyPattern { - pub(crate) fn to_value(&self) -> RedisValue { - match *self { - AclKeyPattern::AllKeys => RedisValue::from_static_str("allkeys"), - AclKeyPattern::Custom(ref pat) => format!("~{}", pat).into(), - } - } -} - -/// ACL rules describing the channels a user can access. -#[derive(Clone, Debug, Eq, PartialEq)] -pub enum AclChannelPattern { - AllChannels, - Custom(String), -} - -impl AclChannelPattern { - pub(crate) fn to_value(&self) -> RedisValue { - match *self { - AclChannelPattern::AllChannels => RedisValue::from_static_str("allchannels"), - AclChannelPattern::Custom(ref pat) => format!("&{}", pat).into(), - } - } -} - -/// ACL rules describing the commands a user can access. -#[derive(Clone, Debug, Eq, PartialEq)] -pub enum AclCommandPattern { - AllCommands, - NoCommands, - Custom { - command: String, - subcommand: Option, - }, -} - -impl AclCommandPattern { - pub(crate) fn to_value(&self, prefix: &'static str) -> RedisValue { - match *self { - AclCommandPattern::AllCommands => RedisValue::from_static_str("allcommands"), - AclCommandPattern::NoCommands => RedisValue::from_static_str("nocommands"), - AclCommandPattern::Custom { - ref command, - ref subcommand, - } => { - if let Some(subcommand) = subcommand { - format!("{}{}|{}", prefix, command, subcommand).into() - } else { - format!("{}{}", prefix, command).into() - } - }, - } - } -} - -/// ACL rules associated with a user. -/// -/// -#[derive(Clone, Debug, Eq, PartialEq)] -pub enum AclRule { - On, - Off, - Reset, - ResetChannels, - ResetKeys, - AddKeys(AclKeyPattern), - AddChannels(AclChannelPattern), - AddCommands(AclCommandPattern), - RemoveCommands(AclCommandPattern), - AddCategory(String), - RemoveCategory(String), - NoPass, - AddPassword(String), - AddHashedPassword(String), - RemovePassword(String), - RemoveHashedPassword(String), -} - -impl AclRule { - pub(crate) fn to_value(&self) -> RedisValue { - match self { - AclRule::On => RedisValue::from_static_str("on"), - AclRule::Off => RedisValue::from_static_str("off"), - AclRule::Reset => RedisValue::from_static_str("reset"), - AclRule::ResetChannels => RedisValue::from_static_str("resetchannels"), - AclRule::ResetKeys => RedisValue::from_static_str("resetkeys"), - AclRule::NoPass => RedisValue::from_static_str("nopass"), - AclRule::AddPassword(ref pass) => format!(">{}", pass).into(), - AclRule::RemovePassword(ref pass) => format!("<{}", pass).into(), - AclRule::AddHashedPassword(ref pass) => format!("#{}", pass).into(), - AclRule::RemoveHashedPassword(ref pass) => format!("!{}", pass).into(), - AclRule::AddCategory(ref cat) => format!("+@{}", cat).into(), - AclRule::RemoveCategory(ref cat) => format!("-@{}", cat).into(), - AclRule::AddKeys(ref pat) => pat.to_value(), - AclRule::AddChannels(ref pat) => pat.to_value(), - AclRule::AddCommands(ref pat) => pat.to_value("+"), - AclRule::RemoveCommands(ref pat) => pat.to_value("-"), - } - } -} - -/// A flag from the ACL GETUSER command. -#[derive(Clone, Debug, Eq, PartialEq)] -pub enum AclUserFlag { - On, - Off, - AllKeys, - AllChannels, - AllCommands, - NoPass, -} - -/// An ACL user from the ACL GETUSER command. -/// -/// -#[derive(Clone, Debug, Eq, PartialEq, Default)] -pub struct AclUser { - pub flags: Vec, - pub passwords: Vec, - pub commands: Vec, - pub keys: Vec, - pub channels: Vec, -} diff --git a/src/types/args.rs b/src/types/args.rs index bf972785..bdcb7f9a 100644 --- a/src/types/args.rs +++ b/src/types/args.rs @@ -2,7 +2,7 @@ use crate::{ error::{RedisError, RedisErrorKind}, interfaces::{ClientLike, Resp3Frame}, protocol::{connection::OK, utils as protocol_utils}, - types::{FromRedis, FromRedisKey, GeoPosition, XReadResponse, XReadValue, NIL, QUEUED}, + types::{FromRedis, FromRedisKey, GeoPosition, XReadResponse, XReadValue, QUEUED}, utils, }; use bytes::Bytes; @@ -21,7 +21,7 @@ use std::{ str, }; -use crate::types::{Function, Server}; +use crate::types::{Function, GeoRadiusInfo, Server}; #[cfg(feature = "serde-json")] use serde_json::Value; @@ -259,10 +259,10 @@ impl TryFrom for RedisKey { RedisValue::Bytes(b) => RedisKey { key: b }, RedisValue::Boolean(b) => match b { true => RedisKey { - key: TRUE_STR.clone().into_inner().into(), + key: TRUE_STR.clone().into_inner(), }, false => RedisKey { - key: FALSE_STR.clone().into_inner().into(), + key: FALSE_STR.clone().into_inner(), }, }, RedisValue::Queued => utils::static_str(QUEUED).into(), @@ -296,14 +296,6 @@ impl<'a> From<&'a [u8]> for RedisKey { } } -// doing this prevents MultipleKeys from being generic in its `From` implementations since the compiler cant know what -// to do with `Vec`. -// impl From> for RedisKey { -// fn from(b: Vec) -> Self { -// RedisKey { key: b.into() } -// } -// } - impl From for RedisKey { fn from(s: String) -> Self { RedisKey { key: s.into() } @@ -366,28 +358,6 @@ impl_from_str_for_redis_key!(isize); impl_from_str_for_redis_key!(f32); impl_from_str_for_redis_key!(f64); -#[cfg(feature = "serde-json")] -#[cfg_attr(docsrs, doc(cfg(feature = "serde-json")))] -impl TryFrom for RedisKey { - type Error = RedisError; - - fn try_from(value: Value) -> Result { - let value: RedisKey = match value { - Value::String(s) => s.into(), - Value::Bool(b) => b.to_string().into(), - Value::Number(n) => n.to_string().into(), - _ => { - return Err(RedisError::new( - RedisErrorKind::InvalidArgument, - "Cannot convert to key from JSON.", - )) - }, - }; - - Ok(value) - } -} - /// A map of `(RedisKey, RedisValue)` pairs. #[derive(Clone, Debug, Eq, PartialEq)] pub struct RedisMap { @@ -403,7 +373,7 @@ impl RedisMap { /// Replace the value an empty map, returning the original value. pub fn take(&mut self) -> Self { RedisMap { - inner: mem::replace(&mut self.inner, HashMap::new()), + inner: std::mem::take(&mut self.inner), } } @@ -522,31 +492,6 @@ where } } -#[cfg(feature = "serde-json")] -#[cfg_attr(docsrs, doc(cfg(feature = "serde-json")))] -impl TryFrom for RedisMap { - type Error = RedisError; - - fn try_from(value: Value) -> Result { - if let Value::Object(map) = value { - let mut inner = HashMap::with_capacity(map.len()); - for (key, value) in map.into_iter() { - let key: RedisKey = key.into(); - let value: RedisValue = value.try_into()?; - - inner.insert(key, value); - } - - Ok(RedisMap { inner }) - } else { - Err(RedisError::new( - RedisErrorKind::InvalidArgument, - "Cannot convert non-object JSON value to map.", - )) - } - } -} - /// The kind of value from Redis. #[derive(Clone, Debug, Eq, PartialEq)] pub enum RedisValueKind { @@ -590,7 +535,7 @@ pub enum RedisValue { Double(f64), /// A string value. String(Str), - /// A value to represent non-UTF8 strings or byte arrays. + /// A byte array value. Bytes(Bytes), /// A `nil` value. Null, @@ -600,10 +545,11 @@ pub enum RedisValue { Map(RedisMap), /// An ordered list of values. /// - /// In RESP2 mode the server may send map structures as an array of key/value pairs. + /// In RESP2 mode the server usually sends map structures as an array of key/value pairs. Array(Vec), } +#[allow(clippy::match_like_matches_macro)] impl PartialEq for RedisValue { fn eq(&self, other: &Self) -> bool { use RedisValue::*; @@ -651,7 +597,7 @@ impl PartialEq for RedisValue { impl Eq for RedisValue {} -impl<'a> RedisValue { +impl RedisValue { /// Create a new `RedisValue::Bytes` from a static byte slice without copying. pub fn from_static(b: &'static [u8]) -> Self { RedisValue::Bytes(Bytes::from_static(b)) @@ -704,37 +650,26 @@ impl<'a> RedisValue { /// Check if the value is null. pub fn is_null(&self) -> bool { - match *self { - RedisValue::Null => true, - _ => false, - } + matches!(*self, RedisValue::Null) } /// Check if the value is an integer. pub fn is_integer(&self) -> bool { - match *self { - RedisValue::Integer(_) => true, - _ => false, - } + matches!(self, RedisValue::Integer(_)) } /// Check if the value is a string. pub fn is_string(&self) -> bool { - match *self { - RedisValue::String(_) => true, - _ => false, - } + matches!(*self, RedisValue::String(_)) } /// Check if the value is an array of bytes. pub fn is_bytes(&self) -> bool { - match *self { - RedisValue::Bytes(_) => true, - _ => false, - } + matches!(*self, RedisValue::Bytes(_)) } /// Whether or not the value is a boolean value or can be parsed as a boolean value. + #[allow(clippy::match_like_matches_macro)] pub fn is_boolean(&self) -> bool { match *self { RedisValue::Boolean(_) => true, @@ -761,18 +696,12 @@ impl<'a> RedisValue { /// Check if the value is a `QUEUED` response. pub fn is_queued(&self) -> bool { - match *self { - RedisValue::Queued => true, - _ => false, - } + matches!(*self, RedisValue::Queued) } /// Whether or not the value is an array or map. pub fn is_aggregate_type(&self) -> bool { - match *self { - RedisValue::Array(_) | RedisValue::Map(_) => true, - _ => false, - } + matches!(*self, RedisValue::Array(_) | RedisValue::Map(_)) } /// Whether or not the value is a `RedisMap`. @@ -780,10 +709,7 @@ impl<'a> RedisValue { /// See [is_maybe_map](Self::is_maybe_map) for a function that also checks for arrays that likely represent a map in /// RESP2 mode. pub fn is_map(&self) -> bool { - match *self { - RedisValue::Map(_) => true, - _ => false, - } + matches!(*self, RedisValue::Map(_)) } /// Whether or not the value is a `RedisMap` or an array with an even number of elements where each even-numbered @@ -801,10 +727,7 @@ impl<'a> RedisValue { /// Whether or not the value is an array. pub fn is_array(&self) -> bool { - match *self { - RedisValue::Array(_) => true, - _ => false, - } + matches!(*self, RedisValue::Array(_)) } /// Read and return the inner value as a `u64`, if possible. @@ -825,6 +748,10 @@ impl<'a> RedisValue { None } }, + #[cfg(feature = "default-nil-types")] + RedisValue::Null => Some(0), + #[cfg(not(feature = "default-nil-types"))] + RedisValue::Null => None, _ => None, } } @@ -841,6 +768,10 @@ impl<'a> RedisValue { None } }, + #[cfg(feature = "default-nil-types")] + RedisValue::Null => Some(0), + #[cfg(not(feature = "default-nil-types"))] + RedisValue::Null => None, _ => None, } } @@ -863,6 +794,10 @@ impl<'a> RedisValue { None } }, + #[cfg(feature = "default-nil-types")] + RedisValue::Null => Some(0), + #[cfg(not(feature = "default-nil-types"))] + RedisValue::Null => None, _ => None, } } @@ -880,6 +815,10 @@ impl<'a> RedisValue { None } }, + #[cfg(feature = "default-nil-types")] + RedisValue::Null => Some(0.0), + #[cfg(not(feature = "default-nil-types"))] + RedisValue::Null => None, _ => None, } } @@ -900,6 +839,10 @@ impl<'a> RedisValue { None } }, + #[cfg(feature = "default-nil-types")] + RedisValue::Null => Some(String::new()), + #[cfg(not(feature = "default-nil-types"))] + RedisValue::Null => None, _ => None, } } @@ -923,6 +866,10 @@ impl<'a> RedisValue { None } }, + #[cfg(feature = "default-nil-types")] + RedisValue::Null => Some(Str::new()), + #[cfg(not(feature = "default-nil-types"))] + RedisValue::Null => None, _ => None, } } @@ -946,6 +893,10 @@ impl<'a> RedisValue { None } }, + #[cfg(feature = "default-nil-types")] + RedisValue::Null => Some(Str::new()), + #[cfg(not(feature = "default-nil-types"))] + RedisValue::Null => None, _ => None, } } @@ -961,6 +912,10 @@ impl<'a> RedisValue { RedisValue::Bytes(ref b) => str::from_utf8(b).ok().map(|s| s.to_owned()), RedisValue::Integer(ref i) => Some(i.to_string()), RedisValue::Queued => Some(QUEUED.to_owned()), + #[cfg(feature = "default-nil-types")] + RedisValue::Null => Some(String::new()), + #[cfg(not(feature = "default-nil-types"))] + RedisValue::Null => None, _ => None, } } @@ -972,11 +927,14 @@ impl<'a> RedisValue { let s: Cow = match *self { RedisValue::Double(ref f) => Cow::Owned(f.to_string()), RedisValue::Boolean(ref b) => Cow::Owned(b.to_string()), - RedisValue::String(ref s) => Cow::Borrowed(s.deref().as_ref()), + RedisValue::String(ref s) => Cow::Borrowed(s.deref()), RedisValue::Integer(ref i) => Cow::Owned(i.to_string()), - RedisValue::Null => Cow::Borrowed(NIL), RedisValue::Queued => Cow::Borrowed(QUEUED), - RedisValue::Bytes(ref b) => return str::from_utf8(b).ok().map(|s| Cow::Borrowed(s)), + RedisValue::Bytes(ref b) => return str::from_utf8(b).ok().map(Cow::Borrowed), + #[cfg(feature = "default-nil-types")] + RedisValue::Null => Cow::Borrowed(""), + #[cfg(not(feature = "default-nil-types"))] + RedisValue::Null => return None, _ => return None, }; @@ -988,11 +946,14 @@ impl<'a> RedisValue { let s: Cow = match *self { RedisValue::Boolean(ref b) => Cow::Owned(b.to_string()), RedisValue::Double(ref f) => Cow::Owned(f.to_string()), - RedisValue::String(ref s) => Cow::Borrowed(s.deref().as_ref()), + RedisValue::String(ref s) => Cow::Borrowed(s.deref()), RedisValue::Integer(ref i) => Cow::Owned(i.to_string()), - RedisValue::Null => Cow::Borrowed(NIL), RedisValue::Queued => Cow::Borrowed(QUEUED), RedisValue::Bytes(ref b) => String::from_utf8_lossy(b), + #[cfg(feature = "default-nil-types")] + RedisValue::Null => Cow::Borrowed(""), + #[cfg(not(feature = "default-nil-types"))] + RedisValue::Null => return None, _ => return None, }; @@ -1023,7 +984,6 @@ impl<'a> RedisValue { b"false" | b"FALSE" | b"f" | b"F" | b"0" => Some(false), _ => None, }, - RedisValue::Null => Some(false), RedisValue::Array(ref inner) => { if inner.len() == 1 { inner.first().and_then(|v| v.as_bool()) @@ -1031,48 +991,61 @@ impl<'a> RedisValue { None } }, + #[cfg(feature = "default-nil-types")] + RedisValue::Null => Some(false), + #[cfg(not(feature = "default-nil-types"))] + RedisValue::Null => None, _ => None, } } /// Attempt to convert this value to a Redis map if it's an array with an even number of elements. pub fn into_map(self) -> Result { - if let RedisValue::Map(map) = self { - return Ok(map); - } + match self { + RedisValue::Map(map) => Ok(map), + RedisValue::Array(mut values) => { + if values.len() % 2 != 0 { + return Err(RedisError::new( + RedisErrorKind::Unknown, + "Expected an even number of elements.", + )); + } + let mut inner = HashMap::with_capacity(values.len() / 2); + while values.len() >= 2 { + let value = values.pop().unwrap(); + let key: RedisKey = values.pop().unwrap().try_into()?; - if let RedisValue::Array(mut values) = self { - if values.len() % 2 != 0 { - return Err(RedisError::new( - RedisErrorKind::Unknown, - "Expected an even number of elements.", - )); - } - let mut inner = HashMap::with_capacity(values.len() / 2); - while values.len() >= 2 { - let value = values.pop().unwrap(); - let key: RedisKey = values.pop().unwrap().try_into()?; + inner.insert(key, value); + } - inner.insert(key, value); - } + Ok(RedisMap { inner }) + }, + #[cfg(feature = "default-nil-types")] + RedisValue::Null => Ok(RedisMap::new()), + _ => Err(RedisError::new(RedisErrorKind::Unknown, "Could not convert to map.")), + } + } - Ok(RedisMap { inner }) - } else { - Err(RedisError::new(RedisErrorKind::Unknown, "Expected array.")) + pub(crate) fn into_multiple_values(self) -> Vec { + match self { + RedisValue::Array(values) => values, + RedisValue::Map(map) => map + .inner() + .into_iter() + .flat_map(|(k, v)| [RedisValue::Bytes(k.into_bytes()), v]) + .collect(), + RedisValue::Null => Vec::new(), + _ => vec![self], } } /// Convert the array value to a set, if possible. pub fn into_set(self) -> Result, RedisError> { - if let RedisValue::Array(values) = self { - let mut out = HashSet::with_capacity(values.len()); - - for value in values.into_iter() { - out.insert(value); - } - Ok(out) - } else { - Err(RedisError::new(RedisErrorKind::Unknown, "Expected array.")) + match self { + RedisValue::Array(values) => Ok(values.into_iter().collect()), + #[cfg(feature = "default-nil-types")] + RedisValue::Null => Ok(HashSet::new()), + _ => Err(RedisError::new_parse("Could not convert to set.")), } } @@ -1105,7 +1078,6 @@ impl<'a> RedisValue { let v = match self { RedisValue::String(s) => s.to_string().into_bytes(), RedisValue::Bytes(b) => b.to_vec(), - RedisValue::Null => NULL.as_bytes().to_vec(), RedisValue::Queued => QUEUED.as_bytes().to_vec(), RedisValue::Array(mut inner) => { if inner.len() == 1 { @@ -1115,6 +1087,10 @@ impl<'a> RedisValue { } }, RedisValue::Integer(i) => i.to_string().into_bytes(), + #[cfg(feature = "default-nil-types")] + RedisValue::Null => Vec::new(), + #[cfg(not(feature = "default-nil-types"))] + RedisValue::Null => return None, _ => return None, }; @@ -1126,7 +1102,6 @@ impl<'a> RedisValue { let v = match self { RedisValue::String(s) => s.inner().clone(), RedisValue::Bytes(b) => b, - RedisValue::Null => Bytes::from_static(NULL.as_bytes()), RedisValue::Queued => Bytes::from_static(QUEUED.as_bytes()), RedisValue::Array(mut inner) => { if inner.len() == 1 { @@ -1136,6 +1111,10 @@ impl<'a> RedisValue { } }, RedisValue::Integer(i) => i.to_string().into(), + #[cfg(feature = "default-nil-types")] + RedisValue::Null => Bytes::new(), + #[cfg(not(feature = "default-nil-types"))] + RedisValue::Null => return None, _ => return None, }; @@ -1150,6 +1129,26 @@ impl<'a> RedisValue { } } + /// Whether the value is an array with one element. + pub(crate) fn is_single_element_vec(&self) -> bool { + if let RedisValue::Array(ref d) = self { + d.len() == 1 + } else { + false + } + } + + /// Pop the first value in the inner array or return the original value. + /// + /// This uses unwrap. Use [is_single_element_vec] first. + pub(crate) fn pop_or_take(self) -> Self { + if let RedisValue::Array(mut values) = self { + values.pop().unwrap() + } else { + self + } + } + /// Flatten adjacent nested arrays to the provided depth. /// /// See the [XREAD](crate::interfaces::StreamsInterface::xread) documentation for an example of when this might be @@ -1225,7 +1224,29 @@ impl<'a> RedisValue { /// /// Null values are returned as `None` to work more easily with the result of the `GEOPOS` command. pub fn as_geo_position(&self) -> Result, RedisError> { - utils::value_to_geo_pos(self) + if self.is_null() { + Ok(None) + } else { + GeoPosition::try_from(self.clone()).map(Some) + } + } + + /// Parse the value as the response to any of the relevant GEO commands that return an array of + /// [GeoRadiusInfo](crate::types::GeoRadiusInfo) values, such as `GEOSEARCH`, GEORADIUS`, etc. + pub fn into_geo_radius_result( + self, + withcoord: bool, + withdist: bool, + withhash: bool, + ) -> Result, RedisError> { + match self { + RedisValue::Array(data) => data + .into_iter() + .map(|value| GeoRadiusInfo::from_redis_value(value, withcoord, withdist, withhash)) + .collect(), + RedisValue::Null => Ok(Vec::new()), + _ => Err(RedisError::new(RedisErrorKind::Parse, "Expected array.")), + } } /// Replace this value with `RedisValue::Null`, returning the original value. @@ -1234,34 +1255,6 @@ impl<'a> RedisValue { } /// Attempt to convert this value to any value that implements the [FromRedis](crate::types::FromRedis) trait. - /// - /// ```rust - /// # use fred::types::RedisValue; - /// # use std::collections::HashMap; - /// let foo: usize = RedisValue::String("123".into()).convert()?; - /// let foo: i64 = RedisValue::String("123".into()).convert()?; - /// let foo: String = RedisValue::String("123".into()).convert()?; - /// let foo: Vec = RedisValue::Bytes(vec![102, 111, 111].into()).convert()?; - /// let foo: Vec = RedisValue::String("foo".into()).convert()?; - /// let foo: Vec = RedisValue::Array(vec!["a".into(), "b".into()]).convert()?; - /// let foo: HashMap = - /// RedisValue::Array(vec!["a".into(), 1.into(), "b".into(), 2.into()]).convert()?; - /// let foo: (String, i64) = RedisValue::Array(vec!["a".into(), 1.into()]).convert()?; - /// let foo: Vec<(String, i64)> = - /// RedisValue::Array(vec!["a".into(), 1.into(), "b".into(), 2.into()]).convert()?; - /// // ... - /// ``` - /// **Performance Considerations** - /// - /// The backing data type for potentially large values is either [Str](https://docs.rs/bytes-utils/latest/bytes_utils/string/type.Str.html) or [Bytes](https://docs.rs/bytes/latest/bytes/struct.Bytes.html). - /// - /// These values represent views into the buffer that receives data from the Redis server. As a result it is - /// possible for callers to utilize `RedisValue` types in such a way that the underlying data is never moved or - /// copied. - /// - /// If the values are huge or performance is a concern and callers do not need to modify the underlying data it is - /// recommended to convert to `Str` or `Bytes` whenever possible. Converting to `String`, `Vec`, etc will - /// result in at least a move, if not a copy, of the underlying data. pub fn convert(self) -> Result where R: FromRedis, @@ -1274,17 +1267,17 @@ impl<'a> RedisValue { /// Some use cases require using `RedisValue` types as keys in a `HashMap`, etc. Trying to do so with an aggregate /// type can panic, and this function can be used to more gracefully handle this situation. pub fn can_hash(&self) -> bool { - match self.kind() { + matches!( + self.kind(), RedisValueKind::String - | RedisValueKind::Boolean - | RedisValueKind::Double - | RedisValueKind::Integer - | RedisValueKind::Bytes - | RedisValueKind::Null - | RedisValueKind::Array - | RedisValueKind::Queued => true, - _ => false, - } + | RedisValueKind::Boolean + | RedisValueKind::Double + | RedisValueKind::Integer + | RedisValueKind::Bytes + | RedisValueKind::Null + | RedisValueKind::Array + | RedisValueKind::Queued + ) } /// Convert the value to JSON. @@ -1494,9 +1487,80 @@ where } } -impl FromIterator for RedisValue { - fn from_iter>(iter: I) -> Self { - RedisValue::Array(iter.into_iter().collect()) +impl<'a, T, const N: usize> TryFrom<&'a [T; N]> for RedisValue +where + T: TryInto + Clone, + T::Error: Into, +{ + type Error = RedisError; + + fn try_from(value: &'a [T; N]) -> Result { + let values = value + .iter() + .map(|v| v.clone().try_into().map_err(|e| e.into())) + .collect::, RedisError>>()?; + + Ok(RedisValue::Array(values)) + } +} + +impl TryFrom<[T; N]> for RedisValue +where + T: TryInto + Clone, + T::Error: Into, +{ + type Error = RedisError; + + fn try_from(value: [T; N]) -> Result { + let values = value + .into_iter() + .map(|v| v.try_into().map_err(|e| e.into())) + .collect::, RedisError>>()?; + + Ok(RedisValue::Array(values)) + } +} + +impl TryFrom> for RedisValue +where + T: TryInto, + T::Error: Into, +{ + type Error = RedisError; + + fn try_from(value: Vec) -> Result { + let values = value + .into_iter() + .map(|v| v.try_into().map_err(|e| e.into())) + .collect::, RedisError>>()?; + + Ok(RedisValue::Array(values)) + } +} + +impl TryFrom> for RedisValue +where + T: TryInto, + T::Error: Into, +{ + type Error = RedisError; + + fn try_from(value: VecDeque) -> Result { + let values = value + .into_iter() + .map(|v| v.try_into().map_err(|e| e.into())) + .collect::, RedisError>>()?; + + Ok(RedisValue::Array(values)) + } +} + +impl FromIterator for RedisValue +where + V: Into, +{ + fn from_iter>(iter: I) -> Self { + RedisValue::Array(iter.into_iter().map(|v| v.into()).collect()) } } @@ -1550,45 +1614,6 @@ impl From<()> for RedisValue { } } -#[cfg(feature = "serde-json")] -#[cfg_attr(docsrs, doc(cfg(feature = "serde-json")))] -impl TryFrom for RedisValue { - type Error = RedisError; - - fn try_from(v: Value) -> Result { - let value = match v { - Value::Null => RedisValue::Null, - Value::String(s) => RedisValue::String(s.into()), - Value::Bool(b) => RedisValue::Boolean(b), - Value::Number(n) => { - if n.is_i64() { - RedisValue::Integer(n.as_i64().unwrap()) - } else if n.is_f64() { - RedisValue::Double(n.as_f64().unwrap()) - } else { - return Err(RedisError::new(RedisErrorKind::InvalidArgument, "Invalid JSON number.")); - } - }, - Value::Array(a) => { - let mut out = Vec::with_capacity(a.len()); - for value in a.into_iter() { - out.push(value.try_into()?); - } - RedisValue::Array(out) - }, - Value::Object(m) => { - let mut out: HashMap = HashMap::with_capacity(m.len()); - for (key, value) in m.into_iter() { - out.insert(key.into(), value.try_into()?); - } - RedisValue::Map(RedisMap { inner: out }) - }, - }; - - Ok(value) - } -} - impl TryFrom for RedisValue { type Error = RedisError; diff --git a/src/types/builder.rs b/src/types/builder.rs new file mode 100644 index 00000000..0265a772 --- /dev/null +++ b/src/types/builder.rs @@ -0,0 +1,287 @@ +use crate::{ + clients::{RedisClient, RedisPool}, + error::{RedisError, RedisErrorKind}, + prelude::ReconnectPolicy, + types::{ConnectionConfig, PerformanceConfig, RedisConfig, ServerConfig}, +}; + +#[cfg(feature = "subscriber-client")] +use crate::clients::SubscriberClient; +#[cfg(feature = "sentinel-client")] +use crate::{clients::SentinelClient, types::SentinelConfig}; + +/// A client and pool builder interface. +/// +/// ```rust +/// # use std::time::Duration; +/// # use redis_protocol::resp3::types::RespVersion; +/// # use fred::prelude::*; +/// fn example() -> Result<(), RedisError> { +/// // use default values +/// let client = Builder::default_centralized().build()?; +/// +/// // or initialize from a URL or config +/// let config = RedisConfig::from_url("redis://localhost:6379/1")?; +/// let mut builder = Builder::from_config(config); +/// // or modify values in place (creating defaults if needed) +/// builder +/// .with_performance_config(|config| { +/// config.auto_pipeline = true; +/// }) +/// .with_config(|config| { +/// config.version = RespVersion::RESP3; +/// config.fail_fast = true; +/// }) +/// .with_connection_config(|config| { +/// config.tcp = TcpConfig { +/// nodelay: Some(true), +/// ..Default::default() +/// }; +/// config.internal_command_timeout = Duration::from_secs(10); +/// }); +/// // or overwrite configuration structs in place +/// builder.set_policy(ReconnectPolicy::new_exponential(0, 100, 30_000, 2)); +/// builder.set_performance_config(PerformanceConfig::default()); +/// +/// // reuse the builder as needed to create any kind of client +/// let client = builder.build()?; +/// let pool = builder.build_pool(3)?; +/// let subscriber = builder.build_subscriber_client()?; +/// +/// // ... +/// +/// Ok(()) +/// } +/// ``` +#[derive(Clone, Debug)] +pub struct Builder { + config: Option, + performance: PerformanceConfig, + connection: ConnectionConfig, + policy: Option, + #[cfg(feature = "sentinel-client")] + sentinel: Option, +} + +impl Default for Builder { + fn default() -> Self { + Builder { + config: None, + performance: PerformanceConfig::default(), + connection: ConnectionConfig::default(), + policy: None, + #[cfg(feature = "sentinel-client")] + sentinel: None, + } + } +} + +impl Builder { + /// Create a new builder instance with default config values for a centralized deployment. + pub fn default_centralized() -> Self { + Builder { + config: Some(RedisConfig { + server: ServerConfig::default_centralized(), + ..Default::default() + }), + ..Default::default() + } + } + + /// Create a new builder instance with default config values for a clustered deployment. + pub fn default_clustered() -> Self { + Builder { + config: Some(RedisConfig { + server: ServerConfig::default_clustered(), + ..Default::default() + }), + ..Default::default() + } + } + + /// Create a new builder instance from the provided client config. + pub fn from_config(config: RedisConfig) -> Self { + Builder { + config: Some(config), + ..Default::default() + } + } + + /// Read the client config. + pub fn get_config(&self) -> Option<&RedisConfig> { + self.config.as_ref() + } + + /// Read the reconnection policy. + pub fn get_policy(&self) -> Option<&ReconnectPolicy> { + self.policy.as_ref() + } + + /// Read the performance config. + pub fn get_performance_config(&self) -> &PerformanceConfig { + &self.performance + } + + /// Read the connection config. + pub fn get_connection_config(&self) -> &ConnectionConfig { + &self.connection + } + + /// Read the sentinel client config. + #[cfg(feature = "sentinel-client")] + #[cfg_attr(docsrs, doc(cfg(feature = "sentinel-client")))] + pub fn get_sentinel_config(&self) -> Option<&RedisConfig> { + self.config.as_ref() + } + + /// Overwrite the client config on the builder. + pub fn set_config(&mut self, config: RedisConfig) -> &mut Self { + self.config = Some(config); + self + } + + /// Overwrite the reconnection policy on the builder. + pub fn set_policy(&mut self, policy: ReconnectPolicy) -> &mut Self { + self.policy = Some(policy); + self + } + + /// Overwrite the performance config on the builder. + pub fn set_performance_config(&mut self, config: PerformanceConfig) -> &mut Self { + self.performance = config; + self + } + + /// Overwrite the connection config on the builder. + pub fn set_connection_config(&mut self, config: ConnectionConfig) -> &mut Self { + self.connection = config; + self + } + + /// Overwrite the sentinel config on the builder. + #[cfg(feature = "sentinel-client")] + #[cfg_attr(docsrs, doc(cfg(feature = "sentinel-client")))] + pub fn set_sentinel_config(&mut self, config: SentinelConfig) -> &mut Self { + self.sentinel = Some(config); + self + } + + /// Modify the client config in place, creating a new one with default centralized values first if needed. + pub fn with_config(&mut self, func: F) -> &mut Self + where + F: FnOnce(&mut RedisConfig), + { + if let Some(config) = self.config.as_mut() { + func(config); + } else { + let mut config = RedisConfig::default(); + func(&mut config); + self.config = Some(config); + } + + self + } + + /// Modify the performance config in place, creating a new one with default values first if needed. + pub fn with_performance_config(&mut self, func: F) -> &mut Self + where + F: FnOnce(&mut PerformanceConfig), + { + func(&mut self.performance); + self + } + + /// Modify the connection config in place, creating a new one with default values first if needed. + pub fn with_connection_config(&mut self, func: F) -> &mut Self + where + F: FnOnce(&mut ConnectionConfig), + { + func(&mut self.connection); + self + } + + /// Modify the sentinel config in place, creating a new one with default values first if needed. + #[cfg(feature = "sentinel-client")] + #[cfg_attr(docsrs, doc(cfg(feature = "sentinel-client")))] + pub fn with_sentinel_config(&mut self, func: F) -> &mut Self + where + F: FnOnce(&mut SentinelConfig), + { + if let Some(config) = self.sentinel.as_mut() { + func(config); + } else { + let mut config = SentinelConfig::default(); + func(&mut config); + self.sentinel = Some(config); + } + + self + } + + /// Create a new client. + pub fn build(&self) -> Result { + if let Some(config) = self.config.as_ref() { + Ok(RedisClient::new( + config.clone(), + Some(self.performance.clone()), + Some(self.connection.clone()), + self.policy.clone(), + )) + } else { + Err(RedisError::new(RedisErrorKind::Config, "Missing client configuration.")) + } + } + + /// Create a new client pool. + pub fn build_pool(&self, size: usize) -> Result { + if let Some(config) = self.config.as_ref() { + RedisPool::new( + config.clone(), + Some(self.performance.clone()), + Some(self.connection.clone()), + self.policy.clone(), + size, + ) + } else { + Err(RedisError::new(RedisErrorKind::Config, "Missing client configuration.")) + } + } + + /// Create a new subscriber client. + #[cfg(feature = "subscriber-client")] + #[cfg_attr(docsrs, doc(cfg(feature = "subscriber-client")))] + pub fn build_subscriber_client(&self) -> Result { + if let Some(config) = self.config.as_ref() { + Ok(SubscriberClient::new( + config.clone(), + Some(self.performance.clone()), + Some(self.connection.clone()), + self.policy.clone(), + )) + } else { + Err(RedisError::new(RedisErrorKind::Config, "Missing client configuration.")) + } + } + + /// Create a new sentinel client. + /// + /// This is only necessary if callers need to communicate directly with sentinel nodes. Use a + /// `ServerConfig::Sentinel` to interact with Redis servers behind a sentinel layer. + #[cfg(feature = "sentinel-client")] + #[cfg_attr(docsrs, doc(cfg(feature = "sentinel-client")))] + pub fn build_sentinel_client(&self) -> Result { + if let Some(config) = self.sentinel.as_ref() { + Ok(SentinelClient::new( + config.clone(), + Some(self.performance.clone()), + Some(self.connection.clone()), + self.policy.clone(), + )) + } else { + Err(RedisError::new( + RedisErrorKind::Config, + "Missing sentinel client configuration.", + )) + } + } +} diff --git a/src/types/cluster.rs b/src/types/cluster.rs index 58d660d3..d21f2aad 100644 --- a/src/types/cluster.rs +++ b/src/types/cluster.rs @@ -1,9 +1,49 @@ -use crate::utils; +pub use crate::protocol::types::{ClusterRouting, SlotRange}; +use crate::{ + error::{RedisError, RedisErrorKind}, + types::RedisValue, + utils, +}; use bytes_utils::Str; -pub use crate::protocol::types::{ClusterRouting, SlotRange}; +macro_rules! parse_or_zero( + ($data:ident, $t:ty) => { + $data.parse::<$t>().ok().unwrap_or(0) + } +); -/// The state of the cluster from the CLUSTER INFO command. +fn parse_cluster_info_line(info: &mut ClusterInfo, line: &str) -> Result<(), RedisError> { + let parts: Vec<&str> = line.split(':').collect(); + if parts.len() != 2 { + return Err(RedisError::new(RedisErrorKind::Protocol, "Expected key:value pair.")); + } + let (field, val) = (parts[0], parts[1]); + + match field { + "cluster_state" => match val { + "ok" => info.cluster_state = ClusterState::Ok, + "fail" => info.cluster_state = ClusterState::Fail, + _ => return Err(RedisError::new(RedisErrorKind::Protocol, "Invalid cluster state.")), + }, + "cluster_slots_assigned" => info.cluster_slots_assigned = parse_or_zero!(val, u16), + "cluster_slots_ok" => info.cluster_slots_ok = parse_or_zero!(val, u16), + "cluster_slots_pfail" => info.cluster_slots_pfail = parse_or_zero!(val, u16), + "cluster_slots_fail" => info.cluster_slots_fail = parse_or_zero!(val, u16), + "cluster_known_nodes" => info.cluster_known_nodes = parse_or_zero!(val, u16), + "cluster_size" => info.cluster_size = parse_or_zero!(val, u32), + "cluster_current_epoch" => info.cluster_current_epoch = parse_or_zero!(val, u64), + "cluster_my_epoch" => info.cluster_my_epoch = parse_or_zero!(val, u64), + "cluster_stats_messages_sent" => info.cluster_stats_messages_sent = parse_or_zero!(val, u64), + "cluster_stats_messages_received" => info.cluster_stats_messages_received = parse_or_zero!(val, u64), + _ => { + warn!("Invalid cluster info field: {}", line); + }, + }; + + Ok(()) +} + +/// The state of the cluster from the `CLUSTER INFO` command. #[derive(Clone, Debug, Eq, PartialEq)] pub enum ClusterState { Ok, @@ -16,7 +56,7 @@ impl Default for ClusterState { } } -/// A parsed response from the CLUSTER INFO command. +/// A parsed response from the `CLUSTER INFO` command. /// /// #[derive(Clone, Debug, Eq, PartialEq, Default)] @@ -34,6 +74,26 @@ pub struct ClusterInfo { pub cluster_stats_messages_received: u64, } +impl TryFrom for ClusterInfo { + type Error = RedisError; + + fn try_from(value: RedisValue) -> Result { + if let Some(data) = value.as_bytes_str() { + let mut out = ClusterInfo::default(); + + for line in data.lines() { + let trimmed = line.trim(); + if !trimmed.is_empty() { + parse_cluster_info_line(&mut out, trimmed)?; + } + } + Ok(out) + } else { + Err(RedisError::new(RedisErrorKind::Protocol, "Expected string response.")) + } + } +} + /// Options for the CLUSTER FAILOVER command. #[derive(Clone, Debug, Eq, PartialEq)] pub enum ClusterFailoverFlag { diff --git a/src/types/config.rs b/src/types/config.rs index 50cb8e78..40c7aec6 100644 --- a/src/types/config.rs +++ b/src/types/config.rs @@ -1,9 +1,11 @@ -use crate::{error::RedisError, types::RespVersion, utils}; -use std::cmp; +pub use crate::protocol::types::Server; +use crate::{error::RedisError, protocol::command::RedisCommand, types::RespVersion, utils}; +use socket2::TcpKeepalive; +use std::{cmp, time::Duration}; use url::Url; #[cfg(feature = "mocks")] -use crate::mocks::{Echo, Mocks}; +use crate::mocks::Mocks; #[cfg(feature = "mocks")] use std::sync::Arc; @@ -15,7 +17,6 @@ pub use crate::protocol::tls::{HostMapping, TlsConfig, TlsConnector, TlsHostMapp #[cfg_attr(docsrs, doc(cfg(feature = "replicas")))] pub use crate::router::replicas::{ReplicaConfig, ReplicaFilter}; -pub use crate::protocol::types::Server; /// The default amount of jitter when waiting to reconnect. pub const DEFAULT_JITTER_MS: u32 = 100; @@ -245,8 +246,8 @@ pub enum BackpressurePolicy { /// If `0` then no backpressure will be applied, but backpressure errors will not be surfaced to callers unless /// `disable_auto_backpressure` is `true`. /// - /// Default: 50 ms - min_sleep_duration_ms: u64, + /// Default: 10 ms + min_sleep_duration: Duration, }, /// Wait for all in-flight commands to finish before sending the next command. Drain, @@ -263,7 +264,7 @@ impl BackpressurePolicy { pub fn default_sleep() -> Self { BackpressurePolicy::Sleep { disable_backpressure_scaling: false, - min_sleep_duration_ms: 50, + min_sleep_duration: Duration::from_millis(10), } } } @@ -273,7 +274,8 @@ impl BackpressurePolicy { pub struct BackpressureConfig { /// Whether or not to disable the automatic backpressure features when pipelining is enabled. /// - /// If `true` then `RedisErrorKind::Backpressure` errors may be surfaced to callers. + /// If `true` then `RedisErrorKind::Backpressure` errors may be surfaced to callers. Callers can set this to `true` + /// and `max_in_flight_commands` to `0` to effectively disable the backpressure logic. /// /// Default: `false` pub disable_auto_backpressure: bool, @@ -297,6 +299,89 @@ impl Default for BackpressureConfig { } } +/// TCP configuration options. +#[derive(Clone, Debug, Default)] +pub struct TcpConfig { + /// Set the [TCP_NODELAY](https://docs.rs/tokio/latest/tokio/net/struct.TcpStream.html#method.set_nodelay) value. + pub nodelay: Option, + /// Set the [SO_LINGER](https://docs.rs/tokio/latest/tokio/net/struct.TcpStream.html#method.set_linger) value. + pub linger: Option, + /// Set the [IP_TTL](https://docs.rs/tokio/latest/tokio/net/struct.TcpStream.html#method.set_ttl) value. + pub ttl: Option, + /// Set the [TCP keepalive values](https://docs.rs/socket2/latest/socket2/struct.Socket.html#method.set_tcp_keepalive). + pub keepalive: Option, +} + +impl PartialEq for TcpConfig { + fn eq(&self, other: &Self) -> bool { + self.nodelay == other.nodelay && self.linger == other.linger && self.ttl == other.ttl + } +} + +impl Eq for TcpConfig {} + +/// Configuration options related to the creation or management of TCP connection. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct ConnectionConfig { + /// The timeout to apply when attempting to create a new TCP connection. + /// + /// This also includes the TLS handshake if using any of the TLS features. + /// + /// Default: 10 sec + pub connection_timeout: Duration, + /// The timeout to apply when sending internal commands such as `AUTH`, `SELECT`, `CLUSTER SLOTS`, `READONLY`, etc. + /// + /// Default: 10 sec + pub internal_command_timeout: Duration, + /// The amount of time to wait after a `MOVED` error is received before the client will update the cached cluster + /// state. + /// + /// Default: `0` + pub cluster_cache_update_delay: Duration, + /// The maximum number of times the client will attempt to send a command. + /// + /// This value be incremented whenever the connection closes while the command is in-flight. + /// + /// Default: `3` + pub max_command_attempts: u32, + /// The maximum number of times the client will attempt to follow a `MOVED` or `ASK` redirection per command. + /// + /// Default: `5` + pub max_redirections: u32, + /// The amount of time a command can wait without a response before the corresponding connection is considered + /// unresponsive. This will trigger a reconnection and in-flight commands will be retried. + /// + /// Default: 10 sec + #[cfg(feature = "check-unresponsive")] + #[cfg_attr(docsrs, doc(cfg(feature = "check-unresponsive")))] + pub unresponsive_timeout: Duration, + /// Configuration options for replica nodes. + /// + /// Default: `None` + #[cfg(feature = "replicas")] + #[cfg_attr(docsrs, doc(cfg(feature = "replicas")))] + pub replica: ReplicaConfig, + /// TCP connection options. + pub tcp: TcpConfig, +} + +impl Default for ConnectionConfig { + fn default() -> Self { + ConnectionConfig { + connection_timeout: Duration::from_millis(10_000), + internal_command_timeout: Duration::from_millis(10_000), + max_redirections: 5, + max_command_attempts: 3, + cluster_cache_update_delay: Duration::from_millis(0), + tcp: TcpConfig::default(), + #[cfg(feature = "check-unresponsive")] + unresponsive_timeout: Duration::from_millis(10_000), + #[cfg(feature = "replicas")] + replica: ReplicaConfig::default(), + } + } +} + /// Configuration options that can affect the performance of the client. #[derive(Clone, Debug, Eq, PartialEq)] pub struct PerformanceConfig { @@ -306,52 +391,31 @@ pub struct PerformanceConfig { /// whereas this flag can automatically pipeline commands across tasks. /// /// Default: `true` - pub auto_pipeline: bool, - /// The maximum number of times the client will attempt to send a command. - /// - /// This value be incremented whenever the connection closes while the command is in-flight or following a - /// MOVED/ASK error. - /// - /// Default: `3` - pub max_command_attempts: u32, + pub auto_pipeline: bool, /// Configuration options for backpressure features in the client. - pub backpressure: BackpressureConfig, - /// An optional timeout (in milliseconds) to apply to all commands. + pub backpressure: BackpressureConfig, + /// An optional timeout to apply to all commands. /// - /// If `0` this will disable any timeout being applied to commands. + /// If `0` this will disable any timeout being applied to commands. Callers can also set timeouts on individual + /// commands via the [with_options](crate::interfaces::ClientLike::with_options) interface. /// /// Default: `0` - pub default_command_timeout_ms: u64, - /// The maximum number of frames that will be passed to a socket before flushing the socket. + pub default_command_timeout: Duration, + /// The maximum number of frames that will be fed to a socket before flushing. /// /// Note: in some circumstances the client with always flush the socket (`QUIT`, `EXEC`, etc). /// - /// Default: 500 - pub max_feed_count: u64, - /// The amount of time, in milliseconds, to wait after a `MOVED` error is received before the client will update - /// the cached cluster state. - /// - /// Default: 0 ms - pub cluster_cache_update_delay_ms: u32, - /// The amount of time a command can wait without a response before a connection is considered unresponsive. - /// - /// Default: 60000 ms (1 min) - #[cfg(feature = "check-unresponsive")] - #[cfg_attr(docsrs, doc(cfg(feature = "check-unresponsive")))] - pub network_timeout_ms: u64, + /// Default: 200 + pub max_feed_count: u64, } impl Default for PerformanceConfig { fn default() -> Self { PerformanceConfig { - auto_pipeline: true, - backpressure: BackpressureConfig::default(), - max_command_attempts: 3, - default_command_timeout_ms: 0, - max_feed_count: 500, - cluster_cache_update_delay_ms: 0, - #[cfg(feature = "check-unresponsive")] - network_timeout_ms: 60_000, + auto_pipeline: true, + backpressure: BackpressureConfig::default(), + default_command_timeout: Duration::from_millis(0), + max_feed_count: 200, } } } @@ -366,7 +430,7 @@ pub struct RedisConfig { /// Normally the reconnection logic only applies to connections that close unexpectedly, but this flag can apply /// the same logic to the first connection as it is being created. /// - /// Note: Callers should use caution setting this to `false` since it can make debugging configuration issues more + /// Callers should use caution setting this to `false` since it can make debugging configuration issues more /// difficult. /// /// Default: `true` @@ -410,28 +474,20 @@ pub struct RedisConfig { pub database: Option, /// TLS configuration options. /// - /// See the `tls` examples on Github for more information. - /// /// Default: `None` #[cfg(any(feature = "enable-native-tls", feature = "enable-rustls"))] #[cfg_attr(docsrs, doc(cfg(any(feature = "enable-native-tls", feature = "enable-rustls"))))] pub tls: Option, - /// Configuration of tracing for this client. + /// Tracing configuration options. #[cfg(feature = "partial-tracing")] #[cfg_attr(docsrs, doc(cfg(feature = "partial-tracing")))] pub tracing: TracingConfig, - /// Configuration options for replica nodes. - /// - /// Default: `None` - #[cfg(feature = "replicas")] - #[cfg_attr(docsrs, doc(cfg(feature = "replicas")))] - pub replica: ReplicaConfig, /// An optional [mocking layer](crate::mocks) to intercept and process commands. /// - /// Default: [Echo](crate::mocks::Echo) + /// Default: `None` #[cfg(feature = "mocks")] #[cfg_attr(docsrs, doc(cfg(feature = "mocks")))] - pub mocks: Arc, + pub mocks: Option>, } impl PartialEq for RedisConfig { @@ -462,10 +518,8 @@ impl Default for RedisConfig { tls: None, #[cfg(feature = "partial-tracing")] tracing: TracingConfig::default(), - #[cfg(feature = "replicas")] - replica: ReplicaConfig::default(), #[cfg(feature = "mocks")] - mocks: Arc::new(Echo), + mocks: None, } } } @@ -555,7 +609,7 @@ impl RedisConfig { /// * `redis-sentinel` - TCP connected to a centralized server behind a sentinel layer. /// * `rediss-sentinel` - TLS connected to a centralized server behind a sentinel layer. /// - /// **Note: The `rediss` scheme prefix requires the `enable-native-tls` or `enable-rustls` feature.** + /// **The `rediss` scheme prefix requires the `enable-native-tls` or `enable-rustls` feature.** /// /// # Query Parameters /// @@ -822,34 +876,25 @@ impl ServerConfig { /// Whether or not the config is for a cluster. pub fn is_clustered(&self) -> bool { - match self { - ServerConfig::Clustered { .. } => true, - _ => false, - } + matches!(*self, ServerConfig::Clustered { .. }) } /// Whether or not the config is for a centralized server behind a sentinel node(s). pub fn is_sentinel(&self) -> bool { - match self { - ServerConfig::Sentinel { .. } => true, - _ => false, - } + matches!(*self, ServerConfig::Sentinel { .. }) } /// Whether or not the config is for a centralized server. pub fn is_centralized(&self) -> bool { - match self { - ServerConfig::Centralized { .. } => true, - _ => false, - } + matches!(*self, ServerConfig::Centralized { .. }) } /// Read the server hosts or sentinel hosts if using the sentinel interface. pub fn hosts(&self) -> Vec<&Server> { match *self { ServerConfig::Centralized { ref server } => vec![server], - ServerConfig::Clustered { ref hosts } => hosts.iter().map(|s| s).collect(), - ServerConfig::Sentinel { ref hosts, .. } => hosts.iter().map(|s| s).collect(), + ServerConfig::Clustered { ref hosts } => hosts.iter().collect(), + ServerConfig::Sentinel { ref hosts, .. } => hosts.iter().collect(), } } } @@ -973,8 +1018,117 @@ impl From for RedisConfig { tls: config.tls, #[cfg(feature = "partial-tracing")] tracing: config.tracing, - #[cfg(feature = "replicas")] - replica: ReplicaConfig::default(), + #[cfg(feature = "mocks")] + mocks: None, + } + } +} + +/// Options to configure or overwrite for individual commands. +/// +/// Fields left as `None` will use the value from the corresponding client or global config option. +/// +/// ```rust +/// # use fred::prelude::*; +/// async fn example() -> Result<(), RedisError> { +/// let options = Options { +/// max_attempts: Some(10), +/// max_redirections: Some(2), +/// ..Default::default() +/// }; +/// +/// let client = RedisClient::default(); +/// let _ = client.connect(); +/// let _ = client.wait_for_connect().await?; +/// let _: () = client.with_options(&options).get("foo").await?; +/// +/// Ok(()) +/// } +/// ``` +/// +/// See [WithOptions](crate::clients::WithOptions) for more information. +#[derive(Clone, Debug, Eq, PartialEq, Default)] +pub struct Options { + /// Set the max number of write attempts for a command. + pub max_attempts: Option, + /// Set the max number of cluster redirections to follow for a command. + pub max_redirections: Option, + /// Set the timeout duration for a command. + /// + /// This interface is more* cancellation-safe than a simple [timeout](https://docs.rs/tokio/latest/tokio/time/fn.timeout.html) call. + /// + /// * But it's not perfect. There's no reliable mechanism to cancel a command once it has been written + /// to the connection. + pub timeout: Option, + /// The cluster node that should receive the command. + /// + /// The caller will receive a `RedisErrorKind::Cluster` error if the provided server does not exist. + /// + /// The client will still follow redirection errors via this interface. Callers may not notice this, but incorrect + /// server arguments here could result in unnecessary calls to refresh the cached cluster routing table. + pub cluster_node: Option, + /// Whether to skip backpressure checks for a command. + pub no_backpressure: bool, + /// Whether to send `CLIENT CACHING yes|no` before the command. + #[cfg(feature = "client-tracking")] + #[cfg_attr(docsrs, doc(cfg(feature = "client-tracking")))] + pub caching: Option, +} + +impl Options { + /// Set the non-null values from `other` onto `self`. + pub fn extend(&mut self, other: &Self) -> &mut Self { + if let Some(val) = other.max_attempts { + self.max_attempts = Some(val); + } + if let Some(val) = other.max_redirections { + self.max_redirections = Some(val); + } + if let Some(val) = other.timeout { + self.timeout = Some(val); + } + if let Some(ref val) = other.cluster_node { + self.cluster_node = Some(val.clone()); + } + self.no_backpressure |= other.no_backpressure; + + #[cfg(feature = "client-tracking")] + if let Some(val) = other.caching { + self.caching = Some(val); + } + + self + } + + /// Create options from a command. + pub(crate) fn from_command(cmd: &RedisCommand) -> Self { + Options { + max_attempts: Some(cmd.attempts_remaining), + max_redirections: Some(cmd.redirections_remaining), + timeout: cmd.timeout_dur, + no_backpressure: cmd.skip_backpressure, + cluster_node: cmd.cluster_node.clone(), + #[cfg(feature = "client-tracking")] + caching: cmd.caching.clone(), + } + } + + /// Overwrite the configuration options on the provided command. + pub(crate) fn apply(&self, command: &mut RedisCommand) { + command.skip_backpressure = self.no_backpressure; + command.timeout_dur = self.timeout; + command.cluster_node = self.cluster_node.clone(); + + #[cfg(feature = "client-tracking")] + { + command.caching = self.caching.clone(); + } + + if let Some(attempts) = self.max_attempts { + command.attempts_remaining = attempts; + } + if let Some(redirections) = self.max_redirections { + command.redirections_remaining = redirections; } } } diff --git a/src/types/geo.rs b/src/types/geo.rs index 84118423..02ac4378 100644 --- a/src/types/geo.rs +++ b/src/types/geo.rs @@ -1,4 +1,4 @@ -use crate::{error::RedisError, types::RedisValue, utils}; +use crate::{error::RedisError, protocol::utils as protocol_utils, types::RedisValue, utils}; use bytes_utils::Str; use std::{ collections::VecDeque, @@ -29,6 +29,15 @@ impl From<(f64, f64)> for GeoPosition { } } +impl TryFrom for GeoPosition { + type Error = RedisError; + + fn try_from(value: RedisValue) -> Result { + let (longitude, latitude): (f64, f64) = value.convert()?; + Ok(GeoPosition { longitude, latitude }) + } +} + /// Units for the GEO DIST command. #[derive(Clone, Debug, Eq, PartialEq)] pub enum GeoUnit { @@ -56,13 +65,6 @@ pub struct GeoValue { pub member: RedisValue, } -impl GeoValue { - pub fn new>(coordinates: GeoPosition, member: V) -> Self { - let member = member.into(); - GeoValue { coordinates, member } - } -} - impl TryFrom<(f64, f64, T)> for GeoValue where T: TryInto, @@ -147,3 +149,74 @@ impl PartialEq for GeoRadiusInfo { } impl Eq for GeoRadiusInfo {} + +impl GeoRadiusInfo { + /// Parse the value with context from the calling command. + pub fn from_redis_value( + value: RedisValue, + withcoord: bool, + withdist: bool, + withhash: bool, + ) -> Result { + if let RedisValue::Array(mut data) = value { + let mut out = GeoRadiusInfo::default(); + data.reverse(); + + if withcoord && withdist && withhash { + // 4 elements: member, dist, hash, position + protocol_utils::assert_array_len(&data, 4)?; + + out.member = data.pop().unwrap(); + out.distance = data.pop().unwrap().convert()?; + out.hash = data.pop().unwrap().convert()?; + out.position = data.pop().unwrap().convert()?; + } else if withcoord && withdist { + // 3 elements: member, dist, position + protocol_utils::assert_array_len(&data, 3)?; + + out.member = data.pop().unwrap(); + out.distance = data.pop().unwrap().convert()?; + out.position = data.pop().unwrap().convert()?; + } else if withcoord && withhash { + // 3 elements: member, hash, position + protocol_utils::assert_array_len(&data, 3)?; + + out.member = data.pop().unwrap(); + out.hash = data.pop().unwrap().convert()?; + out.position = data.pop().unwrap().convert()?; + } else if withdist && withhash { + // 3 elements: member, dist, hash + protocol_utils::assert_array_len(&data, 3)?; + + out.member = data.pop().unwrap(); + out.distance = data.pop().unwrap().convert()?; + out.hash = data.pop().unwrap().convert()?; + } else if withcoord { + // 2 elements: member, position + protocol_utils::assert_array_len(&data, 2)?; + + out.member = data.pop().unwrap(); + out.position = data.pop().unwrap().convert()?; + } else if withdist { + // 2 elements: member, dist + protocol_utils::assert_array_len(&data, 2)?; + + out.member = data.pop().unwrap(); + out.distance = data.pop().unwrap().convert()?; + } else if withhash { + // 2 elements: member, hash + protocol_utils::assert_array_len(&data, 2)?; + + out.member = data.pop().unwrap(); + out.hash = data.pop().unwrap().convert()?; + } + + Ok(out) + } else { + Ok(GeoRadiusInfo { + member: value, + ..Default::default() + }) + } + } +} diff --git a/src/types/misc.rs b/src/types/misc.rs index 9520faab..8be76d7c 100644 --- a/src/types/misc.rs +++ b/src/types/misc.rs @@ -1,14 +1,14 @@ -pub use crate::protocol::hashers::ClusterHash; +pub use crate::protocol::{ + hashers::ClusterHash, + types::{Message, MessageKind}, +}; use crate::{ error::{RedisError, RedisErrorKind}, - types::Server, - utils, + types::{RedisKey, RedisValue, Server}, + utils::{self, convert_or_default}, }; use bytes_utils::Str; -use std::{collections::HashMap, convert::TryFrom, fmt}; - -pub use crate::protocol::types::{Message, MessageKind}; -use crate::types::RedisKey; +use std::{collections::HashMap, convert::TryFrom, fmt, time::Duration}; /// Arguments passed to the SHUTDOWN command. /// @@ -101,7 +101,7 @@ pub struct CustomCommand { /// Cluster clients will use the default policy if not provided. pub cluster_hash: ClusterHash, /// Whether or not the command should block the connection while waiting on a response. - pub is_blocking: bool, + pub blocking: bool, } impl CustomCommand { @@ -114,9 +114,9 @@ impl CustomCommand { H: Into, { CustomCommand { - cmd: cmd.into(), + cmd: cmd.into(), cluster_hash: cluster_hash.into(), - is_blocking: blocking, + blocking, } } @@ -126,16 +126,16 @@ impl CustomCommand { H: Into, { CustomCommand { - cmd: utils::static_str(cmd), + cmd: utils::static_str(cmd), cluster_hash: cluster_hash.into(), - is_blocking: blocking, + blocking, } } } /// An enum describing the possible ways in which a Redis cluster can change state. /// -/// See [on_cluster_change](crate::interfaces::ClientLike::on_cluster_change) for more information. +/// See [on_cluster_change](crate::interfaces::EventInterface::on_cluster_change) for more information. #[derive(Clone, Debug, Eq, PartialEq)] pub enum ClusterStateChange { /// A node was added to the cluster. @@ -228,19 +228,44 @@ impl fmt::Display for ClientState { /// #[derive(Clone, Debug, Eq, PartialEq)] pub struct DatabaseMemoryStats { - pub overhead_hashtable_main: u64, - pub overhead_hashtable_expires: u64, + pub overhead_hashtable_main: u64, + pub overhead_hashtable_expires: u64, + pub overhead_hashtable_slot_to_keys: u64, } impl Default for DatabaseMemoryStats { fn default() -> Self { DatabaseMemoryStats { - overhead_hashtable_expires: 0, - overhead_hashtable_main: 0, + overhead_hashtable_expires: 0, + overhead_hashtable_main: 0, + overhead_hashtable_slot_to_keys: 0, } } } +fn parse_database_memory_stat(stats: &mut DatabaseMemoryStats, key: &str, value: RedisValue) { + match key { + "overhead.hashtable.main" => stats.overhead_hashtable_main = convert_or_default(value), + "overhead.hashtable.expires" => stats.overhead_hashtable_expires = convert_or_default(value), + "overhead.hashtable.slot-to-keys" => stats.overhead_hashtable_slot_to_keys = convert_or_default(value), + _ => {}, + }; +} + +impl TryFrom for DatabaseMemoryStats { + type Error = RedisError; + + fn try_from(value: RedisValue) -> Result { + let values: HashMap = value.convert()?; + let mut out = DatabaseMemoryStats::default(); + + for (key, value) in values.into_iter() { + parse_database_memory_stat(&mut out, &key, value); + } + Ok(out) + } +} + /// The parsed result of the MEMORY STATS command. /// /// @@ -340,6 +365,64 @@ impl PartialEq for MemoryStats { impl Eq for MemoryStats {} +fn parse_memory_stat_field(stats: &mut MemoryStats, key: &str, value: RedisValue) { + match key { + "peak.allocated" => stats.peak_allocated = convert_or_default(value), + "total.allocated" => stats.total_allocated = convert_or_default(value), + "startup.allocated" => stats.startup_allocated = convert_or_default(value), + "replication.backlog" => stats.replication_backlog = convert_or_default(value), + "clients.slaves" => stats.clients_slaves = convert_or_default(value), + "clients.normal" => stats.clients_normal = convert_or_default(value), + "aof.buffer" => stats.aof_buffer = convert_or_default(value), + "lua.caches" => stats.lua_caches = convert_or_default(value), + "overhead.total" => stats.overhead_total = convert_or_default(value), + "keys.count" => stats.keys_count = convert_or_default(value), + "keys.bytes-per-key" => stats.keys_bytes_per_key = convert_or_default(value), + "dataset.bytes" => stats.dataset_bytes = convert_or_default(value), + "dataset.percentage" => stats.dataset_percentage = convert_or_default(value), + "peak.percentage" => stats.peak_percentage = convert_or_default(value), + "allocator.allocated" => stats.allocator_allocated = convert_or_default(value), + "allocator.active" => stats.allocator_active = convert_or_default(value), + "allocator.resident" => stats.allocator_resident = convert_or_default(value), + "allocator-fragmentation.ratio" => stats.allocator_fragmentation_ratio = convert_or_default(value), + "allocator-fragmentation.bytes" => stats.allocator_fragmentation_bytes = convert_or_default(value), + "allocator-rss.ratio" => stats.allocator_rss_ratio = convert_or_default(value), + "allocator-rss.bytes" => stats.allocator_rss_bytes = convert_or_default(value), + "rss-overhead.ratio" => stats.rss_overhead_ratio = convert_or_default(value), + "rss-overhead.bytes" => stats.rss_overhead_bytes = convert_or_default(value), + "fragmentation" => stats.fragmentation = convert_or_default(value), + "fragmentation.bytes" => stats.fragmentation_bytes = convert_or_default(value), + _ => { + if key.starts_with("db.") { + let db = match key.split('.').last().and_then(|v| v.parse::().ok()) { + Some(db) => db, + None => return, + }; + let parsed: DatabaseMemoryStats = match value.convert().ok() { + Some(db) => db, + None => return, + }; + + stats.db.insert(db, parsed); + } + }, + } +} + +impl TryFrom for MemoryStats { + type Error = RedisError; + + fn try_from(value: RedisValue) -> Result { + let values: HashMap = value.convert()?; + let mut out = MemoryStats::default(); + + for (key, value) in values.into_iter() { + parse_memory_stat_field(&mut out, &key, value); + } + Ok(out) + } +} + /// The output of an entry in the slow queries log. /// /// @@ -347,10 +430,62 @@ impl Eq for MemoryStats {} pub struct SlowlogEntry { pub id: i64, pub timestamp: i64, - pub duration: u64, - pub args: Vec, - pub ip: Option, - pub name: Option, + pub duration: Duration, + pub args: Vec, + pub ip: Option, + pub name: Option, +} + +impl TryFrom for SlowlogEntry { + type Error = RedisError; + + fn try_from(value: RedisValue) -> Result { + if let RedisValue::Array(values) = value { + if values.len() < 4 { + return Err(RedisError::new( + RedisErrorKind::Protocol, + "Expected at least 4 response values.", + )); + } + + let id = values[0] + .as_i64() + .ok_or(RedisError::new(RedisErrorKind::Protocol, "Expected integer ID."))?; + let timestamp = values[1] + .as_i64() + .ok_or(RedisError::new(RedisErrorKind::Protocol, "Expected integer timestamp."))?; + let duration = values[2] + .as_u64() + .map(Duration::from_micros) + .ok_or(RedisError::new(RedisErrorKind::Protocol, "Expected integer duration."))?; + let args = values[3].clone().into_multiple_values(); + + let (ip, name) = if values.len() == 6 { + let ip = values[4] + .as_bytes_str() + .ok_or(RedisError::new(RedisErrorKind::Protocol, "Expected IP address string."))?; + let name = values[5].as_bytes_str().ok_or(RedisError::new( + RedisErrorKind::Protocol, + "Expected client name string.", + ))?; + + (Some(ip), Some(name)) + } else { + (None, None) + }; + + Ok(SlowlogEntry { + id, + timestamp, + duration, + args, + ip, + name, + }) + } else { + Err(RedisError::new_parse("Expected array.")) + } + } } /// Flags for the SCRIPT DEBUG command. @@ -432,7 +567,7 @@ impl FnPolicy { } pub(crate) fn from_str(s: &str) -> Result { - Ok(match s.as_ref() { + Ok(match s { "flush" | "FLUSH" => FnPolicy::Flush, "append" | "APPEND" => FnPolicy::Append, "replace" | "REPLACE" => FnPolicy::Replace, @@ -483,6 +618,6 @@ impl TryFrom<&Str> for FnPolicy { type Error = RedisError; fn try_from(value: &Str) -> Result { - FnPolicy::from_str(&value) + FnPolicy::from_str(value) } } diff --git a/src/types/mod.rs b/src/types/mod.rs index 58f0ba95..b03bb287 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -1,11 +1,10 @@ use crate::error::RedisError; pub use crate::modules::response::{FromRedis, FromRedisKey}; -pub use arcstr::ArcStr; pub use redis_protocol::resp3::types::{Frame, RespVersion}; use tokio::task::JoinHandle; -mod acl; mod args; +mod builder; mod client; mod cluster; mod config; @@ -18,8 +17,8 @@ mod scripts; mod sorted_sets; mod streams; -pub use acl::*; pub use args::*; +pub use builder::*; pub use client::*; pub use cluster::*; pub use config::*; @@ -41,8 +40,7 @@ pub use crate::modules::metrics::Stats; #[cfg_attr(docsrs, doc(cfg(feature = "dns")))] pub use crate::protocol::types::Resolve; -pub(crate) static QUEUED: &'static str = "QUEUED"; -pub(crate) static NIL: &'static str = "nil"; +pub(crate) static QUEUED: &str = "QUEUED"; /// The ANY flag used on certain GEO commands. pub type Any = bool; diff --git a/src/types/multiple.rs b/src/types/multiple.rs index 2436becf..3e4257bc 100644 --- a/src/types/multiple.rs +++ b/src/types/multiple.rs @@ -1,12 +1,5 @@ -use crate::{ - error::RedisError, - types::{RedisKey, RedisValue}, -}; -use std::{ - collections::VecDeque, - convert::{TryFrom, TryInto}, - iter::FromIterator, -}; +use crate::types::{RedisKey, RedisValue}; +use std::{collections::VecDeque, iter::FromIterator}; /// Convenience struct for commands that take 1 or more keys. /// @@ -62,6 +55,17 @@ where } } +impl<'a, K, const N: usize> From<&'a [K; N]> for MultipleKeys +where + K: Into + Clone, +{ + fn from(value: &'a [K; N]) -> Self { + MultipleKeys { + keys: value.iter().map(|k| k.clone().into()).collect(), + } + } +} + impl From> for MultipleKeys where T: Into, @@ -90,109 +94,11 @@ impl From<()> for MultipleKeys { } } -/// Convenience struct for commands that take 1 or more strings. +/// Convenience interface for commands that take 1 or more strings. pub type MultipleStrings = MultipleKeys; -/// Convenience struct for commands that take 1 or more values. -/// -/// **Note: this can be used to represent an empty set of values by using `None` for any function that takes -/// `Into`.** This is most useful for `EVAL` and `EVALSHA`. -#[derive(Clone, Debug, Eq, PartialEq)] -pub struct MultipleValues { - values: Vec, -} - -impl MultipleValues { - pub fn inner(self) -> Vec { - self.values - } - - pub fn len(&self) -> usize { - self.values.len() - } - - /// Convert this a nested `RedisValue`. - pub fn into_values(self) -> RedisValue { - RedisValue::Array(self.values) - } -} - -impl From> for MultipleValues { - fn from(val: Option) -> Self { - let values = if let Some(val) = val { vec![val] } else { vec![] }; - MultipleValues { values } - } -} - -// https://github.com/rust-lang/rust/issues/50133 -// FIXME there has to be a way around this issue? -// impl TryFrom for MultipleValues -// where -// T: TryInto, -// T::Error: Into, -// { -// type Error = RedisError; -// -// fn try_from(d: T) -> Result { -// Ok(MultipleValues { values: vec![to!(d)?] }) -// } -// } - -// TODO consider supporting conversion from tuples with a reasonable size - -impl From for MultipleValues -where - T: Into, -{ - fn from(d: T) -> Self { - MultipleValues { values: vec![d.into()] } - } -} - -impl FromIterator for MultipleValues -where - T: Into, -{ - fn from_iter>(iter: I) -> Self { - MultipleValues { - values: iter.into_iter().map(|v| v.into()).collect(), - } - } -} - -impl TryFrom> for MultipleValues -where - T: TryInto, - T::Error: Into, -{ - type Error = RedisError; - - fn try_from(d: Vec) -> Result { - let mut values = Vec::with_capacity(d.len()); - for value in d.into_iter() { - values.push(to!(value)?); - } - - Ok(MultipleValues { values }) - } -} - -impl TryFrom> for MultipleValues -where - T: TryInto, - T::Error: Into, -{ - type Error = RedisError; - - fn try_from(d: VecDeque) -> Result { - let mut values = Vec::with_capacity(d.len()); - for value in d.into_iter() { - values.push(to!(value)?); - } - - Ok(MultipleValues { values }) - } -} +/// Convenience interface for commands that take 1 or more values. +pub type MultipleValues = RedisValue; /// A convenience struct for functions that take one or more hash slot values. pub struct MultipleHashSlots { diff --git a/src/types/scripts.rs b/src/types/scripts.rs index 49ffd0f0..8cca37d7 100644 --- a/src/types/scripts.rs +++ b/src/types/scripts.rs @@ -1,10 +1,11 @@ use crate::{ clients::RedisClient, interfaces::{FunctionInterface, LuaInterface}, - prelude::{FromRedis, RedisError, RedisErrorKind, RedisResult}, + prelude::{FromRedis, RedisError, RedisResult}, types::{MultipleKeys, MultipleValues, RedisValue}, - util::sha1_hash, }; +#[cfg(feature = "sha-1")] +use crate::{prelude::RedisErrorKind, util::sha1_hash}; use bytes_utils::Str; use std::{ cmp::Ordering, @@ -59,6 +60,8 @@ impl PartialOrd for Script { impl Script { /// Create a new `Script` from a lua script. + #[cfg(feature = "sha-1")] + #[cfg_attr(docsrs, doc(cfg(feature = "sha-1")))] pub fn from_lua>(lua: S) -> Self { let lua: Str = lua.into(); let hash = Str::from(sha1_hash(&lua)); @@ -86,6 +89,8 @@ impl Script { /// Call `SCRIPT LOAD` on all the associated servers. This must be /// called once before calling [evalsha](Self::evalsha). + #[cfg(feature = "sha-1")] + #[cfg_attr(docsrs, doc(cfg(feature = "sha-1")))] pub async fn load(&self, client: &RedisClient) -> RedisResult<()> { if let Some(ref lua) = self.lua { client.script_load_cluster::<(), _>(lua.clone()).await @@ -108,6 +113,8 @@ impl Script { /// Send `EVALSHA` to the server with the provided arguments. Automatically `SCRIPT LOAD` in case /// of `NOSCRIPT` error and try `EVALSHA` again. + #[cfg(feature = "sha-1")] + #[cfg_attr(docsrs, doc(cfg(feature = "sha-1")))] pub async fn evalsha_with_reload(&self, client: &RedisClient, keys: K, args: V) -> RedisResult where R: FromRedis, @@ -140,6 +147,7 @@ pub enum FunctionFlag { impl FunctionFlag { /// Parse the string representation of the flag. + #[allow(clippy::should_implement_trait)] pub fn from_str(s: &str) -> Option { Some(match s { "allow-oom" => FunctionFlag::AllowOOM, @@ -303,7 +311,7 @@ impl Library { .as_functions(&name)?; Ok(Library { - name: name.into(), + name, functions: functions.into_iter().map(|f| (f.name.clone(), f)).collect(), }) } @@ -322,7 +330,7 @@ impl Library { .as_functions(&name)?; Ok(Library { - name: name.into(), + name, functions: functions.into_iter().map(|f| (f.name.clone(), f)).collect(), }) } diff --git a/src/types/streams.rs b/src/types/streams.rs index afab4390..aa88b078 100644 --- a/src/types/streams.rs +++ b/src/types/streams.rs @@ -30,7 +30,7 @@ impl<'a> TryFrom<&'a str> for XCapTrim { type Error = RedisError; fn try_from(s: &'a str) -> Result { - Ok(match s.as_ref() { + Ok(match s { "=" => XCapTrim::Exact, "~" => XCapTrim::AlmostExact, _ => { @@ -185,7 +185,7 @@ impl<'a> TryFrom<&'a str> for XCapKind { type Error = RedisError; fn try_from(value: &'a str) -> Result { - Ok(match value.as_ref() { + Ok(match value { "MAXLEN" => XCapKind::MaxLen, "MINID" => XCapKind::MinID, _ => { @@ -289,14 +289,14 @@ impl XID { XID::Auto => utils::static_str("*"), XID::Max => utils::static_str("$"), XID::NewInGroup => utils::static_str(">"), - XID::Manual(s) => s.into(), + XID::Manual(s) => s, } } } impl<'a> From<&'a str> for XID { fn from(value: &'a str) -> Self { - match value.as_ref() { + match value { "*" => XID::Auto, "$" => XID::Max, ">" => XID::NewInGroup, diff --git a/src/utils.rs b/src/utils.rs index 5be5fd28..a65befa9 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,14 +1,14 @@ use crate::{ error::{RedisError, RedisErrorKind}, + globals::globals, interfaces::ClientLike, - modules::inner::RedisClientInner, + modules::inner::{CommandSender, RedisClientInner}, protocol::{ command::{RedisCommand, RedisCommandKind}, responders::ResponseKind, utils as protocol_utils, }, types::*, - utils, }; use arc_swap::ArcSwap; use bytes::Bytes; @@ -37,27 +37,24 @@ use std::{ use tokio::{ sync::{ broadcast::{channel as broadcast_channel, Sender as BroadcastSender}, - oneshot::{channel as oneshot_channel, Receiver as OneshotReceiver}, + oneshot::channel as oneshot_channel, }, time::sleep, }; use url::Url; -use crate::globals::globals; #[cfg(any(feature = "enable-native-tls", feature = "enable-rustls"))] use crate::protocol::tls::{TlsConfig, TlsConnector}; #[cfg(any(feature = "full-tracing", feature = "partial-tracing"))] use crate::trace; -#[cfg(feature = "serde-json")] -use serde_json::Value; #[cfg(any(feature = "full-tracing", feature = "partial-tracing"))] use tracing_futures::Instrument; -const REDIS_TLS_SCHEME: &'static str = "rediss"; -const REDIS_CLUSTER_SCHEME_SUFFIX: &'static str = "-cluster"; -const REDIS_SENTINEL_SCHEME_SUFFIX: &'static str = "-sentinel"; -const SENTINEL_NAME_QUERY: &'static str = "sentinelServiceName"; -const CLUSTER_NODE_QUERY: &'static str = "node"; +const REDIS_TLS_SCHEME: &str = "rediss"; +const REDIS_CLUSTER_SCHEME_SUFFIX: &str = "-cluster"; +const REDIS_SENTINEL_SCHEME_SUFFIX: &str = "-sentinel"; +const SENTINEL_NAME_QUERY: &str = "sentinelServiceName"; +const CLUSTER_NODE_QUERY: &str = "node"; #[cfg(feature = "sentinel-auth")] const SENTINEL_USERNAME_QUERY: &'static str = "sentinelUsername"; #[cfg(feature = "sentinel-auth")] @@ -155,6 +152,13 @@ pub fn random_string(len: usize) -> String { .collect() } +pub fn convert_or_default(value: RedisValue) -> R +where + R: FromRedis + Default, +{ + value.convert().ok().unwrap_or_default() +} + pub fn random_u64(max: u64) -> u64 { rand::thread_rng().gen_range(0 .. max) } @@ -215,49 +219,20 @@ pub fn set_mutex(locked: &Mutex, value: T) -> T { mem::replace(&mut *locked.lock(), value) } -pub fn take_mutex(locked: &Mutex>) -> Option { - locked.lock().take() -} - +/// pub fn check_lex_str(val: String, kind: &ZRangeKind) -> String { - let formatted = val.starts_with("(") || val.starts_with("[") || val == "+" || val == "-"; + let formatted = val.starts_with('(') || val.starts_with('[') || val == "+" || val == "-"; if formatted { val + } else if *kind == ZRangeKind::Exclusive { + format!("({}", val) } else { - if *kind == ZRangeKind::Exclusive { - format!("({}", val) - } else { - format!("[{}", val) - } - } -} - -pub fn value_to_f64(value: &RedisValue) -> Result { - value.as_f64().ok_or(RedisError::new( - RedisErrorKind::Unknown, - "Could not parse value as float.", - )) -} - -pub fn value_to_geo_pos(value: &RedisValue) -> Result, RedisError> { - if let RedisValue::Array(value) = value { - if value.len() == 2 { - let longitude = value_to_f64(&value[0])?; - let latitude = value_to_f64(&value[1])?; - - Ok(Some(GeoPosition { longitude, latitude })) - } else { - Err(RedisError::new( - RedisErrorKind::Unknown, - "Expected array with 2 elements.", - )) - } - } else { - Ok(None) + format!("[{}", val) } } +/// Parse the response from `FUNCTION LIST`. fn parse_functions(value: &RedisValue) -> Result, RedisError> { if let RedisValue::Array(functions) = value { let mut out = Vec::with_capacity(functions.len()); @@ -287,6 +262,7 @@ fn parse_functions(value: &RedisValue) -> Result, RedisError> { } } +/// Check and parse the response to `FUNCTION LIST`. pub fn value_to_functions(value: &RedisValue, name: &str) -> Result, RedisError> { if let RedisValue::Array(ref libraries) = value { for library in libraries.iter() { @@ -310,17 +286,17 @@ pub fn value_to_functions(value: &RedisValue, name: &str) -> Result(ft: Fut, timeout: u64) -> Result +pub async fn apply_timeout(ft: Fut, timeout: Duration) -> Result where E: Into, Fut: Future>, { - if timeout > 0 { - let sleep_ft = sleep(Duration::from_millis(timeout)); + if !timeout.is_zero() { + let sleep_ft = sleep(timeout); pin_mut!(sleep_ft); pin_mut!(ft); - trace!("Using timeout: {} ms", timeout); + trace!("Using timeout: {:?}", timeout); match select(ft, sleep_ft).await { Either::Left((lhs, _)) => lhs.map_err(|e| e.into()), Either::Right((_, _)) => Err(RedisError::new(RedisErrorKind::Timeout, "Request timed out.")), @@ -330,13 +306,6 @@ where } } -pub async fn wait_for_response( - rx: OneshotReceiver>, - timeout: u64, -) -> Result { - apply_timeout(rx, timeout).await? -} - pub fn has_blocking_error_policy(inner: &Arc) -> bool { inner.config.blocking == Blocking::Error } @@ -345,11 +314,11 @@ pub fn has_blocking_interrupt_policy(inner: &Arc) -> bool { inner.config.blocking == Blocking::Interrupt } +/// Whether the router should check and interrupt the blocked command. async fn should_enforce_blocking_policy(inner: &Arc, command: &RedisCommand) -> bool { - // TODO switch the blocked flag to an AtomicBool to avoid this locked check on each command !command.kind.closes_connection() - && (inner.config.blocking == Blocking::Error || inner.config.blocking == Blocking::Interrupt) - && inner.backchannel.read().await.is_blocked() + && ((inner.config.blocking == Blocking::Error || inner.config.blocking == Blocking::Interrupt) + && inner.backchannel.read().await.is_blocked()) } /// Interrupt the currently blocked connection (if found) with the provided flag. @@ -383,13 +352,13 @@ pub async fn interrupt_blocked_connection( let command = RedisCommand::new(RedisCommandKind::ClientUnblock, args); let frame = backchannel_request_response(inner, command, true).await?; - protocol_utils::frame_to_single_result(frame).map(|_| ()) + protocol_utils::frame_to_results(frame).map(|_| ()) } /// Check the status of the connection (usually before sending a command) to determine whether the connection should /// be unblocked automatically. async fn check_blocking_policy(inner: &Arc, command: &RedisCommand) -> Result<(), RedisError> { - if should_enforce_blocking_policy(inner, &command).await { + if should_enforce_blocking_policy(inner, command).await { _debug!( inner, "Checking to enforce blocking policy for {}", @@ -411,6 +380,15 @@ async fn check_blocking_policy(inner: &Arc, command: &RedisCom Ok(()) } +/// Prepare the command options, returning the timeout duration to apply. +pub fn prepare_command(client: &C, command: &mut RedisCommand) -> Duration { + client.change_command(command); + command.inherit_options(client.inner()); + command + .timeout_dur + .unwrap_or_else(|| client.inner().default_command_timeout()) +} + /// Send a command to the server using the default response handler. pub async fn basic_request_response(client: &C, func: F) -> Result where @@ -424,13 +402,14 @@ where command.response = ResponseKind::Respond(Some(tx)); let timed_out = command.timed_out.clone(); - let _ = check_blocking_policy(inner, &command).await?; - let _ = disallow_nested_values(&command)?; - let _ = client.send_command(command)?; + let timeout_dur = prepare_command(client, &mut command); + check_blocking_policy(inner, &command).await?; + client.send_command(command)?; - wait_for_response(rx, inner.default_command_timeout()) + apply_timeout(rx, timeout_dur) + .and_then(|r| async { r }) .map_err(move |error| { - utils::set_bool_atomic(&timed_out, true); + set_bool_atomic(&timed_out, true); error }) .await @@ -462,7 +441,6 @@ where let req_size = protocol_utils::args_size(&command.args()); args_span.record("num_args", &command.args().len()); - let _ = disallow_nested_values(&command)?; (command, rx, req_size) }; cmd_span.record("cmd", &command.kind.to_str_debug()); @@ -480,12 +458,14 @@ where command.traces.cmd = Some(cmd_span.clone()); command.traces.queued = Some(queued_span); + let timeout_dur = prepare_command(client, &mut command); let _ = check_blocking_policy(inner, &command).await?; let _ = client.send_command(command)?; - wait_for_response(rx, inner.default_command_timeout()) + apply_timeout(rx, timeout_dur) + .and_then(|r| async { r }) .map_err(move |error| { - utils::set_bool_atomic(&timed_out, true); + set_bool_atomic(&timed_out, true); error }) .and_then(|frame| async move { @@ -530,19 +510,6 @@ pub fn check_empty_keys(keys: &MultipleKeys) -> Result<(), RedisError> { } } -pub fn disallow_nested_values(cmd: &RedisCommand) -> Result<(), RedisError> { - for arg in cmd.args().iter() { - if arg.is_map() || arg.is_array() { - return Err(RedisError::new( - RedisErrorKind::InvalidArgument, - format!("Invalid argument type: {:?}", arg.kind()), - )); - } - } - - Ok(()) -} - /// Check for a scan pattern without a hash tag, or with a wildcard in the hash tag. /// /// These patterns will result in scanning a random node if used against a clustered redis. @@ -619,19 +586,6 @@ where Ok(out) } -#[cfg(feature = "serde-json")] -pub fn parse_nested_json(s: &str) -> Option { - let trimmed = s.trim(); - let is_maybe_json = - (trimmed.starts_with("{") && trimmed.ends_with("}")) || (trimmed.starts_with("[") && trimmed.ends_with("]")); - - if is_maybe_json { - serde_json::from_str(s).ok() - } else { - None - } -} - pub fn flatten_nested_array_values(value: RedisValue, depth: usize) -> RedisValue { if depth == 0 { return value; @@ -673,8 +627,8 @@ pub fn flatten_nested_array_values(value: RedisValue, depth: usize) -> RedisValu } pub fn is_maybe_array_map(arr: &Vec) -> bool { - if arr.len() > 0 && arr.len() % 2 == 0 { - arr.chunks(2).fold(true, |b, chunk| b && !chunk[0].is_aggregate_type()) + if !arr.is_empty() && arr.len() % 2 == 0 { + arr.chunks(2).all(|chunk| !chunk[0].is_aggregate_type()) } else { false } @@ -790,7 +744,7 @@ pub fn parse_url_other_nodes(url: &Url) -> Result, RedisError> { for (key, value) in url.query_pairs().into_iter() { if key == CLUSTER_NODE_QUERY { - let parts: Vec<&str> = value.split(":").collect(); + let parts: Vec<&str> = value.split(':').collect(); if parts.len() != 2 { return Err(RedisError::new( RedisErrorKind::Config, @@ -842,16 +796,23 @@ pub fn parse_url_sentinel_password(url: &Url) -> Option { }) } -#[cfg(feature = "check-unresponsive")] -pub fn abort_network_timeout_task(inner: &Arc) { - if let Some(jh) = inner.network_timeouts.take_handle() { - _trace!(inner, "Shut down network timeout task."); - jh.abort(); - } +pub async fn clear_backchannel_state(inner: &Arc) { + inner.backchannel.write().await.clear_router_state(inner).await; } -#[cfg(not(feature = "check-unresponsive"))] -pub fn abort_network_timeout_task(_: &Arc) {} +/// Send QUIT to the servers and clean up the old router task's state. +pub fn close_router_channel(inner: &Arc, command_tx: Arc) { + set_client_state(&inner.state, ClientState::Disconnecting); + inner.notifications.broadcast_close(); + inner.reset_server_state(); + + let command = RedisCommand::new(RedisCommandKind::Quit, vec![]); + inner.counters.incr_cmd_buffer_len(); + if let Err(_) = command_tx.send(command.into()) { + inner.counters.decr_cmd_buffer_len(); + _warn!(inner, "Failed to send QUIT when dropping old command channel."); + } +} #[cfg(test)] mod tests { diff --git a/tests/README.md b/tests/README.md index 1449781e..bd0500e9 100644 --- a/tests/README.md +++ b/tests/README.md @@ -4,7 +4,7 @@ Tests are organized by category, similar to the [commands](../src/commands) fold By default, most tests run 8 times based on the following configuration parameters: clustered vs centralized servers, pipelined vs non-pipelined clients, and RESP2 vs RESP3 mode. Helper macros exist to make this easy so each test only has to be written once. -**The tests require Redis version >=6.2** As of writing the default version used is 7.0.5. +**The tests require Redis version >=6.2** As of writing the default version used is 7.2.1. ## Installation @@ -22,10 +22,12 @@ The runner scripts will set up the Redis servers and run the tests inside docker * [all-features](runners/all-features.sh) will run tests with all features (except sentinel tests). * [default-features](runners/default-features.sh) will run tests with default features (except sentinel tests). +* [default-nil-types](runners/default-nil-types.sh) will run tests with `default-nil-types`. * [no-features](runners/no-features.sh) will run the tests without any of the feature flags. * [sentinel-features](runners/sentinel-features.sh) will run the centralized tests against a sentinel deployment. This is the only test runner that requires the sentinel deployment via docker-compose. * [cluster-rustls](runners/cluster-rustls.sh) will set up a cluster with TLS enabled and run the cluster tests against it with `rustls`. * [cluster-native-tls](runners/cluster-native-tls.sh) will set up a cluster with TLS enabled and run the cluster tests against it with `native-tls`. +* [redis-stack](runners/redis-stack.sh) will set up a centralized `redis/redis-stack` container and run with `redis-stack` features. * [everything](runners/everything.sh) will run all of the above scripts. These scripts will pass through any extra argv so callers can filter tests as needed. diff --git a/tests/docker/compose/base.yml b/tests/docker/compose/base.yml index 0a8a525a..22a9f2d2 100644 --- a/tests/docker/compose/base.yml +++ b/tests/docker/compose/base.yml @@ -7,6 +7,7 @@ services: - redis-main - redis-cluster-6 - redis-sentinel-3 + - redis-stack-main container_name: "debug" build: context: ../../../ @@ -14,7 +15,7 @@ services: args: REDIS_VERSION: "${REDIS_VERSION}" networks: - - app-tier + - fred-tests command: - "/bin/bash" environment: @@ -27,6 +28,8 @@ services: FRED_REDIS_CLUSTER_PORT: "${FRED_REDIS_CLUSTER_PORT}" FRED_REDIS_SENTINEL_HOST: "${FRED_REDIS_SENTINEL_HOST}" FRED_REDIS_SENTINEL_PORT: "${FRED_REDIS_SENTINEL_PORT}" + FRED_REDIS_STACK_HOST: "${FRED_REDIS_STACK_HOST}" + FRED_REDIS_STACK_PORT: "${FRED_REDIS_STACK_PORT}" REDIS_USERNAME: "${REDIS_USERNAME}" REDIS_PASSWORD: "${REDIS_PASSWORD}" REDIS_SENTINEL_PASSWORD: "${REDIS_SENTINEL_PASSWORD}" diff --git a/tests/docker/compose/centralized.yml b/tests/docker/compose/centralized.yml index b3c96824..c05e2f63 100644 --- a/tests/docker/compose/centralized.yml +++ b/tests/docker/compose/centralized.yml @@ -1,7 +1,7 @@ version: '2' networks: - app-tier: + fred-tests: driver: bridge services: @@ -14,7 +14,7 @@ services: ports: - "6379:${FRED_REDIS_CENTRALIZED_PORT}" networks: - - app-tier + - fred-tests volumes: - '../../../tests/users.acl:/opt/bitnami/redis/mounted-etc/users.acl' - '../../../tests/docker/overrides/default.conf:/opt/bitnami/redis/mounted-etc/overrides.conf' \ No newline at end of file diff --git a/tests/docker/compose/cluster-tls.yml b/tests/docker/compose/cluster-tls.yml index 11730ce4..3dc86920 100644 --- a/tests/docker/compose/cluster-tls.yml +++ b/tests/docker/compose/cluster-tls.yml @@ -1,7 +1,7 @@ version: '2' networks: - app-tier: + fred-tests: driver: bridge services: @@ -22,7 +22,7 @@ services: ports: - "40001:${FRED_REDIS_CLUSTER_TLS_PORT}" networks: - - app-tier + - fred-tests volumes: - '../../../tests/tmp/creds:/opt/bitnami/redis/mounted-etc/creds' - '../../../tests/users.acl:/opt/bitnami/redis/mounted-etc/users.acl' @@ -44,7 +44,7 @@ services: ports: - "40002:${FRED_REDIS_CLUSTER_TLS_PORT}" networks: - - app-tier + - fred-tests volumes: - '../../../tests/tmp/creds:/opt/bitnami/redis/mounted-etc/creds' - '../../../tests/users.acl:/opt/bitnami/redis/mounted-etc/users.acl' @@ -66,7 +66,7 @@ services: ports: - "40003:${FRED_REDIS_CLUSTER_TLS_PORT}" networks: - - app-tier + - fred-tests volumes: - '../../../tests/tmp/creds:/opt/bitnami/redis/mounted-etc/creds' - '../../../tests/users.acl:/opt/bitnami/redis/mounted-etc/users.acl' @@ -88,7 +88,7 @@ services: ports: - "40004:${FRED_REDIS_CLUSTER_TLS_PORT}" networks: - - app-tier + - fred-tests volumes: - '../../../tests/tmp/creds:/opt/bitnami/redis/mounted-etc/creds' - '../../../tests/users.acl:/opt/bitnami/redis/mounted-etc/users.acl' @@ -110,7 +110,7 @@ services: ports: - "40005:${FRED_REDIS_CLUSTER_TLS_PORT}" networks: - - app-tier + - fred-tests volumes: - '../../../tests/tmp/creds:/opt/bitnami/redis/mounted-etc/creds' - '../../../tests/users.acl:/opt/bitnami/redis/mounted-etc/users.acl' @@ -140,7 +140,7 @@ services: ports: - "40006:${FRED_REDIS_CLUSTER_TLS_PORT}" networks: - - app-tier + - fred-tests volumes: - '../../../tests/tmp/creds:/opt/bitnami/redis/mounted-etc/creds' - '../../../tests/users.acl:/opt/bitnami/redis/mounted-etc/users.acl' diff --git a/tests/docker/compose/cluster.yml b/tests/docker/compose/cluster.yml index 92136410..a1a66428 100644 --- a/tests/docker/compose/cluster.yml +++ b/tests/docker/compose/cluster.yml @@ -1,7 +1,7 @@ version: '2' networks: - app-tier: + fred-tests: driver: bridge services: @@ -16,7 +16,7 @@ services: ports: - "30001:${FRED_REDIS_CLUSTER_PORT}" networks: - - app-tier + - fred-tests volumes: - '../../../tests/users.acl:/opt/bitnami/redis/mounted-etc/users.acl' - '../../../tests/docker/overrides/default.conf:/opt/bitnami/redis/mounted-etc/overrides.conf' @@ -31,7 +31,7 @@ services: ports: - "30002:${FRED_REDIS_CLUSTER_PORT}" networks: - - app-tier + - fred-tests volumes: - '../../../tests/users.acl:/opt/bitnami/redis/mounted-etc/users.acl' - '../../../tests/docker/overrides/default.conf:/opt/bitnami/redis/mounted-etc/overrides.conf' @@ -46,7 +46,7 @@ services: ports: - "30003:${FRED_REDIS_CLUSTER_PORT}" networks: - - app-tier + - fred-tests volumes: - '../../../tests/users.acl:/opt/bitnami/redis/mounted-etc/users.acl' - '../../../tests/docker/overrides/default.conf:/opt/bitnami/redis/mounted-etc/overrides.conf' @@ -61,7 +61,7 @@ services: ports: - "30004:${FRED_REDIS_CLUSTER_PORT}" networks: - - app-tier + - fred-tests volumes: - '../../../tests/users.acl:/opt/bitnami/redis/mounted-etc/users.acl' - '../../../tests/docker/overrides/default.conf:/opt/bitnami/redis/mounted-etc/overrides.conf' @@ -76,7 +76,7 @@ services: ports: - "30005:${FRED_REDIS_CLUSTER_PORT}" networks: - - app-tier + - fred-tests volumes: - '../../../tests/users.acl:/opt/bitnami/redis/mounted-etc/users.acl' - '../../../tests/docker/overrides/default.conf:/opt/bitnami/redis/mounted-etc/overrides.conf' @@ -99,7 +99,7 @@ services: ports: - "30006:${FRED_REDIS_CLUSTER_PORT}" networks: - - app-tier + - fred-tests volumes: - '../../../tests/users.acl:/opt/bitnami/redis/mounted-etc/users.acl' - '../../../tests/docker/overrides/default.conf:/opt/bitnami/redis/mounted-etc/overrides.conf' \ No newline at end of file diff --git a/tests/docker/compose/redis-stack.yml b/tests/docker/compose/redis-stack.yml new file mode 100644 index 00000000..7bc91db0 --- /dev/null +++ b/tests/docker/compose/redis-stack.yml @@ -0,0 +1,23 @@ +version: '2' + +networks: + fred-tests: + driver: bridge + +services: + redis-stack-main: + container_name: "redis-stack-main" + image: 'redis/redis-stack:latest' + environment: + - 'ALLOW_EMPTY_PASSWORD=yes' + - 'REDIS_ARGS="--requirepass ${REDIS_PASSWORD}"' + # - 'REDISEARCH_ARGS=""' + # - 'REDISJSON_ARGS=""' + # - 'REDISGRAPH_ARGS=""' + # - 'REDISTIMESERIES_ARGS=""' + # - 'REDISBLOOM_ARGS=""' + ports: + - "6382:${FRED_REDIS_STACK_PORT}" + - "8001:8001" + networks: + - fred-tests \ No newline at end of file diff --git a/tests/docker/compose/sentinel.yml b/tests/docker/compose/sentinel.yml index e2b70daf..cdfa77e8 100644 --- a/tests/docker/compose/sentinel.yml +++ b/tests/docker/compose/sentinel.yml @@ -1,7 +1,7 @@ version: '2' networks: - app-tier: + fred-tests: driver: bridge services: @@ -16,7 +16,7 @@ services: ports: - "6380:6380" networks: - - app-tier + - fred-tests volumes: - '../../../tests/users.acl:/opt/bitnami/redis/mounted-etc/users.acl' redis-sentinel-replica: @@ -30,13 +30,14 @@ services: - 'REDIS_MASTER_PASSWORD=${REDIS_PASSWORD}' - 'ALLOW_EMPTY_PASSWORD=yes' - 'REDIS_REPLICATION_MODE=slave' + - 'REDIS_REPLICA_PORT=6381' - 'REDIS_MASTER_HOST=redis-sentinel-main' - 'REDIS_MASTER_PORT_NUMBER=6380' - 'REDIS_ACLFILE=/opt/bitnami/redis/mounted-etc/users.acl' ports: - "6381:6381" networks: - - app-tier + - fred-tests volumes: - '../../../tests/users.acl:/opt/bitnami/redis/mounted-etc/users.acl' redis-sentinel-1: @@ -57,7 +58,7 @@ services: ports: - '26379:26379' networks: - - app-tier + - fred-tests volumes: - '../../../tests/users.acl:/opt/bitnami/redis/mounted-etc/users.acl' redis-sentinel-2: @@ -79,7 +80,7 @@ services: ports: - '26380:26380' networks: - - app-tier + - fred-tests volumes: - '../../../tests/users.acl:/opt/bitnami/redis/mounted-etc/users.acl' redis-sentinel-3: @@ -101,6 +102,6 @@ services: ports: - '26381:26381' networks: - - app-tier + - fred-tests volumes: - '../../../tests/users.acl:/opt/bitnami/redis/mounted-etc/users.acl' \ No newline at end of file diff --git a/tests/docker/runners/bash/all-features.sh b/tests/docker/runners/bash/all-features.sh index 2604af50..ef645267 100755 --- a/tests/docker/runners/bash/all-features.sh +++ b/tests/docker/runners/bash/all-features.sh @@ -11,8 +11,8 @@ do done FEATURES="network-logs pool-prefer-active custom-reconnect-errors ignore-auth-error serde-json blocking-encoding - full-tracing reconnect-on-auth-error monitor metrics sentinel-client subscriber-client no-client-setname - dns debug-ids check-unresponsive replicas client-tracking" + full-tracing reconnect-on-auth-error monitor metrics sentinel-client subscriber-client dns debug-ids + check-unresponsive replicas client-tracking codec sha-1 auto-client-setname" if [ -z "$FRED_CI_NEXTEST" ]; then cargo test --release --lib --tests --features "$FEATURES" -- --test-threads=1 "$@" diff --git a/tests/docker/runners/bash/default-nil-types.sh b/tests/docker/runners/bash/default-nil-types.sh new file mode 100755 index 00000000..7747a257 --- /dev/null +++ b/tests/docker/runners/bash/default-nil-types.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +declare -a arr=("REDIS_VERSION" "REDIS_USERNAME" "REDIS_PASSWORD" "REDIS_SENTINEL_PASSWORD") + +for env in "${arr[@]}" +do + if [ -z "$env" ]; then + echo "$env must be set. Run `source tests/environ` if needed." + exit 1 + fi +done + +FEATURES="network-logs serde-json debug-ids replicas client-tracking default-nil-types" + +if [ -z "$FRED_CI_NEXTEST" ]; then + cargo test --release --lib --tests --features "$FEATURES" -- --test-threads=1 "$@" +else + cargo nextest run --release --lib --tests --features "$FEATURES" --test-threads=1 "$@" +fi \ No newline at end of file diff --git a/tests/docker/runners/bash/redis-stack.sh b/tests/docker/runners/bash/redis-stack.sh new file mode 100755 index 00000000..bcf36f1a --- /dev/null +++ b/tests/docker/runners/bash/redis-stack.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +declare -a arr=("REDIS_VERSION" "REDIS_PASSWORD" "FRED_REDIS_STACK_HOST" "FRED_REDIS_STACK_PORT") + +for env in "${arr[@]}" +do + if [ -z "$env" ]; then + echo "$env must be set. Run `source tests/environ` if needed." + exit 1 + fi +done + +FEATURES="network-logs serde-json debug-ids redis-stack" + +if [ -z "$FRED_CI_NEXTEST" ]; then + cargo test --release --lib --tests --features "$FEATURES" -- --test-threads=1 "$@" +else + cargo nextest run --release --lib --tests --features "$FEATURES" --test-threads=1 "$@" +fi \ No newline at end of file diff --git a/tests/docker/runners/bash/sentinel-features.sh b/tests/docker/runners/bash/sentinel-features.sh index 3416b8e3..15856451 100755 --- a/tests/docker/runners/bash/sentinel-features.sh +++ b/tests/docker/runners/bash/sentinel-features.sh @@ -10,7 +10,7 @@ do fi done -FEATURES="network-logs debug-ids sentinel-tests sentinel-auth replicas" +FEATURES="network-logs debug-ids sentinel-auth replicas" if [ -z "$FRED_CI_NEXTEST" ]; then cargo test --release --lib --tests --features "$FEATURES" -- --test-threads=1 "$@" diff --git a/tests/docker/runners/compose/all-features.yml b/tests/docker/runners/compose/all-features.yml index fe15c077..b79b0b4b 100644 --- a/tests/docker/runners/compose/all-features.yml +++ b/tests/docker/runners/compose/all-features.yml @@ -12,7 +12,8 @@ services: args: REDIS_VERSION: "${REDIS_VERSION}" networks: - - app-tier + - fred-tests + privileged: true command: - "/project/tests/docker/runners/bash/all-features.sh" - "${TEST_ARGV}" diff --git a/tests/docker/runners/compose/cluster-native-tls.yml b/tests/docker/runners/compose/cluster-native-tls.yml index 1aa4bb8c..18243bd6 100644 --- a/tests/docker/runners/compose/cluster-native-tls.yml +++ b/tests/docker/runners/compose/cluster-native-tls.yml @@ -11,7 +11,8 @@ services: args: REDIS_VERSION: "${REDIS_VERSION}" networks: - - app-tier + - fred-tests + privileged: true command: - "/project/tests/docker/runners/bash/cluster-tls.sh" - "${TEST_ARGV}" diff --git a/tests/docker/runners/compose/cluster-rustls.yml b/tests/docker/runners/compose/cluster-rustls.yml index 56ee15d5..008f4209 100644 --- a/tests/docker/runners/compose/cluster-rustls.yml +++ b/tests/docker/runners/compose/cluster-rustls.yml @@ -11,7 +11,8 @@ services: args: REDIS_VERSION: "${REDIS_VERSION}" networks: - - app-tier + - fred-tests + privileged: true command: - "/project/tests/docker/runners/bash/cluster-rustls.sh" - "${TEST_ARGV}" diff --git a/tests/docker/runners/compose/default-features.yml b/tests/docker/runners/compose/default-features.yml index 6b3c67b9..e0dbe794 100644 --- a/tests/docker/runners/compose/default-features.yml +++ b/tests/docker/runners/compose/default-features.yml @@ -12,10 +12,11 @@ services: args: REDIS_VERSION: "${REDIS_VERSION}" networks: - - app-tier + - fred-tests command: - "/project/tests/docker/runners/bash/default-features.sh" - "${TEST_ARGV}" + privileged: true environment: RUST_LOG: "${RUST_LOG}" CIRCLECI_TESTS: "${CIRCLECI_TESTS}" diff --git a/tests/docker/runners/compose/default-nil-types.yml b/tests/docker/runners/compose/default-nil-types.yml new file mode 100644 index 00000000..c5e34385 --- /dev/null +++ b/tests/docker/runners/compose/default-nil-types.yml @@ -0,0 +1,32 @@ +version: '2' + +services: + default-nil-types-tests: + depends_on: + - redis-main + - redis-cluster-6 + container_name: "default-nil-types-tests" + build: + context: ../../../ + dockerfile: tests/docker/runners/images/base.dockerfile + args: + REDIS_VERSION: "${REDIS_VERSION}" + networks: + - fred-tests + command: + - "/project/tests/docker/runners/bash/default-nil-types.sh" + - "${TEST_ARGV}" + privileged: true + environment: + RUST_LOG: "${RUST_LOG}" + CIRCLECI_TESTS: "${CIRCLECI_TESTS}" + REDIS_VERSION: "${REDIS_VERSION}" + FRED_REDIS_CENTRALIZED_HOST: "${FRED_REDIS_CENTRALIZED_HOST}" + FRED_REDIS_CENTRALIZED_PORT: "${FRED_REDIS_CENTRALIZED_PORT}" + FRED_REDIS_CLUSTER_HOST: "${FRED_REDIS_CLUSTER_HOST}" + FRED_REDIS_CLUSTER_PORT: "${FRED_REDIS_CLUSTER_PORT}" + REDIS_USERNAME: "${REDIS_USERNAME}" + REDIS_PASSWORD: "${REDIS_PASSWORD}" + volumes: + - "../../..:/project" + - "~/.cargo/registry:/usr/local/cargo/registry" \ No newline at end of file diff --git a/tests/docker/runners/compose/no-features.yml b/tests/docker/runners/compose/no-features.yml index eb31f485..44f116c7 100644 --- a/tests/docker/runners/compose/no-features.yml +++ b/tests/docker/runners/compose/no-features.yml @@ -12,7 +12,8 @@ services: args: REDIS_VERSION: "${REDIS_VERSION}" networks: - - app-tier + - fred-tests + privileged: true command: - "/project/tests/docker/runners/bash/no-features.sh" - "${TEST_ARGV}" diff --git a/tests/docker/runners/compose/redis-stack.yml b/tests/docker/runners/compose/redis-stack.yml new file mode 100644 index 00000000..87d6f8b6 --- /dev/null +++ b/tests/docker/runners/compose/redis-stack.yml @@ -0,0 +1,28 @@ +version: '2' + +services: + redis-stack-tests: + depends_on: + - redis-stack-main + container_name: "redis-stack-tests" + build: + context: ../../../ + dockerfile: tests/docker/runners/images/base.dockerfile + args: + REDIS_VERSION: "${REDIS_VERSION}" + networks: + - fred-tests + command: + - "/project/tests/docker/runners/bash/redis-stack.sh" + - "${TEST_ARGV}" + privileged: true + environment: + RUST_LOG: "${RUST_LOG}" + CIRCLECI_TESTS: "${CIRCLECI_TESTS}" + REDIS_VERSION: "${REDIS_VERSION}" + FRED_REDIS_STACK_HOST: "${FRED_REDIS_STACK_HOST}" + FRED_REDIS_STACK_PORT: "${FRED_REDIS_STACK_PORT}" + REDIS_PASSWORD: "${REDIS_PASSWORD}" + volumes: + - "../../..:/project" + - "~/.cargo/registry:/usr/local/cargo/registry" \ No newline at end of file diff --git a/tests/docker/runners/compose/sentinel-features.yml b/tests/docker/runners/compose/sentinel-features.yml index d197ca46..44b6484e 100644 --- a/tests/docker/runners/compose/sentinel-features.yml +++ b/tests/docker/runners/compose/sentinel-features.yml @@ -13,7 +13,8 @@ services: args: REDIS_VERSION: "${REDIS_VERSION}" networks: - - app-tier + - fred-tests + privileged: true command: - "/project/tests/docker/runners/bash/sentinel-features.sh" - "${TEST_ARGV}" @@ -26,6 +27,7 @@ services: REDIS_USERNAME: "${REDIS_USERNAME}" REDIS_PASSWORD: "${REDIS_PASSWORD}" REDIS_SENTINEL_PASSWORD: "${REDIS_SENTINEL_PASSWORD}" + FRED_SENTINEL_TESTS: "1" volumes: - "../../..:/project" - "~/.cargo/registry:/usr/local/cargo/registry" \ No newline at end of file diff --git a/tests/docker/runners/images/base.dockerfile b/tests/docker/runners/images/base.dockerfile index 0cb7d7c8..562dd1c9 100644 --- a/tests/docker/runners/images/base.dockerfile +++ b/tests/docker/runners/images/base.dockerfile @@ -1,4 +1,4 @@ -FROM rust:1.72.0-slim-buster +FROM rust:1.72.1-slim-buster WORKDIR /project @@ -17,7 +17,7 @@ ARG FRED_REDIS_SENTINEL_HOST ARG FRED_REDIS_SENTINEL_PORT ARG CIRCLECI_TESTS -RUN USER=root apt-get update && apt-get install -y build-essential libssl-dev dnsutils +RUN USER=root apt-get update && apt-get install -y build-essential libssl-dev dnsutils curl pkg-config RUN echo "REDIS_VERSION=$REDIS_VERSION" # For debugging diff --git a/tests/docker/runners/images/ci.dockerfile b/tests/docker/runners/images/ci.dockerfile index 912ff2bc..a1fc5550 100644 --- a/tests/docker/runners/images/ci.dockerfile +++ b/tests/docker/runners/images/ci.dockerfile @@ -1,4 +1,4 @@ -FROM rust:1.72.0-slim-buster +FROM rust:1.72.1-slim-buster WORKDIR /project # circleci doesn't mount volumes with a remote docker engine so we have to copy everything diff --git a/tests/docker/runners/images/debug.dockerfile b/tests/docker/runners/images/debug.dockerfile index c00b644a..562dd1c9 100644 --- a/tests/docker/runners/images/debug.dockerfile +++ b/tests/docker/runners/images/debug.dockerfile @@ -1,4 +1,4 @@ -FROM rust:1.72.0-slim-buster +FROM rust:1.72.1-slim-buster WORKDIR /project diff --git a/tests/environ b/tests/environ index b6f3951a..2c519f89 100644 --- a/tests/environ +++ b/tests/environ @@ -2,13 +2,14 @@ . ./tests/scripts/utils.sh if [ -z "$REDIS_VERSION" ]; then - export REDIS_VERSION=7.0.9 + export REDIS_VERSION=7.2.1 fi if [ -z "$CARGO_HTTP_DEBUG" ]; then export CARGO_HTTP_DEBUG=false fi +# FIXME: changing the redis stack port here doesn't work. this might just be a limitation on the image's config interface though export ROOT=$PWD \ RUST_BACKTRACE=full \ FRED_REDIS_CLUSTER_HOST=redis-cluster-1 \ @@ -19,6 +20,8 @@ export ROOT=$PWD \ FRED_REDIS_CENTRALIZED_PORT=6379 \ FRED_REDIS_SENTINEL_HOST=redis-sentinel-1 \ FRED_REDIS_SENTINEL_PORT=26379 \ + FRED_REDIS_STACK_HOST=redis-stack-main \ + FRED_REDIS_STACK_PORT=6379 \ FRED_TEST_TLS_CREDS=$PWD/tests/tmp/creds \ REDIS_USERNAME=foo \ REDIS_PASSWORD=bar \ diff --git a/tests/integration/acl/mod.rs b/tests/integration/acl/mod.rs index b5614429..4d15fa7e 100644 --- a/tests/integration/acl/mod.rs +++ b/tests/integration/acl/mod.rs @@ -1,21 +1,20 @@ -use super::utils::read_env_var; +use super::utils::{read_env_var, should_use_sentinel_config}; use fred::{ clients::RedisClient, error::RedisError, interfaces::*, - types::{AclUserFlag, RedisConfig}, + types::{RedisConfig, RedisValue}, }; +use std::collections::HashMap; // the docker image we use for sentinel tests doesn't allow for configuring users, just passwords, // so for the tests here we just use an empty username so it uses the `default` user -#[cfg(feature = "sentinel-tests")] fn read_redis_username() -> Option { - None -} - -#[cfg(not(feature = "sentinel-tests"))] -fn read_redis_username() -> Option { - read_env_var("REDIS_USERNAME") + if should_use_sentinel_config() { + None + } else { + read_env_var("REDIS_USERNAME") + } } fn check_env_creds() -> (Option, Option) { @@ -26,31 +25,32 @@ fn check_env_creds() -> (Option, Option) { pub async fn should_auth_as_test_user(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { let (username, password) = check_env_creds(); if let Some(password) = password { - let _ = client.auth(username, password).await?; - let _: () = client.get("foo").await?; + client.auth(username, password).await?; + client.get("foo").await?; } Ok(()) } -// note: currently this only works in CI against the centralized server +// FIXME currently this only works in CI against the centralized server pub async fn should_auth_as_test_user_via_config(_: RedisClient, mut config: RedisConfig) -> Result<(), RedisError> { let (username, password) = check_env_creds(); if let Some(password) = password { config.username = username; config.password = Some(password); - let client = RedisClient::new(config, None, None); - let _ = client.connect(); - let _ = client.wait_for_connect().await?; - let _: () = client.get("foo").await?; + let client = RedisClient::new(config, None, None, None); + client.connect(); + client.wait_for_connect().await?; + client.get("foo").await?; } Ok(()) } pub async fn should_run_acl_getuser(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { - let user = client.acl_getuser("default").await?.unwrap(); - assert!(user.flags.contains(&AclUserFlag::On)); + let user: HashMap = client.acl_getuser("default").await?; + let flags: Vec = user.get("flags").unwrap().clone().convert()?; + assert!(flags.contains(&"on".to_string())); Ok(()) } diff --git a/tests/integration/centralized.rs b/tests/integration/centralized.rs index f6226a28..7034ae50 100644 --- a/tests/integration/centralized.rs +++ b/tests/integration/centralized.rs @@ -1,5 +1,5 @@ mod keys { - + centralized_test!(keys, should_handle_missing_keys); centralized_test!(keys, should_set_and_get_a_value); centralized_test!(keys, should_set_and_del_a_value); centralized_test!(keys, should_set_with_get_argument); @@ -50,6 +50,10 @@ mod other { centralized_test!(other, should_pipeline_try_all); centralized_test!(other, should_use_all_cluster_nodes_repeatedly); centralized_test!(other, should_gracefully_quit); + centralized_test!(other, should_support_options_with_pipeline); + centralized_test!(other, should_reuse_pipeline); + centralized_test!(other, should_manually_connect_twice); + centralized_test!(other, should_support_options_with_trx); //#[cfg(feature = "dns")] // centralized_test!(other, should_use_trust_dns); @@ -66,6 +70,11 @@ mod other { centralized_test!(other, should_replica_set_and_get_not_lazy); #[cfg(feature = "replicas")] centralized_test!(other, should_pipeline_with_replicas); + + #[cfg(feature = "codec")] + centralized_test!(other, should_use_resp3_codec_example); + #[cfg(feature = "codec")] + centralized_test!(other, should_use_resp2_codec_example); } mod pool { @@ -165,11 +174,16 @@ pub mod memory { pub mod lua { + #[cfg(feature = "sha-1")] centralized_test!(lua, should_load_script); centralized_test!(lua, should_eval_echo_script); + #[cfg(feature = "sha-1")] centralized_test!(lua, should_eval_get_script); + #[cfg(feature = "sha-1")] centralized_test!(lua, should_evalsha_echo_script); + #[cfg(feature = "sha-1")] centralized_test!(lua, should_evalsha_with_reload_echo_script); + #[cfg(feature = "sha-1")] centralized_test!(lua, should_evalsha_get_script); centralized_test!(lua, should_function_load_scripts); @@ -182,7 +196,9 @@ pub mod lua { centralized_test!(lua, should_function_fcall_echo); centralized_test!(lua, should_function_fcall_ro_echo); + #[cfg(feature = "sha-1")] centralized_test!(lua, should_create_lua_script_helper_from_code); + #[cfg(feature = "sha-1")] centralized_test!(lua, should_create_lua_script_helper_from_hash); centralized_test!(lua, should_create_function_from_code); centralized_test!(lua, should_create_function_from_name); @@ -306,3 +322,21 @@ mod tracking { centralized_test!(tracking, should_invalidate_foo_resp3); centralized_test!(tracking, should_invalidate_foo_resp2_centralized); } + +// The CI settings for redis-stack only support centralized configs for now. +#[cfg(feature = "redis-json")] +mod redis_json { + centralized_test!(redis_json, should_get_and_set_basic_obj); + centralized_test!(redis_json, should_get_and_set_stringified_obj); + centralized_test!(redis_json, should_array_append); + centralized_test!(redis_json, should_modify_arrays); + centralized_test!(redis_json, should_pop_and_trim_arrays); + centralized_test!(redis_json, should_get_set_del_obj); + centralized_test!(redis_json, should_merge_objects); + centralized_test!(redis_json, should_mset_and_mget); + centralized_test!(redis_json, should_incr_numbers); + centralized_test!(redis_json, should_inspect_objects); + centralized_test!(redis_json, should_modify_strings); + centralized_test!(redis_json, should_toggle_boolean); + centralized_test!(redis_json, should_get_value_type); +} diff --git a/tests/integration/cluster/mod.rs b/tests/integration/cluster/mod.rs index 4b0fe83f..73fb2816 100644 --- a/tests/integration/cluster/mod.rs +++ b/tests/integration/cluster/mod.rs @@ -9,9 +9,9 @@ pub async fn should_use_each_cluster_node(client: RedisClient, _: RedisConfig) - .with_cluster_node(server) .client_info::() .await? - .split(" ") + .split(' ') .find_map(|s| { - let parts: Vec<&str> = s.split("=").collect(); + let parts: Vec<&str> = s.split('=').collect(); if parts[0] == "laddr" { Some(parts[1].to_owned()) } else { diff --git a/tests/integration/clustered.rs b/tests/integration/clustered.rs index a01f8e1a..c1b96b88 100644 --- a/tests/integration/clustered.rs +++ b/tests/integration/clustered.rs @@ -1,5 +1,5 @@ mod keys { - + cluster_test!(keys, should_handle_missing_keys); cluster_test!(keys, should_set_and_get_a_value); cluster_test!(keys, should_set_and_del_a_value); cluster_test!(keys, should_set_with_get_argument); @@ -54,6 +54,10 @@ mod other { cluster_test!(other, should_pipeline_try_all); cluster_test!(other, should_use_all_cluster_nodes_repeatedly); cluster_test!(other, should_gracefully_quit); + cluster_test!(other, should_support_options_with_pipeline); + cluster_test!(other, should_reuse_pipeline); + cluster_test!(other, should_manually_connect_twice); + cluster_test!(other, should_support_options_with_trx); //#[cfg(feature = "dns")] // cluster_test!(other, should_use_trust_dns); @@ -70,6 +74,11 @@ mod other { cluster_test!(other, should_use_cluster_replica_without_redirection); #[cfg(feature = "replicas")] cluster_test!(other, should_pipeline_with_replicas); + + #[cfg(feature = "codec")] + cluster_test!(other, should_use_resp3_codec_example); + #[cfg(feature = "codec")] + cluster_test!(other, should_use_resp2_codec_example); } mod pool { @@ -173,12 +182,18 @@ pub mod memory { pub mod lua { + #[cfg(feature = "sha-1")] cluster_test!(lua, should_load_script); + #[cfg(feature = "sha-1")] cluster_test!(lua, should_load_script_cluster); cluster_test!(lua, should_eval_echo_script); + #[cfg(feature = "sha-1")] cluster_test!(lua, should_eval_get_script); + #[cfg(feature = "sha-1")] cluster_test!(lua, should_evalsha_echo_script); + #[cfg(feature = "sha-1")] cluster_test!(lua, should_evalsha_with_reload_echo_script); + #[cfg(feature = "sha-1")] cluster_test!(lua, should_evalsha_get_script); cluster_test!(lua, should_function_load_scripts); @@ -191,7 +206,9 @@ pub mod lua { cluster_test!(lua, should_function_fcall_echo); cluster_test!(lua, should_function_fcall_ro_echo); + #[cfg(feature = "sha-1")] cluster_test!(lua, should_create_lua_script_helper_from_code); + #[cfg(feature = "sha-1")] cluster_test!(lua, should_create_lua_script_helper_from_hash); cluster_test!(lua, should_create_function_from_code); cluster_test!(lua, should_create_function_from_name); @@ -264,6 +281,7 @@ pub mod geo { cluster_test!(geo, should_geosearch_values); } +#[cfg(not(feature = "redis-stack"))] pub mod acl { cluster_test!(acl, should_run_acl_getuser); } diff --git a/tests/integration/docker.rs b/tests/integration/docker.rs new file mode 100644 index 00000000..fb957cee --- /dev/null +++ b/tests/integration/docker.rs @@ -0,0 +1,210 @@ +#![allow(dead_code)] +use crate::integration::{ + docker::env::{COMPOSE_NETWORK_NAME, NETWORK_NAME}, + utils, +}; +use bollard::{ + container::{ + Config, + CreateContainerOptions, + LogOutput, + NetworkingConfig, + RemoveContainerOptions, + StartContainerOptions, + }, + errors::Error as BollardError, + exec::{CreateExecOptions, StartExecResults}, + network::{ConnectNetworkOptions, ListNetworksOptions}, + ClientVersion, + Docker, + API_DEFAULT_VERSION, +}; +use bytes::Bytes; +use fred::{prelude::*, types::ClusterRouting}; +use futures::stream::StreamExt; +use redis_protocol::resp2::decode::decode as resp2_decode; +use std::collections::HashMap; + +macro_rules! e ( + ($arg:expr) => ($arg.map_err(|e| RedisError::new(RedisErrorKind::Unknown, format!("{:?}", e)))) +); + +pub mod env { + use fred::error::{RedisError, RedisErrorKind}; + use std::env; + + // compat check + pub const COMPOSE_NETWORK_NAME: &str = "compose_fred-tests"; + pub const NETWORK_NAME: &str = "fred-tests"; + + pub const CENTRALIZED_HOST: &str = "FRED_REDIS_CENTRALIZED_HOST"; + pub const CENTRALIZED_PORT: &str = "FRED_REDIS_CENTRALIZED_PORT"; + pub const CLUSTER_HOST: &str = "FRED_REDIS_CLUSTER_HOST"; + pub const CLUSTER_PORT: &str = "FRED_REDIS_CLUSTER_PORT"; + pub const CLUSTER_TLS_HOST: &str = "FRED_REDIS_CLUSTER_TLS_HOST"; + pub const CLUSTER_TLS_PORT: &str = "FRED_REDIS_CLUSTER_TLS_PORT"; + pub const SENTINEL_HOST: &str = "FRED_REDIS_SENTINEL_HOST"; + pub const SENTINEL_PORT: &str = "FRED_REDIS_SENTINEL_PORT"; + + pub fn read(name: &str) -> Option { + env::var_os(name).and_then(|s| s.into_string().ok()) + } + + pub fn try_read(name: &str) -> Result { + read(name).ok_or(RedisError::new(RedisErrorKind::Unknown, "Failed to read env")) + } +} + +/// Read the name of the network, which may have a different prefix on older docker installs. +pub async fn read_network_name(docker: &Docker) -> Result { + let networks = e!(docker.list_networks(None::>).await)?; + + for network in networks.into_iter() { + if let Some(ref name) = network.name { + if name == NETWORK_NAME || name == COMPOSE_NETWORK_NAME { + return Ok(name.to_owned()); + } + } + } + Err(RedisError::new( + RedisErrorKind::Unknown, + "Failed to read fred test network.", + )) +} + +/// Run a command in the bitnami redis container. +pub async fn run_in_redis_container(docker: &Docker, command: Vec) -> Result, RedisError> { + let redis_version = env::try_read("REDIS_VERSION")?; + + let redis_container_config = Config { + image: Some(format!("bitnami/redis:{}", redis_version)), + tty: Some(true), + ..Default::default() + }; + debug!("Creating test cli container..."); + let container_id = e!( + docker + .create_container( + Some(CreateContainerOptions { + name: "redis-cli-tmp".to_owned(), + ..Default::default() + }), + redis_container_config, + ) + .await + )? + .id; + debug!("Starting test cli container..."); + e!( + docker + .start_container(&container_id, None::>) + .await + )?; + + let test_network = read_network_name(docker).await?; + debug!("Connecting container to the test network..."); + e!( + docker + .connect_network(&test_network, ConnectNetworkOptions { + container: container_id.clone(), + ..Default::default() + }) + .await + )?; + + debug!("Running command: {:?}", command); + let exec = e!( + docker + .create_exec(&container_id, CreateExecOptions { + attach_stdout: Some(true), + attach_stderr: Some(true), + cmd: Some(command), + ..Default::default() + }) + .await + )? + .id; + let exec_state = e!(docker.start_exec(&exec, None).await)?; + + let mut out = Vec::with_capacity(1024); + if let StartExecResults::Attached { mut output, .. } = exec_state { + while let Some(Ok(msg)) = output.next().await { + match msg { + LogOutput::StdOut { message } => out.extend(&message), + LogOutput::StdErr { message } => { + warn!("stderr from cli container: {}", String::from_utf8_lossy(&message)); + }, + _ => {}, + }; + } + } else { + return Err(RedisError::new(RedisErrorKind::Unknown, "Missing start exec result")); + } + + debug!("Cleaning up cli container..."); + let result = e!( + docker + .remove_container( + &container_id, + Some(RemoveContainerOptions { + force: true, + ..Default::default() + }), + ) + .await + ); + if let Err(e) = result { + error!("Failed to remove cli container: {:?}", e); + } + + Ok(out) +} + +/// Read the cluster state via CLUSTER SLOTS. +// This tries to run: +// +// docker run -it --name redis-cli-tmp --rm --network compose_fred-tests bitnami/redis:7.0.9 +// redis-cli -h redis-cluster-1 -p 30001 -a bar --raw CLUSTER SLOTS +pub async fn inspect_cluster(tls: bool) -> Result { + let docker = e!(Docker::connect_with_http("", 10, API_DEFAULT_VERSION))?; + + debug!("Connected to docker"); + let password = env::try_read("REDIS_PASSWORD")?; + + let cluster_slots: Vec = if tls { + let (host, port) = ( + env::try_read(env::CLUSTER_TLS_HOST)?, + env::try_read(env::CLUSTER_TLS_PORT)?, + ); + + // TODO add ca/cert/key argv + format!( + "redis-cli -h {} -p {} -a {} --raw --tls CLUSTER SLOTS", + host, port, password + ) + .split(' ') + .map(|s| s.to_owned()) + .collect() + } else { + let (host, port) = (env::try_read(env::CLUSTER_HOST)?, env::try_read(env::CLUSTER_PORT)?); + + format!("redis-cli -h {} -p {} -a {} --raw CLUSTER SLOTS", host, port, password) + .split(' ') + .map(|s| s.to_owned()) + .collect() + }; + + let result = run_in_redis_container(&docker, cluster_slots).await?; + debug!("CLUSTER SLOTS response: {}", String::from_utf8_lossy(&result)); + let parsed: RedisValue = match resp2_decode(&Bytes::from(result))? { + Some((frame, _)) => redis_protocol::resp2_frame_to_resp3(frame).try_into()?, + None => { + return Err(RedisError::new( + RedisErrorKind::Unknown, + "Failed to read cluster slots.", + )) + }, + }; + + ClusterRouting::from_cluster_slots(parsed, "") +} diff --git a/tests/integration/geo/mod.rs b/tests/integration/geo/mod.rs index dd18dc79..0bbf9fc7 100644 --- a/tests/integration/geo/mod.rs +++ b/tests/integration/geo/mod.rs @@ -1,5 +1,7 @@ -use fred::prelude::*; -use fred::types::{GeoPosition, GeoRadiusInfo, GeoUnit, GeoValue, SortOrder}; +use fred::{ + prelude::*, + types::{GeoPosition, GeoRadiusInfo, GeoUnit, GeoValue, SortOrder}, +}; use std::convert::TryInto; fn loose_eq(lhs: f64, rhs: f64, precision: u32) -> bool { @@ -20,7 +22,7 @@ async fn create_fake_data(client: &RedisClient, key: &str) -> Result Resul pub async fn should_geopos_values(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { let expected = create_fake_data(&client, "foo").await?; - let result = client.geopos("foo", vec!["Palermo", "Catania"]).await?; + let result: RedisValue = client.geopos("foo", vec!["Palermo", "Catania"]).await?; let result: Vec = result .into_array() .into_iter() @@ -69,12 +71,12 @@ pub async fn should_geopos_values(client: RedisClient, _: RedisConfig) -> Result } } - let result = client.geopos("foo", "Palermo").await?; - let result = result.as_geo_position().unwrap().unwrap(); + let result: Vec = client.geopos("foo", "Palermo").await?; + let result = result[0].as_geo_position().unwrap().unwrap(); assert!(loose_eq_pos(&result, &expected[0])); - let result = client.geopos("foo", "Catania").await?; - let result = result.as_geo_position().unwrap().unwrap(); + let result: Vec = client.geopos("foo", "Catania").await?; + let result = result[0].as_geo_position().unwrap().unwrap(); assert!(loose_eq_pos(&result, &expected[1])); Ok(()) @@ -101,7 +103,7 @@ pub async fn should_georadius_values(client: RedisClient, _: RedisConfig) -> Res let _ = create_fake_data(&client, "foo").await?; let result = client - .georadius( + .georadius::( "foo", (15.0, 37.0), 200.0, @@ -114,25 +116,26 @@ pub async fn should_georadius_values(client: RedisClient, _: RedisConfig) -> Res None, None, ) - .await?; + .await? + .into_geo_radius_result(false, true, false)?; let expected: Vec = vec![ GeoRadiusInfo { - member: "Palermo".into(), + member: "Palermo".into(), distance: Some(190.4424), position: None, - hash: None, + hash: None, }, GeoRadiusInfo { - member: "Catania".into(), + member: "Catania".into(), distance: Some(56.4413), position: None, - hash: None, + hash: None, }, ]; assert_eq!(result, expected); let result = client - .georadius( + .georadius::( "foo", (15.0, 37.0), 200.0, @@ -145,25 +148,26 @@ pub async fn should_georadius_values(client: RedisClient, _: RedisConfig) -> Res None, None, ) - .await?; + .await? + .into_geo_radius_result(true, false, false)?; let expected: Vec = vec![ GeoRadiusInfo { - member: "Palermo".into(), + member: "Palermo".into(), distance: None, - position: Some((13.36138933897018433, 38.11555639549629859).into()), - hash: None, + position: Some((13.361_389_338_970_184, 38.115_556_395_496_3).into()), + hash: None, }, GeoRadiusInfo { - member: "Catania".into(), + member: "Catania".into(), distance: None, - position: Some((15.08726745843887329, 37.50266842333162032).into()), - hash: None, + position: Some((15.087_267_458_438_873, 37.502_668_423_331_62).into()), + hash: None, }, ]; assert_eq!(result, expected); let result = client - .georadius( + .georadius::( "foo", (15.0, 37.0), 200.0, @@ -176,19 +180,20 @@ pub async fn should_georadius_values(client: RedisClient, _: RedisConfig) -> Res None, None, ) - .await?; + .await? + .into_geo_radius_result(true, true, false)?; let expected: Vec = vec![ GeoRadiusInfo { - member: "Palermo".into(), + member: "Palermo".into(), distance: Some(190.4424), - position: Some((13.36138933897018433, 38.11555639549629859).into()), - hash: None, + position: Some((13.361_389_338_970_184, 38.115_556_395_496_3).into()), + hash: None, }, GeoRadiusInfo { - member: "Catania".into(), + member: "Catania".into(), distance: Some(56.4413), - position: Some((15.08726745843887329, 37.50266842333162032).into()), - hash: None, + position: Some((15.087_267_458_438_873, 37.502_668_423_331_62).into()), + hash: None, }, ]; assert_eq!(result, expected); @@ -199,10 +204,10 @@ pub async fn should_georadius_values(client: RedisClient, _: RedisConfig) -> Res pub async fn should_georadiusbymember_values(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { let _ = create_fake_data(&client, "foo").await?; let agrigento: GeoValue = (13.583333, 37.316667, "Agrigento").try_into()?; - let _ = client.geoadd("foo", None, false, agrigento).await?; + client.geoadd("foo", None, false, agrigento).await?; let result = client - .georadiusbymember( + .georadiusbymember::( "foo", "Agrigento", 100.0, @@ -215,19 +220,20 @@ pub async fn should_georadiusbymember_values(client: RedisClient, _: RedisConfig None, None, ) - .await?; + .await? + .into_geo_radius_result(false, false, false)?; let expected = vec![ GeoRadiusInfo { - member: "Agrigento".into(), + member: "Agrigento".into(), distance: None, position: None, - hash: None, + hash: None, }, GeoRadiusInfo { - member: "Palermo".into(), + member: "Palermo".into(), distance: None, position: None, - hash: None, + hash: None, }, ]; assert_eq!(result, expected); @@ -241,11 +247,11 @@ pub async fn should_geosearch_values(client: RedisClient, _: RedisConfig) -> Res (12.758489, 38.788135, "edge1").try_into()?, (17.241510, 38.788135, "edge2").try_into()?, ]; - let _ = client.geoadd("foo", None, false, values).await?; + client.geoadd("foo", None, false, values).await?; let lonlat: GeoPosition = (15.0, 37.0).into(); let result = client - .geosearch( + .geosearch::( "foo", None, Some(lonlat.clone()), @@ -257,25 +263,26 @@ pub async fn should_geosearch_values(client: RedisClient, _: RedisConfig) -> Res false, false, ) - .await?; + .await? + .into_geo_radius_result(false, false, false)?; let expected = vec![ GeoRadiusInfo { - member: "Catania".into(), + member: "Catania".into(), distance: None, position: None, - hash: None, + hash: None, }, GeoRadiusInfo { - member: "Palermo".into(), + member: "Palermo".into(), distance: None, position: None, - hash: None, + hash: None, }, ]; assert_eq!(result, expected); let result = client - .geosearch( + .geosearch::( "foo", None, Some(lonlat), @@ -287,31 +294,32 @@ pub async fn should_geosearch_values(client: RedisClient, _: RedisConfig) -> Res true, false, ) - .await?; + .await? + .into_geo_radius_result(true, true, false)?; let expected = vec![ GeoRadiusInfo { - member: "Catania".into(), + member: "Catania".into(), distance: Some(56.4413), - position: Some((15.08726745843887329, 37.50266842333162032).into()), - hash: None, + position: Some((15.087_267_458_438_873, 37.502_668_423_331_62).into()), + hash: None, }, GeoRadiusInfo { - member: "Palermo".into(), + member: "Palermo".into(), distance: Some(190.4424), - position: Some((13.36138933897018433, 38.11555639549629859).into()), - hash: None, + position: Some((13.361_389_338_970_184, 38.115_556_395_496_3).into()), + hash: None, }, GeoRadiusInfo { - member: "edge2".into(), + member: "edge2".into(), distance: Some(279.7403), - position: Some((17.24151045083999634, 38.78813451624225195).into()), - hash: None, + position: Some((17.241_510_450_839_996, 38.788_134_516_242_25).into()), + hash: None, }, GeoRadiusInfo { - member: "edge1".into(), + member: "edge1".into(), distance: Some(279.7405), - position: Some((12.7584877610206604, 38.78813451624225195).into()), - hash: None, + position: Some((12.758_487_761_020_66, 38.788_134_516_242_25).into()), + hash: None, }, ]; assert_eq!(result, expected); diff --git a/tests/integration/hashes/mod.rs b/tests/integration/hashes/mod.rs index b3a8a4c8..c4e36f6a 100644 --- a/tests/integration/hashes/mod.rs +++ b/tests/integration/hashes/mod.rs @@ -1,10 +1,12 @@ -use fred::clients::RedisClient; -use fred::error::RedisError; -use fred::interfaces::*; -use fred::types::{RedisConfig, RedisValue}; +use fred::{ + clients::RedisClient, + error::RedisError, + interfaces::*, + types::{RedisConfig, RedisValue}, +}; use std::collections::{HashMap, HashSet}; -fn assert_contains<'a, T: Eq + PartialEq>(values: Vec, item: &'a T) { +fn assert_contains(values: Vec, item: &T) { for value in values.iter() { if value == item { return; @@ -54,7 +56,7 @@ pub async fn should_hset_and_hget(client: RedisClient, _: RedisConfig) -> Result pub async fn should_hset_and_hdel(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { check_null!(client, "foo"); - let result: i64 = client.hset("foo", vec![("a", 1.into()), ("b", 2), ("c", 3)]).await?; + let result: i64 = client.hset("foo", vec![("a", 1), ("b", 2), ("c", 3)]).await?; assert_eq!(result, 3); let result: i64 = client.hdel("foo", vec!["a", "b"]).await?; assert_eq!(result, 2); @@ -69,7 +71,7 @@ pub async fn should_hset_and_hdel(client: RedisClient, _: RedisConfig) -> Result pub async fn should_hexists(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { check_null!(client, "foo"); - let _: () = client.hset("foo", ("a", 1)).await?; + client.hset("foo", ("a", 1)).await?; let a: bool = client.hexists("foo", "a").await?; assert!(a); let b: bool = client.hexists("foo", "b").await?; @@ -81,7 +83,7 @@ pub async fn should_hexists(client: RedisClient, _: RedisConfig) -> Result<(), R pub async fn should_hgetall(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { check_null!(client, "foo"); - let _: () = client.hset("foo", vec![("a", 1), ("b", 2), ("c", 3)]).await?; + client.hset("foo", vec![("a", 1), ("b", 2), ("c", 3)]).await?; let values: HashMap = client.hgetall("foo").await?; assert_eq!(values.len(), 3); @@ -119,7 +121,7 @@ pub async fn should_hincryby_float(client: RedisClient, _: RedisConfig) -> Resul pub async fn should_get_keys(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { check_null!(client, "foo"); - let _: () = client.hset("foo", vec![("a", 1), ("b", 2), ("c", 3)]).await?; + client.hset("foo", vec![("a", 1), ("b", 2), ("c", 3)]).await?; let keys = client.hkeys("foo").await?; assert_diff_len(vec!["a", "b", "c"], keys, 0); @@ -130,7 +132,7 @@ pub async fn should_get_keys(client: RedisClient, _: RedisConfig) -> Result<(), pub async fn should_hmset(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { check_null!(client, "foo"); - let _: () = client.hmset("foo", vec![("a", 1), ("b", 2), ("c", 3)]).await?; + client.hmset("foo", vec![("a", 1), ("b", 2), ("c", 3)]).await?; let a: i64 = client.hget("foo", "a").await?; assert_eq!(a, 1); @@ -145,7 +147,7 @@ pub async fn should_hmset(client: RedisClient, _: RedisConfig) -> Result<(), Red pub async fn should_hmget(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { check_null!(client, "foo"); - let _: () = client.hmset("foo", vec![("a", 1), ("b", 2), ("c", 3)]).await?; + client.hmset("foo", vec![("a", 1), ("b", 2), ("c", 3)]).await?; let result: Vec = client.hmget("foo", vec!["a", "b"]).await?; assert_eq!(result, vec![1, 2]); @@ -156,9 +158,9 @@ pub async fn should_hmget(client: RedisClient, _: RedisConfig) -> Result<(), Red pub async fn should_hsetnx(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { check_null!(client, "foo"); - let _: () = client.hset("foo", ("a", 1)).await?; + client.hset("foo", ("a", 1)).await?; let result: bool = client.hsetnx("foo", "a", 2).await?; - assert_eq!(result, false); + assert!(!result); let result: i64 = client.hget("foo", "a").await?; assert_eq!(result, 1); let result: bool = client.hsetnx("foo", "b", 2).await?; @@ -172,7 +174,7 @@ pub async fn should_hsetnx(client: RedisClient, _: RedisConfig) -> Result<(), Re pub async fn should_get_random_field(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { check_null!(client, "foo"); - let _: () = client.hmset("foo", vec![("a", 1), ("b", 2), ("c", 3)]).await?; + client.hmset("foo", vec![("a", 1), ("b", 2), ("c", 3)]).await?; let field: String = client.hrandfield("foo", None).await?; assert_contains(vec!["a", "b", "c"], &field.as_str()); @@ -200,7 +202,7 @@ pub async fn should_get_strlen(client: RedisClient, _: RedisConfig) -> Result<() check_null!(client, "foo"); let expected = "abcdefhijklmnopqrstuvwxyz"; - let _: () = client.hset("foo", ("a", expected)).await?; + client.hset("foo", ("a", expected)).await?; let len: usize = client.hstrlen("foo", "a").await?; assert_eq!(len, expected.len()); @@ -211,7 +213,7 @@ pub async fn should_get_strlen(client: RedisClient, _: RedisConfig) -> Result<() pub async fn should_get_values(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { check_null!(client, "foo"); - let _: () = client.hmset("foo", vec![("a", "1"), ("b", "2")]).await?; + client.hmset("foo", vec![("a", "1"), ("b", "2")]).await?; let values: RedisValue = client.hvals("foo").await?; assert_diff_len(vec!["1", "2"], values, 0); diff --git a/tests/integration/hyperloglog/mod.rs b/tests/integration/hyperloglog/mod.rs index 58997884..40bede38 100644 --- a/tests/integration/hyperloglog/mod.rs +++ b/tests/integration/hyperloglog/mod.rs @@ -36,7 +36,7 @@ pub async fn should_pfmerge_elements(client: RedisClient, _: RedisConfig) -> Res let result: i64 = client.pfadd("bar{1}", vec!["c", "d", "e"]).await?; assert_eq!(result, 1); - let _: () = client.pfmerge("baz{1}", vec!["foo{1}", "bar{1}"]).await?; + client.pfmerge("baz{1}", vec!["foo{1}", "bar{1}"]).await?; let result: i64 = client.pfcount("baz{1}").await?; assert_eq!(result, 5); diff --git a/tests/integration/keys/mod.rs b/tests/integration/keys/mod.rs index 1b153b68..45059369 100644 --- a/tests/integration/keys/mod.rs +++ b/tests/integration/keys/mod.rs @@ -1,18 +1,30 @@ +use bytes::Bytes; use fred::{ - clients::RedisClient, + clients::{RedisClient, RedisPool}, error::RedisError, interfaces::*, - pool::RedisPool, types::{Expiration, ReconnectPolicy, RedisConfig, RedisMap, RedisValue}, }; use futures::{pin_mut, StreamExt}; use std::{collections::HashMap, time::Duration}; use tokio::{self, time::sleep}; +#[cfg(feature = "default-nil-types")] +pub async fn should_handle_missing_keys(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { + assert!(client.get::("foo").await?.is_empty()); + Ok(()) +} + +#[cfg(not(feature = "default-nil-types"))] +pub async fn should_handle_missing_keys(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { + assert!(client.get::("foo").await.is_err()); + Ok(()) +} + pub async fn should_set_and_get_a_value(client: RedisClient, _config: RedisConfig) -> Result<(), RedisError> { check_null!(client, "foo"); - let _: () = client.set("foo", "bar", None, None, false).await?; + client.set("foo", "bar", None, None, false).await?; assert_eq!(client.get::("foo").await?, "bar"); Ok(()) @@ -34,7 +46,7 @@ pub async fn should_set_and_del_a_value(client: RedisClient, _config: RedisConfi pub async fn should_set_with_get_argument(client: RedisClient, _config: RedisConfig) -> Result<(), RedisError> { check_null!(client, "foo"); - let _: () = client.set("foo", "bar", None, None, false).await?; + client.set("foo", "bar", None, None, false).await?; let result: String = client.set("foo", "baz", None, None, true).await?; assert_eq!(result, "bar"); @@ -49,9 +61,9 @@ pub async fn should_rename(client: RedisClient, _config: RedisConfig) -> Result< check_null!(client, "{foo}.1"); check_null!(client, "{foo}.2"); - let _: () = client.set("{foo}.1", "baz", None, None, false).await?; + client.set("{foo}.1", "baz", None, None, false).await?; - let _: () = client.rename("{foo}.1", "{foo}.2").await?; + client.rename("{foo}.1", "{foo}.2").await?; let result: String = client.get("{foo}.2").await?; assert_eq!(result, "baz"); check_null!(client, "{foo}.1"); @@ -68,9 +80,9 @@ pub async fn should_renamenx(client: RedisClient, _config: RedisConfig) -> Resul check_null!(client, "{foo}.1"); check_null!(client, "{foo}.2"); - let _: () = client.set("{foo}.1", "baz", None, None, false).await?; + client.set("{foo}.1", "baz", None, None, false).await?; - let _: () = client.renamenx("{foo}.1", "{foo}.2").await?; + client.renamenx("{foo}.1", "{foo}.2").await?; let result: String = client.get("{foo}.2").await?; assert_eq!(result, "baz"); check_null!(client, "{foo}.1"); @@ -78,7 +90,10 @@ pub async fn should_renamenx(client: RedisClient, _config: RedisConfig) -> Resul Ok(()) } -pub async fn should_error_renamenx_does_not_exist(client: RedisClient, _config: RedisConfig) -> Result<(), RedisError> { +pub async fn should_error_renamenx_does_not_exist( + client: RedisClient, + _config: RedisConfig, +) -> Result<(), RedisError> { check_null!(client, "{foo}"); client.renamenx("{foo}", "{foo}.bar").await } @@ -86,10 +101,15 @@ pub async fn should_error_renamenx_does_not_exist(client: RedisClient, _config: pub async fn should_unlink(client: RedisClient, _config: RedisConfig) -> Result<(), RedisError> { check_null!(client, "{foo}1"); - let _: () = client.set("{foo}1", "bar", None, None, false).await?; + client.set("{foo}1", "bar", None, None, false).await?; assert_eq!(client.get::("{foo}1").await?, "bar"); - assert_eq!(client.unlink::(vec!["{foo}1", "{foo}", "{foo}:something"]).await?, 1); + assert_eq!( + client + .unlink::(vec!["{foo}1", "{foo}", "{foo}:something"]) + .await?, + 1 + ); check_null!(client, "{foo}1"); Ok(()) @@ -134,7 +154,7 @@ pub async fn should_mset_a_non_empty_map(client: RedisClient, _config: RedisConf map.insert("b{1}".into(), 2.into()); map.insert("c{1}".into(), 3.into()); - let _ = client.mset(map).await?; + client.mset(map).await?; let a: i64 = client.get("a{1}").await?; let b: i64 = client.get("b{1}").await?; let c: i64 = client.get("c{1}").await?; @@ -153,9 +173,9 @@ pub async fn should_error_mset_empty_map(client: RedisClient, _config: RedisConf pub async fn should_expire_key(client: RedisClient, _config: RedisConfig) -> Result<(), RedisError> { check_null!(client, "foo"); - let _: () = client.set("foo", "bar", None, None, false).await?; + client.set("foo", "bar", None, None, false).await?; - let _: () = client.expire("foo", 1).await?; + client.expire("foo", 1).await?; sleep(Duration::from_millis(1500)).await; let foo: Option = client.get("foo").await?; assert!(foo.is_none()); @@ -165,7 +185,7 @@ pub async fn should_expire_key(client: RedisClient, _config: RedisConfig) -> Res pub async fn should_persist_key(client: RedisClient, _config: RedisConfig) -> Result<(), RedisError> { check_null!(client, "foo"); - let _: () = client.set("foo", "bar", Some(Expiration::EX(5)), None, false).await?; + client.set("foo", "bar", Some(Expiration::EX(5)), None, false).await?; let removed: bool = client.persist("foo").await?; assert!(removed); @@ -178,7 +198,7 @@ pub async fn should_persist_key(client: RedisClient, _config: RedisConfig) -> Re pub async fn should_check_ttl(client: RedisClient, _config: RedisConfig) -> Result<(), RedisError> { check_null!(client, "foo"); - let _: () = client.set("foo", "bar", Some(Expiration::EX(5)), None, false).await?; + client.set("foo", "bar", Some(Expiration::EX(5)), None, false).await?; let ttl: i64 = client.ttl("foo").await?; assert!(ttl > 0 && ttl < 6); @@ -188,7 +208,7 @@ pub async fn should_check_ttl(client: RedisClient, _config: RedisConfig) -> Resu pub async fn should_check_pttl(client: RedisClient, _config: RedisConfig) -> Result<(), RedisError> { check_null!(client, "foo"); - let _: () = client.set("foo", "bar", Some(Expiration::EX(5)), None, false).await?; + client.set("foo", "bar", Some(Expiration::EX(5)), None, false).await?; let ttl: i64 = client.pttl("foo").await?; assert!(ttl > 0 && ttl < 5001); @@ -199,8 +219,8 @@ pub async fn should_check_pttl(client: RedisClient, _config: RedisConfig) -> Res pub async fn should_dump_key(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { check_null!(client, "foo"); - let _: () = client.set("foo", "abc123", None, None, false).await?; - let dump = client.dump("foo").await?; + client.set("foo", "abc123", None, None, false).await?; + let dump: RedisValue = client.dump("foo").await?; assert!(dump.is_bytes()); Ok(()) @@ -210,11 +230,11 @@ pub async fn should_dump_and_restore_key(client: RedisClient, _: RedisConfig) -> check_null!(client, "foo"); let expected = "abc123"; - let _: () = client.set("foo", expected, None, None, false).await?; + client.set("foo", expected, None, None, false).await?; let dump = client.dump("foo").await?; - let _: () = client.del("foo").await?; + client.del("foo").await?; - let _ = client.restore("foo", 0, dump, false, false, None, None).await?; + client.restore("foo", 0, dump, false, false, None, None).await?; let value: String = client.get("foo").await?; assert_eq!(value, expected); @@ -223,12 +243,12 @@ pub async fn should_dump_and_restore_key(client: RedisClient, _: RedisConfig) -> pub async fn should_modify_ranges(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { check_null!(client, "foo"); - let _: () = client.set("foo", "0123456789", None, None, false).await?; + client.set("foo", "0123456789", None, None, false).await?; let range: String = client.getrange("foo", 0, 4).await?; assert_eq!(range, "01234"); - let _: () = client.setrange("foo", 4, "abc").await?; + client.setrange("foo", 4, "abc").await?; let value: String = client.get("foo").await?; assert_eq!(value, "0123abc789"); @@ -254,7 +274,7 @@ pub async fn should_getdel_value(client: RedisClient, _: RedisConfig) -> Result< let value: Option = client.getdel("foo").await?; assert!(value.is_none()); - let _: () = client.set("foo", "bar", None, None, false).await?; + client.set("foo", "bar", None, None, false).await?; let value: String = client.getdel("foo").await?; assert_eq!(value, "bar"); let value: Option = client.get("foo").await?; @@ -267,7 +287,7 @@ pub async fn should_get_strlen(client: RedisClient, _: RedisConfig) -> Result<() check_null!(client, "foo"); let expected = "abcdefghijklmnopqrstuvwxyz"; - let _: () = client.set("foo", expected, None, None, false).await?; + client.set("foo", expected, None, None, false).await?; let len: usize = client.strlen("foo").await?; assert_eq!(len, expected.len()); @@ -281,7 +301,7 @@ pub async fn should_mget_values(client: RedisClient, _: RedisConfig) -> Result<( let expected: Vec<(&str, RedisValue)> = vec![("a{1}", 1.into()), ("b{1}", 2.into()), ("c{1}", 3.into())]; for (key, value) in expected.iter() { - let _: () = client.set(*key, value.clone(), None, None, false).await?; + client.set(*key, value.clone(), None, None, false).await?; } let values: Vec = client.mget(vec!["a{1}", "b{1}", "c{1}"]).await?; assert_eq!(values, vec![1, 2, 3]); @@ -303,8 +323,8 @@ pub async fn should_msetnx_values(client: RedisClient, _: RedisConfig) -> Result assert_eq!(a, 1); assert_eq!(b, 2); - let _: () = client.del(vec!["a{1}", "b{1}"]).await?; - let _: () = client.set("a{1}", 3, None, None, false).await?; + client.del(vec!["a{1}", "b{1}"]).await?; + client.set("a{1}", 3, None, None, false).await?; let values: i64 = client.msetnx(expected.clone()).await?; assert_eq!(values, 0); @@ -318,14 +338,14 @@ pub async fn should_copy_values(client: RedisClient, _: RedisConfig) -> Result<( check_null!(client, "a{1}"); check_null!(client, "b{1}"); - let _: () = client.set("a{1}", "bar", None, None, false).await?; + client.set("a{1}", "bar", None, None, false).await?; let result: i64 = client.copy("a{1}", "b{1}", None, false).await?; assert_eq!(result, 1); let b: String = client.get("b{1}").await?; assert_eq!(b, "bar"); - let _: () = client.set("a{1}", "baz", None, None, false).await?; + client.set("a{1}", "baz", None, None, false).await?; let result: i64 = client.copy("a{1}", "b{1}", None, false).await?; assert_eq!(result, 0); @@ -341,11 +361,11 @@ pub async fn should_get_keys_from_pool_in_a_stream( client: RedisClient, config: RedisConfig, ) -> Result<(), RedisError> { - let _ = client.set("foo", "bar", None, None, false).await?; + client.set("foo", "bar", None, None, false).await?; - let pool = RedisPool::new(config, None, None, 5)?; - let _ = pool.connect(); - let _ = pool.wait_for_connect().await?; + let pool = RedisPool::new(config, None, None, None, 5)?; + pool.connect(); + pool.wait_for_connect().await?; let stream = tokio_stream::wrappers::IntervalStream::new(tokio::time::interval(Duration::from_millis(100))).then(move |_| { diff --git a/tests/integration/lists/mod.rs b/tests/integration/lists/mod.rs index 429c9c97..0e235936 100644 --- a/tests/integration/lists/mod.rs +++ b/tests/integration/lists/mod.rs @@ -11,7 +11,7 @@ const COUNT: i64 = 10; async fn create_count_data(client: &RedisClient, key: &str) -> Result, RedisError> { let mut values = Vec::with_capacity(COUNT as usize); for idx in 0 .. COUNT { - let _: () = client.rpush(key, idx).await?; + client.rpush(key, idx).await?; values.push(idx.to_string().into()); } @@ -20,8 +20,8 @@ async fn create_count_data(client: &RedisClient, key: &str) -> Result Result<(), RedisError> { let publisher = client.clone_new(); - let _ = publisher.connect(); - let _ = publisher.wait_for_connect().await?; + publisher.connect(); + publisher.wait_for_connect().await?; let jh = tokio::spawn(async move { for idx in 0 .. COUNT { @@ -46,8 +46,8 @@ pub async fn should_blpop_values(client: RedisClient, _: RedisConfig) -> Result< pub async fn should_brpop_values(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { let publisher = client.clone_new(); - let _ = publisher.connect(); - let _ = publisher.wait_for_connect().await?; + publisher.connect(); + publisher.wait_for_connect().await?; let jh = tokio::spawn(async move { for idx in 0 .. COUNT { @@ -72,8 +72,8 @@ pub async fn should_brpop_values(client: RedisClient, _: RedisConfig) -> Result< pub async fn should_brpoplpush_values(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { let publisher = client.clone_new(); - let _ = publisher.connect(); - let _ = publisher.wait_for_connect().await?; + publisher.connect(); + publisher.wait_for_connect().await?; let jh = tokio::spawn(async move { for idx in 0 .. COUNT { @@ -100,8 +100,8 @@ pub async fn should_brpoplpush_values(client: RedisClient, _: RedisConfig) -> Re pub async fn should_blmove_values(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { let publisher = client.clone_new(); - let _ = publisher.connect(); - let _ = publisher.wait_for_connect().await?; + publisher.connect(); + publisher.wait_for_connect().await?; let jh = tokio::spawn(async move { for idx in 0 .. COUNT { @@ -145,7 +145,7 @@ pub async fn should_linsert_values(client: RedisClient, _: RedisConfig) -> Resul let result: usize = client.llen("foo").await?; assert_eq!(result, 0); - let _: () = client.lpush("foo", 0).await?; + client.lpush("foo", 0).await?; let mut expected: Vec = vec!["0".into()]; for idx in 1 .. COUNT { let result: i64 = client.linsert("foo", ListLocation::After, idx - 1, idx).await?; @@ -219,7 +219,7 @@ pub async fn should_lpushx_values(client: RedisClient, _: RedisConfig) -> Result let result: i64 = client.lpushx("foo", 0).await?; assert_eq!(result, 0); - let _: () = client.lpush("foo", 0).await?; + client.lpush("foo", 0).await?; for idx in 0 .. COUNT { let result: i64 = client.lpushx("foo", idx).await?; assert_eq!(result, idx + 2); @@ -273,7 +273,7 @@ pub async fn should_lset_values(client: RedisClient, _: RedisConfig) -> Result<( expected.reverse(); for idx in 0 .. COUNT { - let _: () = client.lset("foo", idx, COUNT - (idx + 1)).await?; + client.lset("foo", idx, COUNT - (idx + 1)).await?; } let result: Vec = client.lrange("foo", 0, COUNT).await?; assert_eq!(result, expected); @@ -284,16 +284,16 @@ pub async fn should_lset_values(client: RedisClient, _: RedisConfig) -> Result<( pub async fn should_ltrim_values(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { let expected = create_count_data(&client, "foo").await?; - let _: () = client.ltrim("foo", 0, COUNT).await?; + client.ltrim("foo", 0, COUNT).await?; let result: Vec = client.lrange("foo", 0, COUNT).await?; assert_eq!(result, expected); for idx in 0 .. COUNT { - let _: () = client.ltrim("foo", 0, idx).await?; + client.ltrim("foo", 0, idx).await?; let result: Vec = client.lrange("foo", 0, COUNT).await?; assert_eq!(result, expected[0 .. (idx + 1) as usize]); - let _: () = client.del("foo").await?; + client.del("foo").await?; let _ = create_count_data(&client, "foo").await?; } @@ -361,7 +361,7 @@ pub async fn should_rpushx_values(client: RedisClient, _: RedisConfig) -> Result let result: i64 = client.rpushx("foo", 0).await?; assert_eq!(result, 0); - let _: () = client.rpush("foo", 0).await?; + client.rpush("foo", 0).await?; for idx in 0 .. COUNT { let result: i64 = client.rpushx("foo", idx).await?; assert_eq!(result, idx + 2); diff --git a/tests/integration/lua/mod.rs b/tests/integration/lua/mod.rs index c37ecf1a..2212b2b3 100644 --- a/tests/integration/lua/mod.rs +++ b/tests/integration/lua/mod.rs @@ -9,9 +9,11 @@ use std::{ ops::Deref, }; -static ECHO_SCRIPT: &'static str = "return {KEYS[1],KEYS[2],ARGV[1],ARGV[2]}"; -static GET_SCRIPT: &'static str = "return redis.call('get', KEYS[1])"; +static ECHO_SCRIPT: &str = "return {KEYS[1],KEYS[2],ARGV[1],ARGV[2]}"; +#[cfg(feature = "sha-1")] +static GET_SCRIPT: &str = "return redis.call('get', KEYS[1])"; +#[cfg(feature = "sha-1")] pub async fn load_script(client: &RedisClient, script: &str) -> Result { if client.is_clustered() { client.script_load_cluster(script).await @@ -28,6 +30,7 @@ pub async fn flush_scripts(client: &RedisClient) -> Result<(), RedisError> { } } +#[cfg(feature = "sha-1")] pub async fn should_load_script(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { let script_hash = util::sha1_hash(ECHO_SCRIPT); let hash: String = client.script_load(ECHO_SCRIPT).await?; @@ -36,6 +39,7 @@ pub async fn should_load_script(client: RedisClient, _: RedisConfig) -> Result<( Ok(()) } +#[cfg(feature = "sha-1")] pub async fn should_load_script_cluster(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { let script_hash = util::sha1_hash(ECHO_SCRIPT); let hash: String = client.script_load_cluster(ECHO_SCRIPT).await?; @@ -44,16 +48,18 @@ pub async fn should_load_script_cluster(client: RedisClient, _: RedisConfig) -> Ok(()) } +#[cfg(feature = "sha-1")] pub async fn should_evalsha_echo_script(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { let hash = load_script(&client, ECHO_SCRIPT).await?; let result: Vec = client.evalsha(hash, vec!["a{1}", "b{1}"], vec!["c{1}", "d{1}"]).await?; assert_eq!(result, vec!["a{1}", "b{1}", "c{1}", "d{1}"]); - let _ = flush_scripts(&client).await?; + flush_scripts(&client).await?; Ok(()) } +#[cfg(feature = "sha-1")] pub async fn should_evalsha_with_reload_echo_script(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { let script = Script::from_lua(ECHO_SCRIPT); @@ -62,23 +68,24 @@ pub async fn should_evalsha_with_reload_echo_script(client: RedisClient, _: Redi .await?; assert_eq!(result, vec!["a{1}", "b{1}", "c{1}", "d{1}"]); - let _ = flush_scripts(&client).await?; + flush_scripts(&client).await?; Ok(()) } +#[cfg(feature = "sha-1")] pub async fn should_evalsha_get_script(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { let script_hash = util::sha1_hash(GET_SCRIPT); let hash = load_script(&client, GET_SCRIPT).await?; assert_eq!(hash, script_hash); - let result: Option = client.evalsha(&script_hash, vec!["foo"], None).await?; + let result: Option = client.evalsha(&script_hash, vec!["foo"], ()).await?; assert!(result.is_none()); - let _: () = client.set("foo", "bar", None, None, false).await?; - let result: String = client.evalsha(&script_hash, vec!["foo"], None).await?; + client.set("foo", "bar", None, None, false).await?; + let result: String = client.evalsha(&script_hash, vec!["foo"], ()).await?; assert_eq!(result, "bar"); - let _ = flush_scripts(&client).await?; + flush_scripts(&client).await?; Ok(()) } @@ -88,26 +95,27 @@ pub async fn should_eval_echo_script(client: RedisClient, _: RedisConfig) -> Res .await?; assert_eq!(result, vec!["a{1}", "b{1}", "c{1}", "d{1}"]); - let _ = flush_scripts(&client).await?; + flush_scripts(&client).await?; Ok(()) } +#[cfg(feature = "sha-1")] pub async fn should_eval_get_script(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { - let result: Option = client.eval(GET_SCRIPT, vec!["foo"], None).await?; + let result: Option = client.eval(GET_SCRIPT, vec!["foo"], ()).await?; assert!(result.is_none()); let hash = util::sha1_hash(GET_SCRIPT); - let result: Option = client.evalsha(&hash, vec!["foo"], None).await?; + let result: Option = client.evalsha(&hash, vec!["foo"], ()).await?; assert!(result.is_none()); - let _: () = client.set("foo", "bar", None, None, false).await?; - let result: String = client.eval(GET_SCRIPT, vec!["foo"], None).await?; + client.set("foo", "bar", None, None, false).await?; + let result: String = client.eval(GET_SCRIPT, vec!["foo"], ()).await?; assert_eq!(result, "bar"); - let result: String = client.evalsha(&hash, vec!["foo"], None).await?; + let result: String = client.evalsha(&hash, vec!["foo"], ()).await?; assert_eq!(result, "bar"); - let _ = flush_scripts(&client).await?; + flush_scripts(&client).await?; Ok(()) } @@ -121,7 +129,7 @@ pub async fn should_function_load_scripts(client: RedisClient, _: RedisConfig) - assert_eq!(echo, "echolib"); let getset: String = client.function_load(true, getset_fn).await?; assert_eq!(getset, "getsetlib"); - let _ = client.function_load_cluster(true, echo_fn).await?; + client.function_load_cluster(true, echo_fn).await?; Ok(()) } @@ -130,11 +138,11 @@ pub async fn should_function_dump_and_restore(client: RedisClient, _: RedisConfi check_redis_7!(client); let echo_fn = include_str!("../../scripts/lua/echo.lua"); - let _ = client.function_load_cluster(true, echo_fn).await?; + client.function_load_cluster(true, echo_fn).await?; let fns: Bytes = client.function_dump().await?; - let _ = client.function_flush_cluster(false).await?; - let _: () = client.function_restore_cluster(fns, FnPolicy::default()).await?; + client.function_flush_cluster(false).await?; + client.function_restore_cluster(fns, FnPolicy::default()).await?; let mut fns: Vec> = client.function_list(Some("echolib"), false).await?; assert_eq!(fns.len(), 1); @@ -148,11 +156,11 @@ pub async fn should_function_flush(client: RedisClient, _: RedisConfig) -> Resul check_redis_7!(client); let echo_fn = include_str!("../../scripts/lua/echo.lua"); - let _ = client.function_load_cluster(true, echo_fn).await?; + client.function_load_cluster(true, echo_fn).await?; let fns: RedisValue = client.function_list(Some("echolib"), false).await?; assert!(!fns.is_null()); - let _ = client.function_flush_cluster(false).await?; + client.function_flush_cluster(false).await?; let fns: RedisValue = client.function_list(Some("echolib"), false).await?; assert!(fns.is_null() || fns.array_len() == Some(0)); @@ -163,11 +171,11 @@ pub async fn should_function_delete(client: RedisClient, _: RedisConfig) -> Resu check_redis_7!(client); let echo_fn = include_str!("../../scripts/lua/echo.lua"); - let _ = client.function_load_cluster(true, echo_fn).await?; + client.function_load_cluster(true, echo_fn).await?; let fns: RedisValue = client.function_list(Some("echolib"), false).await?; assert!(!fns.is_null()); - let _ = client.function_delete_cluster("echolib").await?; + client.function_delete_cluster("echolib").await?; let fns: RedisValue = client.function_list(Some("echolib"), false).await?; assert!(fns.is_null() || fns.array_len() == Some(0)); @@ -178,9 +186,9 @@ pub async fn should_function_list(client: RedisClient, _: RedisConfig) -> Result check_redis_7!(client); let echo_fn = include_str!("../../scripts/lua/echo.lua"); - let _ = client.function_load_cluster(true, echo_fn).await?; + client.function_load_cluster(true, echo_fn).await?; let getset_fn = include_str!("../../scripts/lua/getset.lua"); - let _ = client.function_load_cluster(true, getset_fn).await?; + client.function_load_cluster(true, getset_fn).await?; let mut fns: Vec> = client.function_list(Some("echolib"), false).await?; assert_eq!(fns.len(), 1); @@ -194,9 +202,9 @@ pub async fn should_function_list_multiple(client: RedisClient, _: RedisConfig) check_redis_7!(client); let echo_fn = include_str!("../../scripts/lua/echo.lua"); - let _ = client.function_load_cluster(true, echo_fn).await?; + client.function_load_cluster(true, echo_fn).await?; let getset_fn = include_str!("../../scripts/lua/getset.lua"); - let _ = client.function_load_cluster(true, getset_fn).await?; + client.function_load_cluster(true, getset_fn).await?; let fns: Vec> = client.function_list(None::, false).await?; @@ -223,9 +231,9 @@ pub async fn should_function_fcall_getset(client: RedisClient, _: RedisConfig) - check_redis_7!(client); let getset_fn = include_str!("../../scripts/lua/getset.lua"); - let _ = client.function_load_cluster(true, getset_fn).await?; + client.function_load_cluster(true, getset_fn).await?; - let _: () = client.set("foo{1}", "bar", None, None, false).await?; + client.set("foo{1}", "bar", None, None, false).await?; let old: String = client.fcall("getset", vec!["foo{1}"], vec!["baz"]).await?; assert_eq!(old, "bar"); let new: String = client.get("foo{1}").await?; @@ -238,7 +246,7 @@ pub async fn should_function_fcall_echo(client: RedisClient, _: RedisConfig) -> check_redis_7!(client); let echo_fn = include_str!("../../scripts/lua/echo.lua"); - let _ = client.function_load_cluster(true, echo_fn).await?; + client.function_load_cluster(true, echo_fn).await?; let result: Vec = client .fcall("echo", vec!["key1{1}", "key2{1}"], vec!["arg1", "arg2"]) @@ -252,7 +260,7 @@ pub async fn should_function_fcall_ro_echo(client: RedisClient, _: RedisConfig) check_redis_7!(client); let echo_fn = include_str!("../../scripts/lua/echo.lua"); - let _ = client.function_load_cluster(true, echo_fn).await?; + client.function_load_cluster(true, echo_fn).await?; let result: Vec = client .fcall_ro("echo", vec!["key1{1}", "key2{1}"], vec!["arg1", "arg2"]) @@ -262,12 +270,13 @@ pub async fn should_function_fcall_ro_echo(client: RedisClient, _: RedisConfig) Ok(()) } +#[cfg(feature = "sha-1")] pub async fn should_create_lua_script_helper_from_code( client: RedisClient, _: RedisConfig, ) -> Result<(), RedisError> { let script = Script::from_lua(ECHO_SCRIPT); - let _ = script.load(&client).await?; + script.load(&client).await?; let result: Vec = script .evalsha(&client, vec!["foo{1}", "bar{1}"], vec!["3", "4"]) @@ -276,6 +285,7 @@ pub async fn should_create_lua_script_helper_from_code( Ok(()) } +#[cfg(feature = "sha-1")] pub async fn should_create_lua_script_helper_from_hash( client: RedisClient, _: RedisConfig, @@ -306,7 +316,7 @@ pub async fn should_create_function_from_code(client: RedisClient, _: RedisConfi pub async fn should_create_function_from_name(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { check_redis_7!(client); let echo_lib = include_str!("../../scripts/lua/echo.lua"); - let _: () = client.function_load_cluster(true, echo_lib).await?; + client.function_load_cluster(true, echo_lib).await?; let lib = Library::from_name(&client, "echolib").await?; let func = lib.functions().get("echo").expect("Failed to read echo function"); diff --git a/tests/integration/memory/mod.rs b/tests/integration/memory/mod.rs index 2a751c1a..a951fa3e 100644 --- a/tests/integration/memory/mod.rs +++ b/tests/integration/memory/mod.rs @@ -1,31 +1,30 @@ -use fred::prelude::*; +use fred::{prelude::*, types::MemoryStats}; pub async fn should_run_memory_doctor(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { - let _ = client.memory_doctor().await?; + client.memory_doctor().await?; Ok(()) } pub async fn should_run_memory_malloc_stats(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { - let _ = client.memory_malloc_stats().await?; + client.memory_malloc_stats().await?; Ok(()) } pub async fn should_run_memory_purge(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { - let _ = client.memory_purge().await?; + client.memory_purge().await?; Ok(()) } pub async fn should_run_memory_stats(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { - let stats = client.memory_stats().await?; + let stats: MemoryStats = client.memory_stats().await?; assert!(stats.total_allocated > 0); Ok(()) } pub async fn should_run_memory_usage(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { - let _: () = client.set("foo", "bar", None, None, false).await?; - let amt = client.memory_usage("foo", None).await?; - assert!(amt.unwrap() > 0); + client.set("foo", "bar", None, None, false).await?; + assert!(client.memory_usage::("foo", None).await? > 0); Ok(()) } diff --git a/tests/integration/mod.rs b/tests/integration/mod.rs index 200c6c2f..5e39bdf7 100644 --- a/tests/integration/mod.rs +++ b/tests/integration/mod.rs @@ -1,5 +1,6 @@ #[macro_use] pub mod utils; +pub mod docker; mod acl; mod client; @@ -22,22 +23,51 @@ mod slowlog; mod sorted_sets; mod streams; +#[cfg(feature = "redis-json")] +mod redis_json; + #[cfg(feature = "client-tracking")] mod tracking; +#[cfg(not(feature = "mocks"))] pub mod centralized; +#[cfg(not(feature = "mocks"))] pub mod clustered; mod macro_tests { - use fred::{b, s}; + use fred::{cmd, types::ClusterHash}; + use socket2::TcpKeepalive; #[test] - fn should_use_static_str_macro() { - let _s = s!("foo"); + fn should_use_cmd_macro() { + let command = cmd!("GET"); + assert_eq!(command.cmd, "GET"); + assert_eq!(command.cluster_hash, ClusterHash::FirstKey); + assert!(!command.blocking); + let command = cmd!("GET", blocking: true); + assert_eq!(command.cmd, "GET"); + assert_eq!(command.cluster_hash, ClusterHash::FirstKey); + assert!(command.blocking); + let command = cmd!("GET", hash: ClusterHash::FirstValue); + assert_eq!(command.cmd, "GET"); + assert_eq!(command.cluster_hash, ClusterHash::FirstValue); + assert!(!command.blocking); + let command = cmd!("GET", hash: ClusterHash::FirstValue, blocking: true); + assert_eq!(command.cmd, "GET"); + assert_eq!(command.cluster_hash, ClusterHash::FirstValue); + assert!(command.blocking); } +} - #[test] - fn should_use_static_bytes_macro() { - let _b = b!(b"foo"); +mod docker_tests { + use super::*; + + #[tokio::test] + async fn should_read_docker_state() { + // pretty_env_logger::try_init().unwrap(); + // FIXME need a portable way to expose the docker socket + // let routing = docker::inspect_cluster(false).await.unwrap(); + // println!("routing {:?}", routing.slots()); + // panic!("meh"); } } diff --git a/tests/integration/multi/mod.rs b/tests/integration/multi/mod.rs index 352f91a8..5358eb32 100644 --- a/tests/integration/multi/mod.rs +++ b/tests/integration/multi/mod.rs @@ -8,8 +8,8 @@ use fred::{ pub async fn should_run_get_set_trx(client: RedisClient, _config: RedisConfig) -> Result<(), RedisError> { let trx = client.multi(); - let _r1: () = trx.set("foo", "bar", None, None, false).await?; - let _r2: () = trx.get("foo").await?; + trx.set("foo", "bar", None, None, false).await?; + trx.get("foo").await?; let results: Vec = trx.exec(true).await?; assert_eq!(results, vec!["OK", "bar"]); @@ -17,20 +17,20 @@ pub async fn should_run_get_set_trx(client: RedisClient, _config: RedisConfig) - } pub async fn should_run_error_get_set_trx(client: RedisClient, _config: RedisConfig) -> Result<(), RedisError> { - let _: () = client.set("foo", "bar", None, None, false).await?; + client.set("foo", "bar", None, None, false).await?; let trx = client.multi(); - let _: () = trx.incr("foo").await?; - let _: () = trx.exec(true).await?; + trx.incr("foo").await?; + trx.exec(true).await?; Ok(()) } pub async fn should_fail_with_hashslot_error(client: RedisClient, _config: RedisConfig) -> Result<(), RedisError> { let trx = client.multi(); - let _: () = trx.set("foo", "bar", None, None, false).await?; - let _: () = trx.set("bar", "baz", None, None, false).await?; - let _: () = trx.exec(true).await?; + trx.set("foo", "bar", None, None, false).await?; + trx.set("bar", "baz", None, None, false).await?; + trx.exec(true).await?; Ok(()) } diff --git a/tests/integration/other/mod.rs b/tests/integration/other/mod.rs index 7e536cea..9ed43d9a 100644 --- a/tests/integration/other/mod.rs +++ b/tests/integration/other/mod.rs @@ -1,12 +1,20 @@ use super::utils; use async_trait::async_trait; use fred::{ - clients::RedisClient, + clients::{RedisClient, RedisPool}, error::{RedisError, RedisErrorKind}, interfaces::*, - pool::RedisPool, prelude::{Blocking, RedisValue}, - types::{BackpressureConfig, ClientUnblockFlag, PerformanceConfig, RedisConfig, RedisKey, RedisMap, ServerConfig}, + types::{ + BackpressureConfig, + ClientUnblockFlag, + Options, + PerformanceConfig, + RedisConfig, + RedisKey, + RedisMap, + ServerConfig, + }, }; use futures::future::try_join; use parking_lot::RwLock; @@ -25,6 +33,8 @@ use tokio::time::sleep; #[cfg(feature = "subscriber-client")] use fred::clients::SubscriberClient; +#[cfg(feature = "codec")] +use fred::codec::*; #[cfg(feature = "replicas")] use fred::types::ReplicaConfig; #[cfg(feature = "dns")] @@ -36,6 +46,13 @@ use std::net::{IpAddr, SocketAddr}; #[cfg(feature = "dns")] use trust_dns_resolver::{config::*, TokioAsyncResolver}; +#[cfg(feature = "codec")] +use futures::{SinkExt, StreamExt}; +#[cfg(feature = "codec")] +use tokio::net::TcpStream; +#[cfg(feature = "codec")] +use tokio_util::codec::{Decoder, Encoder, Framed}; + fn hash_to_btree(vals: &RedisMap) -> BTreeMap { vals .iter() @@ -53,9 +70,9 @@ pub fn incr_atomic(size: &Arc) -> usize { pub async fn should_smoke_test_from_redis_impl(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { let nested_values: RedisMap = vec![("a", 1), ("b", 2)].try_into()?; - let _ = client.set("foo", "123", None, None, false).await?; - let _ = client.set("baz", "456", None, None, false).await?; - let _ = client.hset("bar", &nested_values).await?; + client.set("foo", "123", None, None, false).await?; + client.set("baz", "456", None, None, false).await?; + client.hset("bar", &nested_values).await?; let foo: usize = client.get("foo").await?; assert_eq!(foo, 123); @@ -83,12 +100,12 @@ pub async fn should_smoke_test_from_redis_impl(client: RedisClient, _: RedisConf pub async fn should_automatically_unblock(_: RedisClient, mut config: RedisConfig) -> Result<(), RedisError> { config.blocking = Blocking::Interrupt; - let client = RedisClient::new(config, None, None); - let _ = client.connect(); - let _ = client.wait_for_connect().await?; + let client = RedisClient::new(config, None, None, None); + client.connect(); + client.wait_for_connect().await?; let unblock_client = client.clone(); - let _ = tokio::spawn(async move { + tokio::spawn(async move { sleep(Duration::from_secs(1)).await; let _: () = unblock_client.ping().await.expect("Failed to ping"); }); @@ -103,7 +120,7 @@ pub async fn should_manually_unblock(client: RedisClient, _: RedisConfig) -> Res let connections_ids = client.connection_ids().await; let unblock_client = client.clone(); - let _ = tokio::spawn(async move { + tokio::spawn(async move { sleep(Duration::from_secs(1)).await; for (_, id) in connections_ids.into_iter() { @@ -121,12 +138,12 @@ pub async fn should_manually_unblock(client: RedisClient, _: RedisConfig) -> Res pub async fn should_error_when_blocked(_: RedisClient, mut config: RedisConfig) -> Result<(), RedisError> { config.blocking = Blocking::Error; - let client = RedisClient::new(config, None, None); - let _ = client.connect(); - let _ = client.wait_for_connect().await?; + let client = RedisClient::new(config, None, None, None); + client.connect(); + client.wait_for_connect().await?; let error_client = client.clone(); - let _ = tokio::spawn(async move { + tokio::spawn(async move { sleep(Duration::from_secs(1)).await; let result = error_client.ping::<()>().await; @@ -187,9 +204,9 @@ pub async fn should_run_flushall_cluster(client: RedisClient, _: RedisConfig) -> let count: i64 = 200; for idx in 0 .. count { - let _: () = client.set(format!("foo-{}", idx), idx, None, None, false).await?; + client.set(format!("foo-{}", idx), idx, None, None, false).await?; } - let _ = client.flushall_cluster().await?; + client.flushall_cluster().await?; for idx in 0 .. count { let value: Option = client.get(format!("foo-{}", idx)).await?; @@ -213,7 +230,7 @@ pub async fn should_safely_change_protocols_repeatedly( return Ok::<_, RedisError>(()); } let foo = String::from("foo"); - let _ = other.incr(&foo).await?; + other.incr(&foo).await?; sleep(Duration::from_millis(10)).await; } }); @@ -225,7 +242,7 @@ pub async fn should_safely_change_protocols_repeatedly( } else { RespVersion::RESP3 }; - let _ = client.hello(version, None).await?; + client.hello(version, None).await?; sleep(Duration::from_millis(500)).await; } let _ = mem::replace(&mut *done.write(), true); @@ -240,16 +257,15 @@ pub async fn should_test_high_concurrency_pool(_: RedisClient, mut config: Redis config.blocking = Blocking::Block; let perf = PerformanceConfig { auto_pipeline: true, - // default_command_timeout_ms: 20_000, backpressure: BackpressureConfig { max_in_flight_commands: 100_000_000, ..Default::default() }, ..Default::default() }; - let pool = RedisPool::new(config, Some(perf), None, 28)?; - let _ = pool.connect(); - let _ = pool.wait_for_connect().await?; + let pool = RedisPool::new(config, Some(perf), None, None, 28)?; + pool.connect(); + pool.wait_for_connect().await?; let num_tasks = 11641; let mut tasks = Vec::with_capacity(num_tasks); @@ -274,7 +290,7 @@ pub async fn should_test_high_concurrency_pool(_: RedisClient, mut config: Redis } } - println!("Task {} finished.", idx); + // println!("Task {} finished.", idx); Ok::<_, RedisError>(()) })); } @@ -336,8 +352,8 @@ pub async fn should_pipeline_last(client: RedisClient, _: RedisConfig) -> Result pub async fn should_pipeline_try_all(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { let pipeline = client.pipeline(); - let _: () = pipeline.incr("foo").await?; - let _: () = pipeline.hgetall("foo").await?; + pipeline.incr("foo").await?; + pipeline.hgetall("foo").await?; let results = pipeline.try_all::().await; assert_eq!(results[0].clone().unwrap(), 1); @@ -350,14 +366,14 @@ pub async fn should_use_all_cluster_nodes_repeatedly(client: RedisClient, _: Red let other = client.clone(); let jh1 = tokio::spawn(async move { for _ in 0 .. 200 { - let _ = other.flushall_cluster().await?; + other.flushall_cluster().await?; } Ok::<_, RedisError>(()) }); let jh2 = tokio::spawn(async move { for _ in 0 .. 200 { - let _ = client.flushall_cluster().await?; + client.flushall_cluster().await?; } Ok::<_, RedisError>(()) @@ -371,7 +387,7 @@ pub async fn should_use_all_cluster_nodes_repeatedly(client: RedisClient, _: Red pub async fn should_use_tracing_get_set(client: RedisClient, mut config: RedisConfig) -> Result<(), RedisError> { config.tracing = TracingConfig::new(true); let (perf, policy) = (client.perf_config(), client.client_reconnect_policy()); - let client = RedisClient::new(config, Some(perf), policy); + let client = RedisClient::new(config, Some(perf), None, policy); let _ = client.connect(); let _ = client.wait_for_connect().await?; @@ -435,7 +451,7 @@ pub async fn should_use_tracing_get_set(client: RedisClient, mut config: RedisCo #[cfg(feature = "subscriber-client")] pub async fn should_ping_with_subscriber_client(client: RedisClient, config: RedisConfig) -> Result<(), RedisError> { let (perf, policy) = (client.perf_config(), client.client_reconnect_policy()); - let client = SubscriberClient::new(config, Some(perf), policy); + let client = SubscriberClient::new(config, Some(perf), None, policy); let _ = client.connect(); let _ = client.wait_for_connect().await?; @@ -458,13 +474,11 @@ pub async fn should_replica_set_and_get(client: RedisClient, _: RedisConfig) -> } #[cfg(feature = "replicas")] -pub async fn should_replica_set_and_get_not_lazy( - client: RedisClient, - mut config: RedisConfig, -) -> Result<(), RedisError> { - let (perf, policy) = (client.perf_config(), client.client_reconnect_policy()); - config.replica.lazy_connections = false; - let client = RedisClient::new(config, Some(perf), policy); +pub async fn should_replica_set_and_get_not_lazy(client: RedisClient, config: RedisConfig) -> Result<(), RedisError> { + let policy = client.client_reconnect_policy(); + let mut connection = client.connection_config().clone(); + connection.replica.lazy_connections = false; + let client = RedisClient::new(config, None, Some(connection), policy); let _ = client.connect(); let _ = client.wait_for_connect().await?; @@ -496,20 +510,19 @@ pub async fn should_pipeline_with_replicas(client: RedisClient, _: RedisConfig) #[cfg(feature = "replicas")] pub async fn should_use_cluster_replica_without_redirection( client: RedisClient, - _: RedisConfig, + config: RedisConfig, ) -> Result<(), RedisError> { - let mut config = client.client_config(); - config.replica = ReplicaConfig { + let mut connection = client.connection_config().clone(); + connection.replica = ReplicaConfig { lazy_connections: true, primary_fallback: false, ignore_reconnection_errors: true, ..ReplicaConfig::default() }; - let mut perf = client.perf_config(); - perf.max_command_attempts = 1; + connection.max_redirections = 0; let policy = client.client_reconnect_policy(); - let client = RedisClient::new(config, Some(perf), policy); + let client = RedisClient::new(config, None, Some(connection), policy); let _ = client.connect(); let _ = client.wait_for_connect().await?; @@ -522,11 +535,125 @@ pub async fn should_use_cluster_replica_without_redirection( pub async fn should_gracefully_quit(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { let client = client.clone_new(); let connection = client.connect(); - let _ = client.wait_for_connect().await?; + client.wait_for_connect().await?; let _: i64 = client.incr("foo").await?; - let _ = client.quit().await?; + client.quit().await?; let _ = connection.await; Ok(()) } + +pub async fn should_support_options_with_pipeline(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { + let options = Options { + timeout: Some(Duration::from_millis(100)), + max_attempts: Some(42), + max_redirections: Some(43), + ..Default::default() + }; + + let pipeline = client.pipeline().with_options(&options); + pipeline.blpop("foo", 2.0).await?; + let results = pipeline.try_all::().await; + assert_eq!(results[0].clone().unwrap_err().kind(), &RedisErrorKind::Timeout); + + Ok(()) +} + +pub async fn should_reuse_pipeline(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { + let pipeline = client.pipeline(); + pipeline.incr("foo").await?; + pipeline.incr("foo").await?; + assert_eq!(pipeline.last::().await?, 2); + assert_eq!(pipeline.last::().await?, 4); + Ok(()) +} + +pub async fn should_support_options_with_trx(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { + let options = Options { + max_attempts: Some(1), + timeout: Some(Duration::from_secs(1)), + ..Default::default() + }; + let trx = client.multi().with_options(&options); + + let _: () = trx.get("foo{1}").await?; + let _: () = trx.set("foo{1}", "bar", None, None, false).await?; + let _: () = trx.get("foo{1}").await?; + let (first, second, third): (Option, bool, String) = trx.exec(true).await?; + + assert_eq!(first, None); + assert_eq!(second, true); + assert_eq!(third, "bar"); + Ok(()) +} + +pub async fn should_manually_connect_twice(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { + let client = client.clone_new(); + let _old_connection = client.connect(); + client.wait_for_connect().await?; + + let _blpop_jh = tokio::spawn({ + let client = client.clone(); + async move { client.blpop::, _>("foo", 5.0).await } + }); + + sleep(Duration::from_millis(100)).await; + let new_connection = client.connect(); + client.wait_for_connect().await?; + + assert_eq!(client.incr::("bar").await?, 1); + client.quit().await?; + let _ = new_connection.await?; + Ok(()) +} + +#[cfg(feature = "codec")] +pub async fn should_use_resp3_codec_example(_: RedisClient, config: RedisConfig) -> Result<(), RedisError> { + let addr = format!("{}", config.server.hosts().first().unwrap()); + let socket = TcpStream::connect(addr).await?; + let mut framed = Framed::new(socket, Resp3::default()); + + let hello = Resp3Frame::Hello { + version: RespVersion::RESP3, + auth: Some(Auth { + username: utils::read_redis_username().into(), + password: utils::read_redis_password().into(), + }), + }; + let echo_foo = resp3_encode_command("ECHO foo"); + + let _ = framed.send(hello).await?; + let response = framed.next().await.unwrap().unwrap(); + assert_eq!(response.kind(), Resp3FrameKind::Map); + + let _ = framed.send(echo_foo).await?; + let response = framed.next().await.unwrap().unwrap(); + assert_eq!(response.as_str().unwrap(), "foo"); + + Ok(()) +} + +#[cfg(feature = "codec")] +pub async fn should_use_resp2_codec_example(_: RedisClient, config: RedisConfig) -> Result<(), RedisError> { + let addr = format!("{}", config.server.hosts().first().unwrap()); + let socket = TcpStream::connect(addr).await?; + let mut framed = Framed::new(socket, Resp2::default()); + + let auth = resp2_encode_command(&format!( + "AUTH {} {}", + utils::read_redis_username(), + utils::read_redis_password() + )); + let echo_foo = resp2_encode_command("ECHO foo"); + + let _ = framed.send(auth).await?; + let response = framed.next().await.unwrap().unwrap(); + assert_eq!(response.as_str().unwrap(), "OK"); + + let _ = framed.send(echo_foo).await?; + let response = framed.next().await.unwrap().unwrap(); + assert_eq!(response.as_str().unwrap(), "foo"); + + Ok(()) +} diff --git a/tests/integration/pool/mod.rs b/tests/integration/pool/mod.rs index 81d875be..9e31828d 100644 --- a/tests/integration/pool/mod.rs +++ b/tests/integration/pool/mod.rs @@ -1,16 +1,21 @@ -use fred::{clients::RedisClient, error::RedisError, interfaces::*, pool::RedisPool, types::RedisConfig}; +use fred::{ + clients::{RedisClient, RedisPool}, + error::RedisError, + interfaces::*, + types::RedisConfig, +}; async fn create_and_ping_pool(config: &RedisConfig, count: usize) -> Result<(), RedisError> { - let pool = RedisPool::new(config.clone(), None, None, count)?; - let _ = pool.connect(); - let _ = pool.wait_for_connect().await?; + let pool = RedisPool::new(config.clone(), None, None, None, count)?; + pool.connect(); + pool.wait_for_connect().await?; for client in pool.clients().iter() { - let _: () = client.ping().await?; + client.ping().await?; } - let _: () = pool.ping().await?; - let _ = pool.quit_pool().await; + pool.ping().await?; + let _ = pool.quit().await; Ok(()) } diff --git a/tests/integration/pubsub/mod.rs b/tests/integration/pubsub/mod.rs index 1e8e1753..c6f33da4 100644 --- a/tests/integration/pubsub/mod.rs +++ b/tests/integration/pubsub/mod.rs @@ -1,19 +1,20 @@ +use super::utils::should_use_sentinel_config; use fred::{interfaces::PubsubInterface, prelude::*}; use futures::{Stream, StreamExt}; use std::{collections::HashMap, time::Duration}; use tokio::time::sleep; -const CHANNEL1: &'static str = "foo"; -const CHANNEL2: &'static str = "bar"; -const CHANNEL3: &'static str = "baz"; -const FAKE_MESSAGE: &'static str = "wibble"; +const CHANNEL1: &str = "foo"; +const CHANNEL2: &str = "bar"; +const CHANNEL3: &str = "baz"; +const FAKE_MESSAGE: &str = "wibble"; const NUM_MESSAGES: i64 = 20; pub async fn should_publish_and_recv_messages(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { let subscriber_client = client.clone_new(); - let _ = subscriber_client.connect(); - let _ = subscriber_client.wait_for_connect().await?; - let _ = subscriber_client.subscribe(CHANNEL1).await?; + subscriber_client.connect(); + subscriber_client.wait_for_connect().await?; + subscriber_client.subscribe(CHANNEL1).await?; let subscriber_jh = tokio::spawn(async move { let mut message_stream = subscriber_client.on_message(); @@ -33,7 +34,7 @@ pub async fn should_publish_and_recv_messages(client: RedisClient, _: RedisConfi sleep(Duration::from_secs(1)).await; for idx in 0 .. NUM_MESSAGES { // https://redis.io/commands/publish#return-value - let _: () = client.publish(CHANNEL1, format!("{}-{}", FAKE_MESSAGE, idx)).await?; + client.publish(CHANNEL1, format!("{}-{}", FAKE_MESSAGE, idx)).await?; // pubsub messages may arrive out of order due to cross-cluster broadcasting sleep(Duration::from_millis(50)).await; @@ -48,9 +49,9 @@ pub async fn should_psubscribe_and_recv_messages(client: RedisClient, _: RedisCo let subscriber_channels = channels.clone(); let subscriber_client = client.clone_new(); - let _ = subscriber_client.connect(); - let _ = subscriber_client.wait_for_connect().await?; - let _ = subscriber_client.psubscribe(channels.clone()).await?; + subscriber_client.connect(); + subscriber_client.wait_for_connect().await?; + subscriber_client.psubscribe(channels.clone()).await?; let subscriber_jh = tokio::spawn(async move { let mut message_stream = subscriber_client.on_message(); @@ -72,7 +73,7 @@ pub async fn should_psubscribe_and_recv_messages(client: RedisClient, _: RedisCo let channel = channels[idx as usize % channels.len()]; // https://redis.io/commands/publish#return-value - let _: () = client.publish(channel, format!("{}-{}", FAKE_MESSAGE, idx)).await?; + client.publish(channel, format!("{}-{}", FAKE_MESSAGE, idx)).await?; // pubsub messages may arrive out of order due to cross-cluster broadcasting sleep(Duration::from_millis(50)).await; @@ -85,11 +86,11 @@ pub async fn should_psubscribe_and_recv_messages(client: RedisClient, _: RedisCo pub async fn should_unsubscribe_from_all(publisher: RedisClient, _: RedisConfig) -> Result<(), RedisError> { let subscriber = publisher.clone_new(); let connection = subscriber.connect(); - let _ = subscriber.wait_for_connect().await?; - let _ = subscriber.subscribe(vec![CHANNEL1, CHANNEL2, CHANNEL3]).await?; + subscriber.wait_for_connect().await?; + subscriber.subscribe(vec![CHANNEL1, CHANNEL2, CHANNEL3]).await?; let mut message_stream = subscriber.on_message(); - let _ = tokio::spawn(async move { + tokio::spawn(async move { while let Ok(message) = message_stream.recv().await { // unsubscribe without args will result in 3 messages in this case, and none should show up here panic!("Recv unexpected pubsub message: {:?}", message); @@ -98,7 +99,7 @@ pub async fn should_unsubscribe_from_all(publisher: RedisClient, _: RedisConfig) Ok::<_, RedisError>(()) }); - let _ = subscriber.unsubscribe(()).await?; + subscriber.unsubscribe(()).await?; sleep(Duration::from_secs(1)).await; // do some incr commands to make sure the response buffer is flushed correctly by this point @@ -106,47 +107,47 @@ pub async fn should_unsubscribe_from_all(publisher: RedisClient, _: RedisConfig) assert_eq!(subscriber.incr::("abc{1}").await?, 2); assert_eq!(subscriber.incr::("abc{1}").await?, 3); - let _ = subscriber.quit().await?; + subscriber.quit().await?; let _ = connection.await?; Ok(()) } pub async fn should_get_pubsub_channels(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { let subscriber = client.clone_new(); - let _ = subscriber.connect(); - let _ = subscriber.wait_for_connect().await?; + subscriber.connect(); + subscriber.wait_for_connect().await?; let channels: Vec = client.pubsub_channels("*").await?; - #[cfg(feature = "sentinel-tests")] - assert_eq!(channels.len(), 1); // "__sentinel__:hello" is always there - #[cfg(not(feature = "sentinel-tests"))] - assert!(channels.is_empty()); - - let _: () = subscriber.subscribe("foo").await?; - let _: () = subscriber.subscribe("bar").await?; + let expected_len = if should_use_sentinel_config() { + // "__sentinel__:hello" is always there + 1 + } else { + 0 + }; + assert_eq!(channels.len(), expected_len); + + subscriber.subscribe("foo").await?; + subscriber.subscribe("bar").await?; let mut channels: Vec = client.pubsub_channels("*").await?; channels.sort(); - #[cfg(feature = "sentinel-tests")] - assert_eq!(channels, vec![ - "__sentinel__:hello".into(), - "bar".to_string(), - "foo".to_string() - ]); - #[cfg(not(feature = "sentinel-tests"))] - assert_eq!(channels, vec!["bar".to_string(), "foo".to_string()]); - + let expected = if should_use_sentinel_config() { + vec!["__sentinel__:hello".into(), "bar".to_string(), "foo".to_string()] + } else { + vec!["bar".to_string(), "foo".to_string()] + }; + assert_eq!(channels, expected); Ok(()) } pub async fn should_get_pubsub_numpat(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { let subscriber = client.clone_new(); - let _ = subscriber.connect(); - let _ = subscriber.wait_for_connect().await?; + subscriber.connect(); + subscriber.wait_for_connect().await?; assert_eq!(client.pubsub_numpat::().await?, 0); - let _: () = subscriber.psubscribe("foo*").await?; - let _: () = subscriber.psubscribe("bar*").await?; + subscriber.psubscribe("foo*").await?; + subscriber.psubscribe("bar*").await?; assert_eq!(client.pubsub_numpat::().await?, 2); Ok(()) @@ -154,8 +155,8 @@ pub async fn should_get_pubsub_numpat(client: RedisClient, _: RedisConfig) -> Re pub async fn should_get_pubsub_nunmsub(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { let subscriber = client.clone_new(); - let _ = subscriber.connect(); - let _ = subscriber.wait_for_connect().await?; + subscriber.connect(); + subscriber.wait_for_connect().await?; let mut expected: HashMap = HashMap::new(); expected.insert("foo".into(), 0); @@ -163,8 +164,8 @@ pub async fn should_get_pubsub_nunmsub(client: RedisClient, _: RedisConfig) -> R let channels: HashMap = client.pubsub_numsub(vec!["foo", "bar"]).await?; assert_eq!(channels, expected); - let _: () = subscriber.subscribe("foo").await?; - let _: () = subscriber.subscribe("bar").await?; + subscriber.subscribe("foo").await?; + subscriber.subscribe("bar").await?; let channels: HashMap = client.pubsub_numsub(vec!["foo", "bar"]).await?; let mut expected: HashMap = HashMap::new(); @@ -177,14 +178,14 @@ pub async fn should_get_pubsub_nunmsub(client: RedisClient, _: RedisConfig) -> R pub async fn should_get_pubsub_shard_channels(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { let subscriber = client.clone_new(); - let _ = subscriber.connect(); - let _ = subscriber.wait_for_connect().await?; + subscriber.connect(); + subscriber.wait_for_connect().await?; let channels: Vec = client.pubsub_shardchannels("{1}*").await?; assert!(channels.is_empty()); - let _: () = subscriber.ssubscribe("{1}foo").await?; - let _: () = subscriber.ssubscribe("{1}bar").await?; + subscriber.ssubscribe("{1}foo").await?; + subscriber.ssubscribe("{1}bar").await?; let mut channels: Vec = client.pubsub_shardchannels("{1}*").await?; channels.sort(); @@ -195,8 +196,8 @@ pub async fn should_get_pubsub_shard_channels(client: RedisClient, _: RedisConfi pub async fn should_get_pubsub_shard_numsub(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { let subscriber = client.clone_new(); - let _ = subscriber.connect(); - let _ = subscriber.wait_for_connect().await?; + subscriber.connect(); + subscriber.wait_for_connect().await?; let mut expected: HashMap = HashMap::new(); expected.insert("foo{1}".into(), 0); @@ -204,8 +205,8 @@ pub async fn should_get_pubsub_shard_numsub(client: RedisClient, _: RedisConfig) let channels: HashMap = client.pubsub_shardnumsub(vec!["foo{1}", "bar{1}"]).await?; assert_eq!(channels, expected); - let _: () = subscriber.ssubscribe("foo{1}").await?; - let _: () = subscriber.ssubscribe("bar{1}").await?; + subscriber.ssubscribe("foo{1}").await?; + subscriber.ssubscribe("bar{1}").await?; let channels: HashMap = client.pubsub_shardnumsub(vec!["foo{1}", "bar{1}"]).await?; let mut expected: HashMap = HashMap::new(); diff --git a/tests/integration/redis_json/mod.rs b/tests/integration/redis_json/mod.rs new file mode 100644 index 00000000..890eebbd --- /dev/null +++ b/tests/integration/redis_json/mod.rs @@ -0,0 +1,199 @@ +use fred::{ + clients::RedisClient, + error::RedisError, + interfaces::RedisJsonInterface, + json_quote, + types::{RedisConfig, RedisValue}, + util::NONE, +}; +use serde_json::{json, Value}; + +pub async fn should_get_and_set_basic_obj(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { + let value: Value = client.json_get("foo", NONE, NONE, NONE, "$").await?; + assert_eq!(value, Value::Null); + + let value = json!({ + "a": "b", + "c": 1 + }); + let _: () = client.json_set("foo", "$", value.clone(), None).await?; + let result: Value = client.json_get("foo", NONE, NONE, NONE, "$").await?; + assert_eq!(value, result[0]); + + Ok(()) +} + +pub async fn should_get_and_set_stringified_obj(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { + let value: Value = client.json_get("foo", NONE, NONE, NONE, "$").await?; + assert_eq!(value, Value::Null); + + let value = json!({ + "a": "b", + "c": 1 + }); + let _: () = client + .json_set("foo", "$", serde_json::to_string(&value)?, None) + .await?; + let result: Value = client.json_get("foo", NONE, NONE, NONE, "$").await?; + assert_eq!(value, result[0]); + + Ok(()) +} + +pub async fn should_array_append(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { + let _: () = client.json_set("foo", "$", json!(["a", "b"]), None).await?; + + // need to double quote string values + let size: i64 = client + .json_arrappend("foo", "$", vec![json_quote!("c"), json_quote!("d")]) + .await?; + assert_eq!(size, 4); + let size: i64 = client.json_arrappend("foo", "$", vec![json!({"e": "f"})]).await?; + assert_eq!(size, 5); + let len: i64 = client.json_arrlen("foo", NONE).await?; + assert_eq!(len, 5); + + let result: Value = client.json_get("foo", NONE, NONE, NONE, "$").await?; + assert_eq!(result[0], json!(["a", "b", "c", "d", {"e": "f"}])); + + Ok(()) +} + +pub async fn should_modify_arrays(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { + let _: () = client.json_set("foo", "$", json!(["a", "d"]), None).await?; + let len: i64 = client + .json_arrinsert("foo", "$", 1, vec![json_quote!("b"), json_quote!("c")]) + .await?; + assert_eq!(len, 4); + let idx: usize = client.json_arrindex("foo", "$", json_quote!("b"), None, None).await?; + assert_eq!(idx, 1); + let len: usize = client.json_arrlen("foo", NONE).await?; + assert_eq!(len, 4); + + Ok(()) +} + +pub async fn should_pop_and_trim_arrays(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { + let _: () = client.json_set("foo", "$", json!(["a", "b"]), None).await?; + let val: Value = client.json_arrpop("foo", NONE, None).await?; + assert_eq!(val, json!("b")); + + let _: () = client.json_set("foo", "$", json!(["a", "b", "c", "d"]), None).await?; + let len: usize = client.json_arrtrim("foo", "$", 0, -2).await?; + assert_eq!(len, 3); + + let vals: Value = client.json_get("foo", NONE, NONE, NONE, "$").await?; + assert_eq!(vals[0], json!(["a", "b", "c"])); + + Ok(()) +} + +pub async fn should_get_set_del_obj(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { + let value = json!({ + "a": "b", + "c": 1, + "d": true + }); + let _: () = client.json_set("foo", "$", value.clone(), None).await?; + let result: Value = client.json_get("foo", NONE, NONE, NONE, "$").await?; + assert_eq!(value, result[0]); + + let count: i64 = client.json_del("foo", "$..c").await?; + assert_eq!(count, 1); + + let result: Value = client.json_get("foo", NONE, NONE, NONE, "$").await?; + assert_eq!(result[0], json!({ "a": "b", "d": true })); + + Ok(()) +} + +pub async fn should_merge_objects(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { + let foo = json!({ "a": "b", "c": { "d": "e" } }); + let bar = json!({ "a": "b1", "c": { "d1": "e1" }, "y": "z" }); + let expected = json!({ "a": "b1", "c": {"d": "e", "d1": "e1"}, "y": "z" }); + + let _: () = client.json_set("foo", "$", foo.clone(), None).await?; + let _: () = client.json_merge("foo", "$", bar.clone()).await?; + let merged: Value = client.json_get("foo", NONE, NONE, NONE, "$").await?; + assert_eq!(merged[0], expected); + + Ok(()) +} + +pub async fn should_mset_and_mget(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { + let values = vec![json!({ "a": "b" }), json!({ "c": "d" })]; + let args = vec![("foo{1}", "$", values[0].clone()), ("bar{1}", "$", values[1].clone())]; + let _: () = client.json_mset(args).await?; + + let result: Value = client.json_mget(vec!["foo{1}", "bar{1}"], "$").await?; + // response is nested: Array [Array [Object {"a": String("b")}], Array [Object {"c": String("d")}]] + assert_eq!(result, json!([[values[0]], [values[1]]])); + + Ok(()) +} + +pub async fn should_incr_numbers(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { + let _: () = client.json_set("foo", "$", json!({ "a": 1 }), None).await?; + let vals: Value = client.json_numincrby("foo", "$.a", 2).await?; + assert_eq!(vals[0], 3); + + Ok(()) +} + +pub async fn should_inspect_objects(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { + let value = json!({ + "a": "b", + "e": { + "f": "g", + "h": "i", + "j": [{ "k": "l" }] + } + }); + let _: () = client.json_set("foo", "$", value.clone(), None).await?; + let keys: Vec> = client.json_objkeys("foo", Some("$")).await?; + assert_eq!(keys[0], vec!["a".to_string(), "e".to_string()]); + let keys: Vec> = client.json_objkeys("foo", Some("$.e")).await?; + assert_eq!(keys[0], vec!["f".to_string(), "h".to_string(), "j".to_string()]); + + let len: usize = client.json_objlen("foo", NONE).await?; + assert_eq!(len, 2); + let len: usize = client.json_objlen("foo", Some("$.e")).await?; + assert_eq!(len, 3); + + Ok(()) +} + +pub async fn should_modify_strings(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { + let _: () = client.json_set("foo", "$", json!({ "a": "abc123" }), None).await?; + let len: usize = client.json_strlen("foo", Some("$.a")).await?; + assert_eq!(len, 6); + + let len: usize = client.json_strappend("foo", Some("$.a"), json_quote!("456")).await?; + assert_eq!(len, 9); + let len: usize = client.json_strlen("foo", Some("$.a")).await?; + assert_eq!(len, 9); + let value: Value = client.json_get("foo", NONE, NONE, NONE, "$").await?; + assert_eq!(value[0], json!({ "a": "abc123456" })); + + Ok(()) +} + +pub async fn should_toggle_boolean(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { + let _: () = client.json_set("foo", "$", json!({ "a": 1, "b": true }), None).await?; + let new_val: bool = client.json_toggle("foo", "$.b").await?; + assert_eq!(new_val, false); + + Ok(()) +} + +pub async fn should_get_value_type(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { + let _: () = client.json_set("foo", "$", json!({ "a": 1, "b": true }), None).await?; + let val: String = client.json_type("foo", NONE).await?; + assert_eq!(val, "object"); + let val: String = client.json_type("foo", Some("$.a")).await?; + assert_eq!(val, "integer"); + let val: String = client.json_type("foo", Some("$.b")).await?; + assert_eq!(val, "boolean"); + + Ok(()) +} diff --git a/tests/integration/scanning/mod.rs b/tests/integration/scanning/mod.rs index c0de8ad2..be22a274 100644 --- a/tests/integration/scanning/mod.rs +++ b/tests/integration/scanning/mod.rs @@ -6,7 +6,7 @@ const SCAN_KEYS: i64 = 100; pub async fn should_scan_keyspace(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { for idx in 0 .. SCAN_KEYS { - let _ = client + client .set(format!("foo-{}-{}", idx, "{1}"), idx, None, None, false) .await?; } @@ -19,14 +19,14 @@ pub async fn should_scan_keyspace(client: RedisClient, _: RedisConfig) -> Result // scanning wont return results in any particular order, so we just check the format of the key for key in results.into_iter() { - let parts: Vec<&str> = key.as_str().unwrap().split("-").collect(); + let parts: Vec<&str> = key.as_str().unwrap().split('-').collect(); assert!(parts[1].parse::().is_ok()); } } else { panic!("Empty results in scan."); } - let _ = result.next()?; + result.next()?; Ok(count) }) .await?; @@ -38,7 +38,7 @@ pub async fn should_scan_keyspace(client: RedisClient, _: RedisConfig) -> Result pub async fn should_hscan_hash(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { for idx in 0 .. SCAN_KEYS { let value = (format!("bar-{}", idx), idx); - let _ = client.hset("foo", value).await?; + client.hset("foo", value).await?; } let count = client @@ -49,14 +49,14 @@ pub async fn should_hscan_hash(client: RedisClient, _: RedisConfig) -> Result<() // scanning wont return results in any particular order, so we just check the format of the key for (key, _) in results.iter() { - let parts: Vec<&str> = key.as_str().unwrap().split("-").collect(); + let parts: Vec<&str> = key.as_str().unwrap().split('-').collect(); assert!(parts[1].parse::().is_ok()); } } else { panic!("Empty results in hscan."); } - let _ = result.next()?; + result.next()?; Ok(count) }) .await?; @@ -67,7 +67,7 @@ pub async fn should_hscan_hash(client: RedisClient, _: RedisConfig) -> Result<() pub async fn should_sscan_set(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { for idx in 0 .. SCAN_KEYS { - let _ = client.sadd("foo", idx).await?; + client.sadd("foo", idx).await?; } let count = client @@ -83,7 +83,7 @@ pub async fn should_sscan_set(client: RedisClient, _: RedisConfig) -> Result<(), panic!("Empty sscan result"); } - let _ = result.next()?; + result.next()?; Ok(count) }) .await?; @@ -95,7 +95,7 @@ pub async fn should_sscan_set(client: RedisClient, _: RedisConfig) -> Result<(), pub async fn should_zscan_sorted_set(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { for idx in 0 .. SCAN_KEYS { let (score, value) = (idx as f64, format!("foo-{}", idx)); - let _ = client.zadd("foo", None, None, false, false, (score, value)).await?; + client.zadd("foo", None, None, false, false, (score, value)).await?; } let count = client @@ -106,7 +106,7 @@ pub async fn should_zscan_sorted_set(client: RedisClient, _: RedisConfig) -> Res for (value, score) in results.into_iter() { let value_str = value.as_str().unwrap(); - let parts: Vec<&str> = value_str.split("-").collect(); + let parts: Vec<&str> = value_str.split('-').collect(); let value_suffix = parts[1].parse::().unwrap(); assert_eq!(value_suffix, score); @@ -115,7 +115,7 @@ pub async fn should_zscan_sorted_set(client: RedisClient, _: RedisConfig) -> Res panic!("Empty zscan result"); } - let _ = result.next()?; + result.next()?; Ok(count) }) .await?; @@ -126,7 +126,7 @@ pub async fn should_zscan_sorted_set(client: RedisClient, _: RedisConfig) -> Res pub async fn should_scan_cluster(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { for idx in 0 .. 2000 { - let _: () = client.set(idx, idx, None, None, false).await?; + client.set(idx, idx, None, None, false).await?; } let mut count = 0; diff --git a/tests/integration/server/mod.rs b/tests/integration/server/mod.rs index 767e9ef0..3d1d96b0 100644 --- a/tests/integration/server/mod.rs +++ b/tests/integration/server/mod.rs @@ -3,11 +3,11 @@ use std::time::Duration; use tokio::time::sleep; pub async fn should_flushall(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { - let _ = client.set("foo{1}", "bar", None, None, false).await?; + client.set("foo{1}", "bar", None, None, false).await?; if client.is_clustered() { - let _ = client.flushall_cluster().await?; + client.flushall_cluster().await?; } else { - let _: () = client.flushall(false).await?; + client.flushall(false).await?; }; let result: Option = client.get("foo{1}").await?; @@ -24,7 +24,7 @@ pub async fn should_read_server_info(client: RedisClient, _: RedisConfig) -> Res } pub async fn should_ping_server(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { - let _: () = client.ping().await?; + client.ping().await?; Ok(()) } @@ -44,7 +44,7 @@ pub async fn should_read_last_save(client: RedisClient, _: RedisConfig) -> Resul pub async fn should_read_db_size(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { for idx in 0 .. 50 { - let _: () = client.set(format!("foo-{}", idx), idx, None, None, false).await?; + client.set(format!("foo-{}", idx), idx, None, None, false).await?; } // this is tricky to assert b/c the dbsize command isnt linked to a specific server in the cluster, hence the loop @@ -65,7 +65,7 @@ pub async fn should_start_bgsave(client: RedisClient, _: RedisConfig) -> Result< } pub async fn should_do_bgrewriteaof(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { - let _ = client.bgrewriteaof().await?; + client.bgrewriteaof().await?; // not much we can assert here aside from the command not failing // need to ensure this finishes before it runs again or it'll return an error diff --git a/tests/integration/sets/mod.rs b/tests/integration/sets/mod.rs index 9bcee022..203f4d6b 100644 --- a/tests/integration/sets/mod.rs +++ b/tests/integration/sets/mod.rs @@ -52,8 +52,8 @@ pub async fn should_sdiff_elements(client: RedisClient, _: RedisConfig) -> Resul check_null!(client, "foo{1}"); check_null!(client, "bar{1}"); - let _: () = client.sadd("foo{1}", vec![1, 2, 3, 4, 5, 6]).await?; - let _: () = client.sadd("bar{1}", vec![3, 4, 5, 6, 7, 8]).await?; + client.sadd("foo{1}", vec![1, 2, 3, 4, 5, 6]).await?; + client.sadd("bar{1}", vec![3, 4, 5, 6, 7, 8]).await?; let result: HashSet = client.sdiff(vec!["foo{1}", "bar{1}"]).await?; assert!(sets_eq(&result, &vec_to_set(vec!["1".into(), "2".into()]))); @@ -65,8 +65,8 @@ pub async fn should_sdiffstore_elements(client: RedisClient, _: RedisConfig) -> check_null!(client, "bar{1}"); check_null!(client, "baz{1}"); - let _: () = client.sadd("foo{1}", vec![1, 2, 3, 4, 5, 6]).await?; - let _: () = client.sadd("bar{1}", vec![3, 4, 5, 6, 7, 8]).await?; + client.sadd("foo{1}", vec![1, 2, 3, 4, 5, 6]).await?; + client.sadd("bar{1}", vec![3, 4, 5, 6, 7, 8]).await?; let result: i64 = client.sdiffstore("baz{1}", vec!["foo{1}", "bar{1}"]).await?; assert_eq!(result, 2); let result: HashSet = client.smembers("baz{1}").await?; @@ -80,8 +80,8 @@ pub async fn should_sinter_elements(client: RedisClient, _: RedisConfig) -> Resu check_null!(client, "bar{1}"); check_null!(client, "baz{1}"); - let _: () = client.sadd("foo{1}", vec![1, 2, 3, 4, 5, 6]).await?; - let _: () = client.sadd("bar{1}", vec![3, 4, 5, 6, 7, 8]).await?; + client.sadd("foo{1}", vec![1, 2, 3, 4, 5, 6]).await?; + client.sadd("bar{1}", vec![3, 4, 5, 6, 7, 8]).await?; let result: HashSet = client.sinter(vec!["foo{1}", "bar{1}"]).await?; assert!(sets_eq( @@ -97,8 +97,8 @@ pub async fn should_sinterstore_elements(client: RedisClient, _: RedisConfig) -> check_null!(client, "bar{1}"); check_null!(client, "baz{1}"); - let _: () = client.sadd("foo{1}", vec![1, 2, 3, 4, 5, 6]).await?; - let _: () = client.sadd("bar{1}", vec![3, 4, 5, 6, 7, 8]).await?; + client.sadd("foo{1}", vec![1, 2, 3, 4, 5, 6]).await?; + client.sadd("bar{1}", vec![3, 4, 5, 6, 7, 8]).await?; let result: i64 = client.sinterstore("baz{1}", vec!["foo{1}", "bar{1}"]).await?; assert_eq!(result, 4); let result: HashSet = client.smembers("baz{1}").await?; @@ -113,27 +113,27 @@ pub async fn should_sinterstore_elements(client: RedisClient, _: RedisConfig) -> pub async fn should_check_sismember(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { check_null!(client, "foo"); - let _ = client.sadd("foo", vec![1, 2, 3, 4, 5, 6]).await?; + client.sadd("foo", vec![1, 2, 3, 4, 5, 6]).await?; let result: bool = client.sismember("foo", 1).await?; assert!(result); let result: bool = client.sismember("foo", 7).await?; - assert_eq!(result, false); + assert!(!result); Ok(()) } pub async fn should_check_smismember(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { check_null!(client, "foo"); - let _ = client.sadd("foo", vec![1, 2, 3, 4, 5, 6]).await?; + client.sadd("foo", vec![1, 2, 3, 4, 5, 6]).await?; let result: Vec = client.smismember("foo", vec![1, 2, 7]).await?; - assert_eq!(result[0], true); - assert_eq!(result[1], true); - assert_eq!(result[2], false); + assert!(result[0]); + assert!(result[1]); + assert!(!result[2]); let result: bool = client.sismember("foo", 7).await?; - assert_eq!(result, false); + assert!(!result); Ok(()) } @@ -141,7 +141,7 @@ pub async fn should_check_smismember(client: RedisClient, _: RedisConfig) -> Res pub async fn should_read_smembers(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { check_null!(client, "foo"); - let _: () = client.sadd("foo", vec![1, 2, 3, 4, 5, 6]).await?; + client.sadd("foo", vec![1, 2, 3, 4, 5, 6]).await?; let result: HashSet = client.smembers("foo").await?; assert!(sets_eq( &result, @@ -162,8 +162,8 @@ pub async fn should_smove_elements(client: RedisClient, _: RedisConfig) -> Resul check_null!(client, "foo{1}"); check_null!(client, "bar{1}"); - let _: () = client.sadd("foo{1}", vec![1, 2, 3, 4, 5, 6]).await?; - let _: () = client.sadd("bar{1}", 5).await?; + client.sadd("foo{1}", vec![1, 2, 3, 4, 5, 6]).await?; + client.sadd("bar{1}", 5).await?; let result: i64 = client.smove("foo{1}", "bar{1}", 7).await?; assert_eq!(result, 0); @@ -187,7 +187,7 @@ pub async fn should_spop_elements(client: RedisClient, _: RedisConfig) -> Result check_null!(client, "foo"); let expected = vec_to_set(vec!["1".into(), "2".into(), "3".into()]); - let _: () = client.sadd("foo", vec![1, 2, 3]).await?; + client.sadd("foo", vec![1, 2, 3]).await?; let result = client.spop("foo", None).await?; assert!(expected.contains(&result)); @@ -204,7 +204,7 @@ pub async fn should_get_random_member(client: RedisClient, _: RedisConfig) -> Re check_null!(client, "foo"); let expected = vec_to_set(vec!["1".into(), "2".into(), "3".into()]); - let _: () = client.sadd("foo", vec![1, 2, 3]).await?; + client.sadd("foo", vec![1, 2, 3]).await?; let result = client.srandmember("foo", None).await?; assert!(expected.contains(&result)); @@ -222,7 +222,7 @@ pub async fn should_remove_elements(client: RedisClient, _: RedisConfig) -> Resu let result: i64 = client.srem("foo", 1).await?; assert_eq!(result, 0); - let _: () = client.sadd("foo", vec![1, 2, 3, 4, 5, 6]).await?; + client.sadd("foo", vec![1, 2, 3, 4, 5, 6]).await?; let result: i64 = client.srem("foo", 1).await?; assert_eq!(result, 1); let result: i64 = client.srem("foo", vec![2, 3, 4, 7]).await?; @@ -238,8 +238,8 @@ pub async fn should_sunion_elements(client: RedisClient, _: RedisConfig) -> Resu check_null!(client, "foo{1}"); check_null!(client, "bar{1}"); - let _: () = client.sadd("foo{1}", vec![1, 2, 3, 4, 5, 6]).await?; - let _: () = client.sadd("bar{1}", vec![3, 4, 5, 6, 7, 8]).await?; + client.sadd("foo{1}", vec![1, 2, 3, 4, 5, 6]).await?; + client.sadd("bar{1}", vec![3, 4, 5, 6, 7, 8]).await?; let result: HashSet = client.sunion(vec!["foo{1}", "bar{1}"]).await?; assert!(sets_eq( @@ -264,8 +264,8 @@ pub async fn should_sunionstore_elements(client: RedisClient, _: RedisConfig) -> check_null!(client, "bar{1}"); check_null!(client, "baz{1}"); - let _: () = client.sadd("foo{1}", vec![1, 2, 3, 4, 5, 6]).await?; - let _: () = client.sadd("bar{1}", vec![3, 4, 5, 6, 7, 8]).await?; + client.sadd("foo{1}", vec![1, 2, 3, 4, 5, 6]).await?; + client.sadd("bar{1}", vec![3, 4, 5, 6, 7, 8]).await?; let result: i64 = client.sunionstore("baz{1}", vec!["foo{1}", "bar{1}"]).await?; assert_eq!(result, 8); let result: HashSet = client.smembers("baz{1}").await?; diff --git a/tests/integration/slowlog/mod.rs b/tests/integration/slowlog/mod.rs index 960cc774..4195fa3e 100644 --- a/tests/integration/slowlog/mod.rs +++ b/tests/integration/slowlog/mod.rs @@ -1,17 +1,18 @@ -use fred::prelude::*; +use fred::{prelude::*, types::SlowlogEntry}; pub async fn should_read_slowlog_length(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { - let _ = client.slowlog_length().await?; - // cant assert much here since the tests run in any order, and the call to reset the slowlog might run just before this + client.slowlog_length().await?; + // cant assert much here since the tests run in any order, and the call to reset the slowlog might run just before + // this Ok(()) } pub async fn should_read_slowlog_entries(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { - let entries = client.slowlog_get(Some(10)).await?; + let entries: Vec = client.slowlog_get(Some(10)).await?; for entry in entries.into_iter() { - assert!(entry.duration > 0); + assert!(!entry.duration.is_zero()); assert!(entry.name.is_some()); } @@ -19,8 +20,8 @@ pub async fn should_read_slowlog_entries(client: RedisClient, _: RedisConfig) -> } pub async fn should_reset_slowlog(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { - let _ = client.slowlog_reset().await?; - let len = client.slowlog_length().await?; + client.slowlog_reset().await?; + let len: i64 = client.slowlog_length().await?; // the slowlog length call might show up here assert!(len < 2); diff --git a/tests/integration/sorted_sets/mod.rs b/tests/integration/sorted_sets/mod.rs index 20b9a358..e230500d 100644 --- a/tests/integration/sorted_sets/mod.rs +++ b/tests/integration/sorted_sets/mod.rs @@ -24,24 +24,21 @@ async fn create_lex_data(client: &RedisClient, key: &str) -> Result Result, RedisError> { - let values: Vec<(f64, RedisValue)> = (0 .. COUNT) - .into_iter() - .map(|idx| (idx as f64, idx.to_string().into())) - .collect(); + let values: Vec<(f64, RedisValue)> = (0 .. COUNT).map(|idx| (idx as f64, idx.to_string().into())).collect(); - let _: () = client.zadd(key, None, None, false, false, values.clone()).await?; + client.zadd(key, None, None, false, false, values.clone()).await?; Ok(values) } pub async fn should_bzpopmin(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { let publisher_client = client.clone_new(); - let _ = publisher_client.connect(); - let _ = publisher_client.wait_for_connect().await?; + publisher_client.connect(); + publisher_client.wait_for_connect().await?; let jh = tokio::task::spawn(async move { for idx in 0 .. COUNT { @@ -65,8 +62,8 @@ pub async fn should_bzpopmin(client: RedisClient, _: RedisConfig) -> Result<(), pub async fn should_bzpopmax(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { let publisher_client = client.clone_new(); - let _ = publisher_client.connect(); - let _ = publisher_client.wait_for_connect().await?; + publisher_client.connect(); + publisher_client.wait_for_connect().await?; let jh = tokio::task::spawn(async move { for idx in 0 .. COUNT { @@ -232,7 +229,7 @@ pub async fn should_zdiff_values(client: RedisClient, _: RedisConfig) -> Result< let _expected: Vec = expected.iter().map(|(_, v)| v.clone()).collect(); assert_eq!(result, _expected); - let _: () = client + client .zadd( "bar{1}", None, @@ -269,7 +266,7 @@ pub async fn should_zdiffstore_values(client: RedisClient, _: RedisConfig) -> Re let result: i64 = client.zdiffstore("baz{1}", vec!["foo{1}", "bar{1}"]).await?; assert_eq!(result, COUNT); - let _: () = client + client .zadd( "bar{1}", None, @@ -314,7 +311,7 @@ pub async fn should_zinter_values(client: RedisClient, _: RedisConfig) -> Result let result: Vec = client.zinter(vec!["foo{1}", "bar{1}"], None, None, false).await?; assert!(result.is_empty()); - let _: () = client + client .zadd( "bar{1}", None, @@ -356,7 +353,7 @@ pub async fn should_zinterstore_values(client: RedisClient, _: RedisConfig) -> R .await?; assert_eq!(result, 0); - let _: () = client + client .zadd( "bar{1}", None, @@ -722,11 +719,11 @@ pub async fn should_zunion_values(client: RedisClient, _: RedisConfig) -> Result assert_eq!(result, 1); } - let result = client.zunion(vec!["foo{1}", "bar{1}"], None, None, false).await?; + let result: RedisValue = client.zunion(vec!["foo{1}", "bar{1}"], None, None, false).await?; let _expected: Vec = expected.iter().map(|(_, v)| v.clone()).collect(); assert_eq!(result.into_array(), _expected); - let _: () = client + client .zadd( "bar{1}", None, @@ -736,7 +733,7 @@ pub async fn should_zunion_values(client: RedisClient, _: RedisConfig) -> Result expected[0 .. expected.len() - 1].to_vec(), ) .await?; - let result = client.zunion(vec!["foo{1}", "bar{1}"], None, None, true).await?; + let result: RedisValue = client.zunion(vec!["foo{1}", "bar{1}"], None, None, true).await?; // scores are added together with a weight of 1 in this example let mut _expected: Vec<(RedisValue, f64)> = expected[0 .. expected.len() - 1] .iter() @@ -772,7 +769,7 @@ pub async fn should_zunionstore_values(client: RedisClient, _: RedisConfig) -> R .await?; assert_eq!(result, COUNT); - let _: () = client + client .zadd( "bar{1}", None, @@ -792,13 +789,13 @@ pub async fn should_zunionstore_values(client: RedisClient, _: RedisConfig) -> R pub async fn should_zmscore_values(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { for idx in 0 .. COUNT { - let _: () = client.zadd("foo", None, None, false, false, (idx as f64, idx)).await?; + client.zadd("foo", None, None, false, false, (idx as f64, idx)).await?; } let result: Vec = client.zmscore("foo", vec![0, 1]).await?; assert_eq!(result, vec![0.0, 1.0]); - let result: Option = client.zmscore("foo", vec![11]).await?; - assert!(result.is_none()); + let result: [Option; 1] = client.zmscore("foo", vec![11]).await?; + assert!(result[0].is_none()); Ok(()) } diff --git a/tests/integration/streams/mod.rs b/tests/integration/streams/mod.rs index 05e89868..de921783 100644 --- a/tests/integration/streams/mod.rs +++ b/tests/integration/streams/mod.rs @@ -41,13 +41,13 @@ pub async fn should_xinfo_consumers(client: RedisClient, _: RedisConfig) -> Resu let result: Result<(), RedisError> = client.xinfo_consumers("foo{1}", "group1").await; assert!(result.is_err()); - let _ = create_fake_group_and_stream(&client, "foo{1}").await?; - let _: () = client.xgroup_createconsumer("foo{1}", "group1", "consumer1").await?; + create_fake_group_and_stream(&client, "foo{1}").await?; + client.xgroup_createconsumer("foo{1}", "group1", "consumer1").await?; let consumers: Vec> = client.xinfo_consumers("foo{1}", "group1").await?; assert_eq!(consumers.len(), 1); assert_eq!(consumers[0].get("name"), Some(&"consumer1".to_owned())); - let _: () = client.xgroup_createconsumer("foo{1}", "group1", "consumer2").await?; + client.xgroup_createconsumer("foo{1}", "group1", "consumer2").await?; let consumers: Vec> = client.xinfo_consumers("foo{1}", "group1").await?; assert_eq!(consumers.len(), 2); assert_eq!(consumers[0].get("name"), Some(&"consumer1".to_owned())); @@ -61,16 +61,16 @@ pub async fn should_xinfo_groups(client: RedisClient, _: RedisConfig) -> Result< let result: Result<(), RedisError> = client.xinfo_groups("foo{1}").await; assert!(result.is_err()); - let _ = create_fake_group_and_stream(&client, "foo{1}").await?; - let result: Vec> = client.xinfo_groups("foo{1}").await?; + create_fake_group_and_stream(&client, "foo{1}").await?; + let result: Vec> = client.xinfo_groups("foo{1}").await?; assert_eq!(result.len(), 1); - assert_eq!(result[0].get("name"), Some(&"group1".to_owned())); + assert_eq!(result[0].get("name"), Some(&"group1".into())); - let _: () = client.xgroup_create("foo{1}", "group2", "$", true).await?; - let result: Vec> = client.xinfo_groups("foo{1}").await?; + client.xgroup_create("foo{1}", "group2", "$", true).await?; + let result: Vec> = client.xinfo_groups("foo{1}").await?; assert_eq!(result.len(), 2); - assert_eq!(result[0].get("name"), Some(&"group1".to_owned())); - assert_eq!(result[1].get("name"), Some(&"group2".to_owned())); + assert_eq!(result[0].get("name"), Some(&"group1".into())); + assert_eq!(result[1].get("name"), Some(&"group2".into())); Ok(()) } @@ -80,13 +80,13 @@ pub async fn should_xinfo_streams(client: RedisClient, _: RedisConfig) -> Result let result: Result<(), RedisError> = client.xinfo_stream("foo{1}", true, None).await; assert!(result.is_err()); - let _ = create_fake_group_and_stream(&client, "foo{1}").await?; + create_fake_group_and_stream(&client, "foo{1}").await?; let mut result: HashMap = client.xinfo_stream("foo{1}", true, None).await?; assert!(result.len() >= 6); assert_eq!(result.get("length"), Some(&RedisValue::Integer(0))); - let groups: HashMap = result.remove("groups").unwrap().convert()?; - assert_eq!(groups.get("name"), Some(&RedisValue::from("group1"))); + let groups: [HashMap; 1] = result.remove("groups").unwrap().convert()?; + assert_eq!(groups[0].get("name"), Some(&RedisValue::from("group1"))); Ok(()) } @@ -113,7 +113,7 @@ pub async fn should_xadd_manual_id_to_a_stream(client: RedisClient, _: RedisConf pub async fn should_xadd_with_cap_to_a_stream(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { check_null!(client, "foo{1}"); - let _: () = client + client .xadd("foo{1}", false, ("MAXLEN", "=", 1), "*", ("a", "b")) .await?; @@ -127,8 +127,8 @@ pub async fn should_xadd_nomkstream_to_a_stream(client: RedisClient, _: RedisCon let result: Option = client.xadd("foo{1}", true, None, "*", ("a", "b")).await?; assert!(result.is_none()); - let _ = create_fake_group_and_stream(&client, "foo{1}").await?; - let _: () = client.xadd("foo{1}", true, None, "*", ("a", "b")).await?; + create_fake_group_and_stream(&client, "foo{1}").await?; + client.xadd("foo{1}", true, None, "*", ("a", "b")).await?; let len: usize = client.xlen("foo{1}").await?; assert_eq!(len, 1); Ok(()) @@ -136,7 +136,7 @@ pub async fn should_xadd_nomkstream_to_a_stream(client: RedisClient, _: RedisCon pub async fn should_xtrim_a_stream_approx_cap(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { check_null!(client, "foo{1}"); - let _ = create_fake_group_and_stream(&client, "foo{1}").await?; + create_fake_group_and_stream(&client, "foo{1}").await?; let _ = add_stream_entries(&client, "foo{1}", 3).await?; let deleted: usize = client.xtrim("foo{1}", ("MAXLEN", "~", 1)).await?; @@ -144,8 +144,8 @@ pub async fn should_xtrim_a_stream_approx_cap(client: RedisClient, _: RedisConfi let len: usize = client.xlen("foo{1}").await?; assert_eq!(len, 3 - deleted); - let _ = client.del("foo{1}").await?; - let _ = create_fake_group_and_stream(&client, "foo{1}").await?; + client.del("foo{1}").await?; + create_fake_group_and_stream(&client, "foo{1}").await?; let _ = add_stream_entries(&client, "foo{1}", 3).await?; let deleted: usize = client .xtrim("foo{1}", (XCapKind::MaxLen, XCapTrim::AlmostExact, 1)) @@ -159,7 +159,7 @@ pub async fn should_xtrim_a_stream_approx_cap(client: RedisClient, _: RedisConfi pub async fn should_xtrim_a_stream_eq_cap(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { check_null!(client, "foo{1}"); - let _ = create_fake_group_and_stream(&client, "foo{1}").await?; + create_fake_group_and_stream(&client, "foo{1}").await?; let _ = add_stream_entries(&client, "foo{1}", 3).await?; let deleted: usize = client.xtrim("foo{1}", ("MAXLEN", "=", 1)).await?; @@ -167,8 +167,8 @@ pub async fn should_xtrim_a_stream_eq_cap(client: RedisClient, _: RedisConfig) - let len: usize = client.xlen("foo{1}").await?; assert_eq!(len, 1); - let _ = client.del("foo{1}").await?; - let _ = create_fake_group_and_stream(&client, "foo{1}").await?; + client.del("foo{1}").await?; + create_fake_group_and_stream(&client, "foo{1}").await?; let _ = add_stream_entries(&client, "foo{1}", 3).await?; let deleted: usize = client.xtrim("foo{1}", (XCapKind::MaxLen, XCapTrim::Exact, 1)).await?; assert_eq!(deleted, 2); @@ -180,7 +180,7 @@ pub async fn should_xtrim_a_stream_eq_cap(client: RedisClient, _: RedisConfig) - pub async fn should_xdel_one_id_in_a_stream(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { check_null!(client, "foo{1}"); - let _ = create_fake_group_and_stream(&client, "foo{1}").await?; + create_fake_group_and_stream(&client, "foo{1}").await?; let (ids, _) = add_stream_entries(&client, "foo{1}", 2).await?; let deleted: usize = client.xdel("foo{1}", &ids[0]).await?; @@ -192,7 +192,7 @@ pub async fn should_xdel_one_id_in_a_stream(client: RedisClient, _: RedisConfig) pub async fn should_xdel_multiple_ids_in_a_stream(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { check_null!(client, "foo{1}"); - let _ = create_fake_group_and_stream(&client, "foo{1}").await?; + create_fake_group_and_stream(&client, "foo{1}").await?; let (ids, _) = add_stream_entries(&client, "foo{1}", 3).await?; let deleted: usize = client.xdel("foo{1}", ids[0 .. 2].to_vec()).await?; @@ -204,7 +204,7 @@ pub async fn should_xdel_multiple_ids_in_a_stream(client: RedisClient, _: RedisC pub async fn should_xrange_no_count(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { check_null!(client, "foo{1}"); - let _ = create_fake_group_and_stream(&client, "foo{1}").await?; + create_fake_group_and_stream(&client, "foo{1}").await?; let (_, expected) = add_stream_entries(&client, "foo{1}", 3).await?; let result: FakeExpectedValues = client.xrange("foo{1}", "-", "+", None).await?; @@ -214,7 +214,7 @@ pub async fn should_xrange_no_count(client: RedisClient, _: RedisConfig) -> Resu pub async fn should_xrange_values_no_count(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { check_null!(client, "foo{1}"); - let _ = create_fake_group_and_stream(&client, "foo{1}").await?; + create_fake_group_and_stream(&client, "foo{1}").await?; let (ids, _) = add_stream_entries(&client, "foo{1}", 3).await?; let result: Vec> = client.xrange_values("foo{1}", "-", "+", None).await?; @@ -225,7 +225,7 @@ pub async fn should_xrange_values_no_count(client: RedisClient, _: RedisConfig) pub async fn should_xrevrange_values_no_count(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { check_null!(client, "foo{1}"); - let _ = create_fake_group_and_stream(&client, "foo{1}").await?; + create_fake_group_and_stream(&client, "foo{1}").await?; let (mut ids, _) = add_stream_entries(&client, "foo{1}", 3).await?; ids.reverse(); @@ -237,7 +237,7 @@ pub async fn should_xrevrange_values_no_count(client: RedisClient, _: RedisConfi pub async fn should_xrange_with_count(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { check_null!(client, "foo{1}"); - let _ = create_fake_group_and_stream(&client, "foo{1}").await?; + create_fake_group_and_stream(&client, "foo{1}").await?; let (_, expected) = add_stream_entries(&client, "foo{1}", 3).await?; let result: FakeExpectedValues = client.xrange("foo{1}", "-", "+", Some(1)).await?; @@ -247,7 +247,7 @@ pub async fn should_xrange_with_count(client: RedisClient, _: RedisConfig) -> Re pub async fn should_xrevrange_no_count(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { check_null!(client, "foo{1}"); - let _ = create_fake_group_and_stream(&client, "foo{1}").await?; + create_fake_group_and_stream(&client, "foo{1}").await?; let (_, mut expected) = add_stream_entries(&client, "foo{1}", 3).await?; expected.reverse(); @@ -258,7 +258,7 @@ pub async fn should_xrevrange_no_count(client: RedisClient, _: RedisConfig) -> R pub async fn should_xrevrange_with_count(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { check_null!(client, "foo{1}"); - let _ = create_fake_group_and_stream(&client, "foo{1}").await?; + create_fake_group_and_stream(&client, "foo{1}").await?; let (_, mut expected) = add_stream_entries(&client, "foo{1}", 3).await?; expected.reverse(); @@ -269,7 +269,7 @@ pub async fn should_xrevrange_with_count(client: RedisClient, _: RedisConfig) -> pub async fn should_run_xlen_on_stream(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { check_null!(client, "foo{1}"); - let _ = create_fake_group_and_stream(&client, "foo{1}").await?; + create_fake_group_and_stream(&client, "foo{1}").await?; let len: usize = client.xlen("foo{1}").await?; assert_eq!(len, 0); @@ -280,12 +280,12 @@ pub async fn should_run_xlen_on_stream(client: RedisClient, _: RedisConfig) -> R } pub async fn should_xread_map_one_key(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { - let _ = create_fake_group_and_stream(&client, "foo{1}").await?; + create_fake_group_and_stream(&client, "foo{1}").await?; let _ = add_stream_entries(&client, "foo{1}", 3).await?; let result: XReadResponse = client.xread_map(None, None, "foo{1}", "0").await?; - for (idx, (_, record)) in result.get("foo{1}").unwrap().into_iter().enumerate() { + for (idx, (_, record)) in result.get("foo{1}").unwrap().iter().enumerate() { let count = record.get("count").expect("Failed to read count"); assert_eq!(*count, idx); } @@ -294,7 +294,7 @@ pub async fn should_xread_map_one_key(client: RedisClient, _: RedisConfig) -> Re } pub async fn should_xread_one_key_count_1(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { - let _ = create_fake_group_and_stream(&client, "foo{1}").await?; + create_fake_group_and_stream(&client, "foo{1}").await?; let (mut ids, mut expected) = add_stream_entries(&client, "foo{1}", 3).await?; let _ = ids.pop().unwrap(); let most_recent_expected = expected.pop().unwrap(); @@ -314,8 +314,8 @@ pub async fn should_xread_one_key_count_1(client: RedisClient, _: RedisConfig) - } pub async fn should_xread_multiple_keys_count_2(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { - let _ = create_fake_group_and_stream(&client, "foo{1}").await?; - let _ = create_fake_group_and_stream(&client, "bar{1}").await?; + create_fake_group_and_stream(&client, "foo{1}").await?; + create_fake_group_and_stream(&client, "bar{1}").await?; let (foo_ids, foo_inner) = add_stream_entries(&client, "foo{1}", 3).await?; let (bar_ids, bar_inner) = add_stream_entries(&client, "bar{1}", 3).await?; @@ -336,7 +336,7 @@ pub async fn should_xread_multiple_keys_count_2(client: RedisClient, _: RedisCon pub async fn should_xread_with_blocking(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { let expected_id = "123456789-0"; - let _ = create_fake_group_and_stream(&client, "foo{1}").await?; + create_fake_group_and_stream(&client, "foo{1}").await?; let mut expected = HashMap::new(); let mut inner = HashMap::new(); @@ -347,14 +347,14 @@ pub async fn should_xread_with_blocking(client: RedisClient, _: RedisConfig) -> let add_client = client.clone_new(); tokio::spawn(async move { - let _ = add_client.connect(); - let _ = add_client.wait_for_connect().await?; + add_client.connect(); + add_client.wait_for_connect().await?; sleep(Duration::from_millis(500)).await; - let _: () = add_client + add_client .xadd("foo{1}", false, None, expected_id, ("count", 100)) .await?; - let _ = add_client.quit().await?; + add_client.quit().await?; Ok::<(), RedisError>(()) }); @@ -371,14 +371,14 @@ pub async fn should_xread_with_blocking(client: RedisClient, _: RedisConfig) -> pub async fn should_xgroup_create_no_mkstream(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { let result: Result = client.xgroup_create("foo{1}", "group1", "$", false).await; assert!(result.is_err()); - let _: () = client.xadd("foo{1}", false, None, "*", ("count", 1)).await?; - let _: () = client.xgroup_create("foo{1}", "group1", "$", false).await?; + client.xadd("foo{1}", false, None, "*", ("count", 1)).await?; + client.xgroup_create("foo{1}", "group1", "$", false).await?; Ok(()) } pub async fn should_xgroup_create_mkstream(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { - let _: () = client.xgroup_create("foo{1}", "group1", "$", true).await?; + client.xgroup_create("foo{1}", "group1", "$", true).await?; let len: usize = client.xlen("foo{1}").await?; assert_eq!(len, 0); @@ -386,7 +386,7 @@ pub async fn should_xgroup_create_mkstream(client: RedisClient, _: RedisConfig) } pub async fn should_xgroup_createconsumer(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { - let _ = create_fake_group_and_stream(&client, "foo{1}").await?; + create_fake_group_and_stream(&client, "foo{1}").await?; let len: usize = client.xgroup_createconsumer("foo{1}", "group1", "consumer1").await?; assert_eq!(len, 1); @@ -398,7 +398,7 @@ pub async fn should_xgroup_createconsumer(client: RedisClient, _: RedisConfig) - } pub async fn should_xgroup_delconsumer(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { - let _ = create_fake_group_and_stream(&client, "foo{1}").await?; + create_fake_group_and_stream(&client, "foo{1}").await?; let len: usize = client.xgroup_createconsumer("foo{1}", "group1", "consumer1").await?; assert_eq!(len, 1); @@ -411,7 +411,7 @@ pub async fn should_xgroup_delconsumer(client: RedisClient, _: RedisConfig) -> R } pub async fn should_xgroup_destroy(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { - let _ = create_fake_group_and_stream(&client, "foo{1}").await?; + create_fake_group_and_stream(&client, "foo{1}").await?; let len: usize = client.xgroup_destroy("foo{1}", "group1").await?; assert_eq!(len, 1); @@ -419,23 +419,23 @@ pub async fn should_xgroup_destroy(client: RedisClient, _: RedisConfig) -> Resul } pub async fn should_xgroup_setid(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { - let _ = create_fake_group_and_stream(&client, "foo{1}").await?; - let _: () = client.xgroup_setid("foo{1}", "group1", "12345-0").await?; + create_fake_group_and_stream(&client, "foo{1}").await?; + client.xgroup_setid("foo{1}", "group1", "12345-0").await?; Ok(()) } pub async fn should_xreadgroup_one_stream(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { - let _ = create_fake_group_and_stream(&client, "foo{1}").await?; + create_fake_group_and_stream(&client, "foo{1}").await?; let _ = add_stream_entries(&client, "foo{1}", 3).await?; - let _: () = client.xgroup_createconsumer("foo{1}", "group1", "consumer1").await?; + client.xgroup_createconsumer("foo{1}", "group1", "consumer1").await?; let result: XReadResponse = client .xreadgroup_map("group1", "consumer1", None, None, false, "foo{1}", ">") .await?; assert_eq!(result.len(), 1); - for (idx, (_, record)) in result.get("foo{1}").unwrap().into_iter().enumerate() { + for (idx, (_, record)) in result.get("foo{1}").unwrap().iter().enumerate() { let value = record.get("count").expect("Failed to read count"); assert_eq!(idx, *value); } @@ -444,12 +444,12 @@ pub async fn should_xreadgroup_one_stream(client: RedisClient, _: RedisConfig) - } pub async fn should_xreadgroup_multiple_stream(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { - let _ = create_fake_group_and_stream(&client, "foo{1}").await?; - let _ = create_fake_group_and_stream(&client, "bar{1}").await?; + create_fake_group_and_stream(&client, "foo{1}").await?; + create_fake_group_and_stream(&client, "bar{1}").await?; let _ = add_stream_entries(&client, "foo{1}", 3).await?; let _ = add_stream_entries(&client, "bar{1}", 1).await?; - let _: () = client.xgroup_createconsumer("foo{1}", "group1", "consumer1").await?; - let _: () = client.xgroup_createconsumer("bar{1}", "group1", "consumer1").await?; + client.xgroup_createconsumer("foo{1}", "group1", "consumer1").await?; + client.xgroup_createconsumer("bar{1}", "group1", "consumer1").await?; let result: XReadResponse = client .xreadgroup_map( @@ -464,7 +464,7 @@ pub async fn should_xreadgroup_multiple_stream(client: RedisClient, _: RedisConf .await?; assert_eq!(result.len(), 2); - for (idx, (_, record)) in result.get("foo{1}").unwrap().into_iter().enumerate() { + for (idx, (_, record)) in result.get("foo{1}").unwrap().iter().enumerate() { let value = record.get("count").expect("Failed to read count"); assert_eq!(idx, *value); } @@ -476,17 +476,17 @@ pub async fn should_xreadgroup_multiple_stream(client: RedisClient, _: RedisConf } pub async fn should_xreadgroup_block(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { - let _ = create_fake_group_and_stream(&client, "foo{1}").await?; - let _: () = client.xgroup_createconsumer("foo{1}", "group1", "consumer1").await?; + create_fake_group_and_stream(&client, "foo{1}").await?; + client.xgroup_createconsumer("foo{1}", "group1", "consumer1").await?; let add_client = client.clone_new(); tokio::spawn(async move { - let _ = add_client.connect(); - let _ = add_client.wait_for_connect().await?; + add_client.connect(); + add_client.wait_for_connect().await?; sleep(Duration::from_secs(1)).await; - let _: () = add_client.xadd("foo{1}", false, None, "*", ("count", 100)).await?; - let _ = add_client.quit().await?; + add_client.xadd("foo{1}", false, None, "*", ("count", 100)).await?; + add_client.quit().await?; Ok::<_, RedisError>(()) }); @@ -504,9 +504,9 @@ pub async fn should_xreadgroup_block(client: RedisClient, _: RedisConfig) -> Res } pub async fn should_xack_one_id(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { - let _ = create_fake_group_and_stream(&client, "foo{1}").await?; + create_fake_group_and_stream(&client, "foo{1}").await?; let _ = add_stream_entries(&client, "foo{1}", 1).await?; - let _: () = client.xgroup_createconsumer("foo{1}", "group1", "consumer1").await?; + client.xgroup_createconsumer("foo{1}", "group1", "consumer1").await?; let result: XReadResponse = client .xreadgroup_map("group1", "consumer1", None, None, false, "foo{1}", ">") @@ -521,9 +521,9 @@ pub async fn should_xack_one_id(client: RedisClient, _: RedisConfig) -> Result<( } pub async fn should_xack_multiple_ids(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { - let _ = create_fake_group_and_stream(&client, "foo{1}").await?; + create_fake_group_and_stream(&client, "foo{1}").await?; let _ = add_stream_entries(&client, "foo{1}", 3).await?; - let _: () = client.xgroup_createconsumer("foo{1}", "group1", "consumer1").await?; + client.xgroup_createconsumer("foo{1}", "group1", "consumer1").await?; let result: XReadResponse = client .xreadgroup_map("group1", "consumer1", None, None, false, "foo{1}", ">") @@ -538,10 +538,10 @@ pub async fn should_xack_multiple_ids(client: RedisClient, _: RedisConfig) -> Re } pub async fn should_xclaim_one_id(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { - let _ = create_fake_group_and_stream(&client, "foo{1}").await?; + create_fake_group_and_stream(&client, "foo{1}").await?; let _ = add_stream_entries(&client, "foo{1}", 3).await?; - let _: () = client.xgroup_createconsumer("foo{1}", "group1", "consumer1").await?; - let _: () = client.xgroup_createconsumer("foo{1}", "group1", "consumer2").await?; + client.xgroup_createconsumer("foo{1}", "group1", "consumer1").await?; + client.xgroup_createconsumer("foo{1}", "group1", "consumer2").await?; let mut result: XReadResponse = client .xreadgroup_map("group1", "consumer1", Some(1), None, false, "foo{1}", ">") @@ -584,10 +584,10 @@ pub async fn should_xclaim_one_id(client: RedisClient, _: RedisConfig) -> Result } pub async fn should_xclaim_multiple_ids(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { - let _ = create_fake_group_and_stream(&client, "foo{1}").await?; + create_fake_group_and_stream(&client, "foo{1}").await?; let _ = add_stream_entries(&client, "foo{1}", 3).await?; - let _: () = client.xgroup_createconsumer("foo{1}", "group1", "consumer1").await?; - let _: () = client.xgroup_createconsumer("foo{1}", "group1", "consumer2").await?; + client.xgroup_createconsumer("foo{1}", "group1", "consumer1").await?; + client.xgroup_createconsumer("foo{1}", "group1", "consumer2").await?; let mut result: XReadResponse = client .xreadgroup_map("group1", "consumer1", Some(2), None, false, "foo{1}", ">") @@ -636,10 +636,10 @@ pub async fn should_xclaim_multiple_ids(client: RedisClient, _: RedisConfig) -> } pub async fn should_xclaim_with_justid(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { - let _ = create_fake_group_and_stream(&client, "foo{1}").await?; + create_fake_group_and_stream(&client, "foo{1}").await?; let _ = add_stream_entries(&client, "foo{1}", 3).await?; - let _: () = client.xgroup_createconsumer("foo{1}", "group1", "consumer1").await?; - let _: () = client.xgroup_createconsumer("foo{1}", "group1", "consumer2").await?; + client.xgroup_createconsumer("foo{1}", "group1", "consumer1").await?; + client.xgroup_createconsumer("foo{1}", "group1", "consumer2").await?; let mut result: XReadResponse = client .xreadgroup_map("group1", "consumer1", Some(2), None, false, "foo{1}", ">") @@ -681,10 +681,10 @@ pub async fn should_xclaim_with_justid(client: RedisClient, _: RedisConfig) -> R } pub async fn should_xautoclaim_default(client: RedisClient, _: RedisConfig) -> Result<(), RedisError> { - let _ = create_fake_group_and_stream(&client, "foo{1}").await?; + create_fake_group_and_stream(&client, "foo{1}").await?; let _ = add_stream_entries(&client, "foo{1}", 3).await?; - let _: () = client.xgroup_createconsumer("foo{1}", "group1", "consumer1").await?; - let _: () = client.xgroup_createconsumer("foo{1}", "group1", "consumer2").await?; + client.xgroup_createconsumer("foo{1}", "group1", "consumer1").await?; + client.xgroup_createconsumer("foo{1}", "group1", "consumer2").await?; let mut result: XReadResponse = client .xreadgroup_map("group1", "consumer1", Some(2), None, false, "foo{1}", ">") diff --git a/tests/integration/tracking/mod.rs b/tests/integration/tracking/mod.rs index 5b47e9d5..a0c0a610 100644 --- a/tests/integration/tracking/mod.rs +++ b/tests/integration/tracking/mod.rs @@ -22,7 +22,7 @@ pub async fn should_invalidate_foo_resp3(client: RedisClient, _: RedisConfig) -> let invalidated = Arc::new(AtomicBool::new(false)); let _invalidated = invalidated.clone(); - let mut invalidations = client.on_invalidation(); + let mut invalidations = client.invalidation_rx(); tokio::spawn(async move { while let Ok(invalidation) = invalidations.recv().await { if invalidation.keys.contains(&key) { @@ -61,7 +61,7 @@ pub async fn should_invalidate_foo_resp2_centralized(client: RedisClient, _: Red let invalidated = Arc::new(AtomicBool::new(false)); let _invalidated = invalidated.clone(); - let mut invalidations = subscriber.on_invalidation(); + let mut invalidations = subscriber.invalidation_rx(); tokio::spawn(async move { while let Ok(invalidation) = invalidations.recv().await { if invalidation.keys.contains(&key) { diff --git a/tests/integration/utils.rs b/tests/integration/utils.rs index 155cf603..cc032ade 100644 --- a/tests/integration/utils.rs +++ b/tests/integration/utils.rs @@ -2,6 +2,7 @@ #![allow(unused_imports)] #![allow(dead_code)] #![allow(unused_variables)] +#![allow(clippy::match_like_matches_macro)] use fred::{ clients::RedisClient, @@ -10,11 +11,11 @@ use fred::{ types::{PerformanceConfig, ReconnectPolicy, RedisConfig, ServerConfig}, }; use redis_protocol::resp3::prelude::RespVersion; -use std::{convert::TryInto, default::Default, env, fmt, fmt::Formatter, fs, future::Future}; +use std::{convert::TryInto, default::Default, env, fmt, fmt::Formatter, fs, future::Future, time::Duration}; const RECONNECT_DELAY: u32 = 1000; -use fred::types::Server; +use fred::types::{Builder, ConnectionConfig, Server}; #[cfg(any(feature = "enable-rustls", feature = "enable-native-tls"))] use fred::types::{TlsConfig, TlsConnector, TlsHostMapping}; #[cfg(feature = "enable-native-tls")] @@ -30,6 +31,24 @@ pub fn read_env_var(name: &str) -> Option { env::var_os(name).and_then(|s| s.into_string().ok()) } +pub fn should_use_sentinel_config() -> bool { + read_env_var("FRED_SENTINEL_TESTS") + .map(|s| match s.as_ref() { + "1" | "t" | "true" | "yes" => true, + _ => false, + }) + .unwrap_or(false) +} + +pub fn should_flushall_between_tests() -> bool { + read_env_var("FRED_NO_FLUSHALL_DURING_TESTS") + .map(|s| match s.as_ref() { + "1" | "t" | "true" | "yes" => false, + _ => true, + }) + .unwrap_or(true) +} + pub fn read_ci_tls_env() -> bool { match env::var_os("FRED_CI_TLS") { Some(s) => match s.into_string() { @@ -56,6 +75,17 @@ fn read_fail_fast_env() -> bool { } } +#[cfg(feature = "redis-stack")] +pub fn read_redis_centralized_host() -> (String, u16) { + let host = read_env_var("FRED_REDIS_STACK_HOST").unwrap_or("redis-main".into()); + let port = read_env_var("FRED_REDIS_STACK_PORT") + .and_then(|s| s.parse::().ok()) + .unwrap_or(6379); + + (host, port) +} + +#[cfg(not(feature = "redis-stack"))] pub fn read_redis_centralized_host() -> (String, u16) { let host = read_env_var("FRED_REDIS_CENTRALIZED_HOST").unwrap_or("redis-main".into()); let port = read_env_var("FRED_REDIS_CENTRALIZED_PORT") @@ -89,16 +119,22 @@ pub fn read_redis_password() -> String { read_env_var("REDIS_PASSWORD").expect("Failed to read REDIS_PASSWORD env") } +#[cfg(not(feature = "redis-stack"))] pub fn read_redis_username() -> String { read_env_var("REDIS_USERNAME").expect("Failed to read REDIS_USERNAME env") } +// the CI settings for redis-stack don't set up custom ACL rules +#[cfg(feature = "redis-stack")] +pub fn read_redis_username() -> String { + read_env_var("REDIS_USERNAME").unwrap_or("default".into()) +} + #[cfg(feature = "sentinel-auth")] pub fn read_sentinel_password() -> String { read_env_var("REDIS_SENTINEL_PASSWORD").expect("Failed to read REDIS_SENTINEL_PASSWORD env") } -#[cfg(feature = "sentinel-tests")] pub fn read_sentinel_server() -> (String, u16) { let host = read_env_var("FRED_REDIS_SENTINEL_HOST").unwrap_or("127.0.0.1".into()); let port = read_env_var("FRED_REDIS_SENTINEL_PORT") @@ -174,7 +210,7 @@ fn create_rustls_config() -> TlsConnector { ClientConfig::builder() .with_safe_defaults() .with_root_certificates(root_store) - .with_single_cert(cert_chain, PrivateKey(creds.client_key_der)) + .with_client_auth_cert(cert_chain, PrivateKey(creds.client_key_der)) .expect("Failed to build rustls client config") .into() } @@ -227,7 +263,7 @@ fn create_normal_redis_config(cluster: bool, pipeline: bool, resp3: bool) -> (Re }; let perf = PerformanceConfig { auto_pipeline: pipeline, - default_command_timeout_ms: 20_000, + default_command_timeout: Duration::from_secs(20), ..Default::default() }; @@ -266,7 +302,7 @@ fn create_redis_config(cluster: bool, pipeline: bool, resp3: bool) -> (RedisConf }; let perf = PerformanceConfig { auto_pipeline: pipeline, - default_command_timeout_ms: 20_000, + default_command_timeout: Duration::from_secs(20), ..Default::default() }; @@ -294,44 +330,53 @@ fn create_redis_config(cluster: bool, pipeline: bool, resp3: bool) -> (RedisConf }; let perf = PerformanceConfig { auto_pipeline: pipeline, - default_command_timeout_ms: 20_000, + default_command_timeout: Duration::from_secs(20), ..Default::default() }; (config, perf) } -#[cfg(feature = "sentinel-tests")] -pub async fn run_sentinel(func: F, pipeline: bool) +async fn flushall_between_tests(client: &RedisClient) -> Result<(), RedisError> { + if should_flushall_between_tests() { + client.flushall_cluster().await + } else { + Ok(()) + } +} + +pub async fn run_sentinel(func: F, pipeline: bool, resp3: bool) where F: Fn(RedisClient, RedisConfig) -> Fut, Fut: Future>, { let policy = ReconnectPolicy::new_constant(300, RECONNECT_DELAY); + let connection = ConnectionConfig::default(); let config = RedisConfig { fail_fast: read_fail_fast_env(), + version: if resp3 { RespVersion::RESP3 } else { RespVersion::RESP2 }, server: ServerConfig::Sentinel { - hosts: vec![read_sentinel_server().into()], - service_name: "redis-sentinel-main".into(), - // TODO fix this so sentinel-tests can run without sentinel-auth - username: None, - password: Some(read_sentinel_password()), + hosts: vec![read_sentinel_server().into()], + service_name: "redis-sentinel-main".into(), + #[cfg(feature = "sentinel-auth")] + username: None, + #[cfg(feature = "sentinel-auth")] + password: Some(read_sentinel_password()), }, password: Some(read_redis_password()), ..Default::default() }; let perf = PerformanceConfig { auto_pipeline: pipeline, - default_command_timeout_ms: 10_000, ..Default::default() }; - let client = RedisClient::new(config.clone(), Some(perf), Some(policy)); + let client = RedisClient::new(config.clone(), Some(perf), Some(connection), Some(policy)); let _client = client.clone(); let _jh = client.connect(); - let _ = client.wait_for_connect().await.expect("Failed to connect client"); + client.wait_for_connect().await.expect("Failed to connect client"); - let _: () = client.flushall(false).await.expect("Failed to flushall"); + flushall_between_tests(&client).await.expect("Failed to flushall"); func(_client, config.clone()).await.expect("Failed to run test"); let _ = client.quit().await; } @@ -342,17 +387,19 @@ where Fut: Future>, { let (policy, cmd_attempts, fail_fast) = resilience_settings(); - let (mut config, mut perf) = create_redis_config(true, pipeline, resp3); - perf.max_command_attempts = cmd_attempts; + let mut connection = ConnectionConfig::default(); + let (mut config, perf) = create_redis_config(true, pipeline, resp3); + connection.max_command_attempts = cmd_attempts; + connection.max_redirections = 10; config.fail_fast = fail_fast; - let client = RedisClient::new(config.clone(), Some(perf), policy); + let client = RedisClient::new(config.clone(), Some(perf), Some(connection), policy); let _client = client.clone(); let _jh = client.connect(); - let _ = client.wait_for_connect().await.expect("Failed to connect client"); + client.wait_for_connect().await.expect("Failed to connect client"); - let _: () = client.flushall_cluster().await.expect("Failed to flushall"); + flushall_between_tests(&client).await.expect("Failed to flushall"); func(_client, config.clone()).await.expect("Failed to run test"); let _ = client.quit().await; } @@ -362,25 +409,30 @@ where F: Fn(RedisClient, RedisConfig) -> Fut, Fut: Future>, { + if should_use_sentinel_config() { + return run_sentinel(func, pipeline, resp3).await; + } + let (policy, cmd_attempts, fail_fast) = resilience_settings(); - let (mut config, mut perf) = create_redis_config(false, pipeline, resp3); - perf.max_command_attempts = cmd_attempts; + let mut connection = ConnectionConfig::default(); + let (mut config, perf) = create_redis_config(false, pipeline, resp3); + connection.max_command_attempts = cmd_attempts; config.fail_fast = fail_fast; - let client = RedisClient::new(config.clone(), Some(perf), policy); + let client = RedisClient::new(config.clone(), Some(perf), Some(connection), policy); let _client = client.clone(); let _jh = client.connect(); - let _ = client.wait_for_connect().await.expect("Failed to connect client"); + client.wait_for_connect().await.expect("Failed to connect client"); - let _: () = client.flushall(false).await.expect("Failed to flushall"); + flushall_between_tests(&client).await.expect("Failed to flushall"); func(_client, config.clone()).await.expect("Failed to run test"); let _ = client.quit().await; } macro_rules! centralized_test_panic( ($module:tt, $name:tt) => { - #[cfg(not(any(feature="sentinel-tests", feature = "enable-rustls", feature = "enable-native-tls")))] + #[cfg(not(any(feature = "enable-rustls", feature = "enable-native-tls")))] mod $name { mod resp2 { #[tokio::test(flavor = "multi_thread")] @@ -430,34 +482,21 @@ macro_rules! centralized_test_panic( } } } - - #[cfg(feature="sentinel-tests")] - mod $name { - #[tokio::test(flavor = "multi_thread")] - #[should_panic] - async fn sentinel_pipelined() { - let _ = pretty_env_logger::try_init(); - crate::integration::utils::run_sentinel(crate::integration::$module::$name, true).await; - } - - #[tokio::test(flavor = "multi_thread")] - #[should_panic] - async fn sentinel_no_pipeline() { - let _ = pretty_env_logger::try_init(); - crate::integration::utils::run_sentinel(crate::integration::$module::$name, false).await; - } - } } ); macro_rules! cluster_test_panic( ($module:tt, $name:tt) => { - #[cfg(not(feature="sentinel-tests"))] mod $name { + #[cfg(not(feature = "redis-stack"))] mod resp2 { #[tokio::test(flavor = "multi_thread")] #[should_panic] async fn pipelined() { + if crate::integration::utils::should_use_sentinel_config() { + panic!(""); + } + let _ = pretty_env_logger::try_init(); crate::integration::utils::run_cluster(crate::integration::$module::$name, true, false).await; } @@ -465,15 +504,24 @@ macro_rules! cluster_test_panic( #[tokio::test(flavor = "multi_thread")] #[should_panic] async fn no_pipeline() { + if crate::integration::utils::should_use_sentinel_config() { + panic!(""); + } + let _ = pretty_env_logger::try_init(); crate::integration::utils::run_cluster(crate::integration::$module::$name, false, false).await; } } + #[cfg(not(feature = "redis-stack"))] mod resp3 { #[tokio::test(flavor = "multi_thread")] #[should_panic] async fn pipelined() { + if crate::integration::utils::should_use_sentinel_config() { + panic!(""); + } + let _ = pretty_env_logger::try_init(); crate::integration::utils::run_cluster(crate::integration::$module::$name, true, true).await; } @@ -481,6 +529,10 @@ macro_rules! cluster_test_panic( #[tokio::test(flavor = "multi_thread")] #[should_panic] async fn no_pipeline() { + if crate::integration::utils::should_use_sentinel_config() { + panic!(""); + } + let _ = pretty_env_logger::try_init(); crate::integration::utils::run_cluster(crate::integration::$module::$name, false, true).await; } @@ -491,7 +543,7 @@ macro_rules! cluster_test_panic( macro_rules! centralized_test( ($module:tt, $name:tt) => { - #[cfg(not(any(feature="sentinel-tests", feature = "enable-rustls", feature = "enable-native-tls")))] + #[cfg(not(any(feature = "enable-rustls", feature = "enable-native-tls")))] mod $name { mod resp2 { #[tokio::test(flavor = "multi_thread")] @@ -537,51 +589,53 @@ macro_rules! centralized_test( } } } - - #[cfg(feature="sentinel-tests")] - mod $name { - #[tokio::test(flavor = "multi_thread")] - async fn sentinel_pipelined() { - let _ = pretty_env_logger::try_init(); - crate::integration::utils::run_sentinel(crate::integration::$module::$name, true).await; - } - - #[tokio::test(flavor = "multi_thread")] - async fn sentinel_no_pipeline() { - let _ = pretty_env_logger::try_init(); - crate::integration::utils::run_sentinel(crate::integration::$module::$name, false).await; - } - } } ); macro_rules! cluster_test( ($module:tt, $name:tt) => { - #[cfg(not(feature="sentinel-tests"))] mod $name { + #[cfg(not(feature = "redis-stack"))] mod resp2 { #[tokio::test(flavor = "multi_thread")] async fn pipelined() { + if crate::integration::utils::should_use_sentinel_config() { + return; + } + let _ = pretty_env_logger::try_init(); crate::integration::utils::run_cluster(crate::integration::$module::$name, true, false).await; } #[tokio::test(flavor = "multi_thread")] async fn no_pipeline() { + if crate::integration::utils::should_use_sentinel_config() { + return; + } + let _ = pretty_env_logger::try_init(); crate::integration::utils::run_cluster(crate::integration::$module::$name, false, false).await; } } + #[cfg(not(feature = "redis-stack"))] mod resp3 { #[tokio::test(flavor = "multi_thread")] async fn pipelined() { + if crate::integration::utils::should_use_sentinel_config() { + return; + } + let _ = pretty_env_logger::try_init(); crate::integration::utils::run_cluster(crate::integration::$module::$name, true, true).await; } #[tokio::test(flavor = "multi_thread")] async fn no_pipeline() { + if crate::integration::utils::should_use_sentinel_config() { + return; + } + let _ = pretty_env_logger::try_init(); crate::integration::utils::run_cluster(crate::integration::$module::$name, false, true).await; } diff --git a/tests/lib.rs b/tests/lib.rs index 30939619..43b06138 100644 --- a/tests/lib.rs +++ b/tests/lib.rs @@ -1,3 +1,14 @@ +#![allow(clippy::redundant_pattern_matching)] +#![allow(clippy::mutable_key_type)] +#![allow(clippy::derivable_impls)] +#![allow(clippy::enum_variant_names)] +#![allow(clippy::iter_kv_map)] +#![allow(clippy::len_without_is_empty)] +#![allow(clippy::vec_init_then_push)] +#![allow(clippy::while_let_on_iterator)] +#![allow(clippy::type_complexity)] +#![allow(clippy::too_many_arguments)] +#![allow(clippy::disallowed_names)] #![allow(unused_imports)] #[macro_use] diff --git a/tests/runners/default-features.sh b/tests/runners/default-features.sh index 89405a0f..4b907be9 100755 --- a/tests/runners/default-features.sh +++ b/tests/runners/default-features.sh @@ -2,4 +2,5 @@ TEST_ARGV="$1" docker-compose -f tests/docker/compose/centralized.yml \ -f tests/docker/compose/cluster.yml \ - -f tests/docker/runners/compose/default-features.yml run -u $(id -u ${USER}):$(id -g ${USER}) --rm default-features-tests \ No newline at end of file + -f tests/docker/runners/compose/default-features.yml run \ + -u $(id -u ${USER}):$(id -g ${USER}) --rm default-features-tests \ No newline at end of file diff --git a/tests/runners/default-nil-types.sh b/tests/runners/default-nil-types.sh new file mode 100755 index 00000000..8ec1daa8 --- /dev/null +++ b/tests/runners/default-nil-types.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +TEST_ARGV="$1" docker-compose -f tests/docker/compose/centralized.yml \ + -f tests/docker/compose/cluster.yml \ + -f tests/docker/runners/compose/default-nil-types.yml run \ + -u $(id -u ${USER}):$(id -g ${USER}) --rm \ + default-nil-types-tests \ No newline at end of file diff --git a/tests/runners/docker-bash.sh b/tests/runners/docker-bash.sh index 942aa241..c1549564 100755 --- a/tests/runners/docker-bash.sh +++ b/tests/runners/docker-bash.sh @@ -1,5 +1,9 @@ #!/bin/bash # boot all the redis servers and start a bash shell on a new container -docker-compose -f tests/docker/compose/cluster-tls.yml -f tests/docker/compose/centralized.yml -f tests/docker/compose/cluster.yml \ - -f tests/docker/compose/sentinel.yml -f tests/docker/compose/base.yml run -u $(id -u ${USER}):$(id -g ${USER}) --rm debug \ No newline at end of file +docker-compose -f tests/docker/compose/cluster-tls.yml \ + -f tests/docker/compose/centralized.yml \ + -f tests/docker/compose/cluster.yml \ + -f tests/docker/compose/sentinel.yml \ + -f tests/docker/compose/redis-stack.yml \ + -f tests/docker/compose/base.yml run -u $(id -u ${USER}):$(id -g ${USER}) --rm debug \ No newline at end of file diff --git a/tests/runners/everything.sh b/tests/runners/everything.sh index 3b17006b..55bfeeb1 100755 --- a/tests/runners/everything.sh +++ b/tests/runners/everything.sh @@ -5,4 +5,6 @@ tests/runners/no-features.sh "$1"\ && tests/runners/all-features.sh "$1"\ && tests/runners/sentinel-features.sh "$1"\ && tests/runners/cluster-native-tls.sh "$1"\ - && tests/runners/cluster-rustls.sh "$1" \ No newline at end of file + && tests/runners/cluster-rustls.sh "$1" \ + && tests/runners/default-nil-types.sh "$1" \ + && tests/runners/redis-stack.sh "$1" \ No newline at end of file diff --git a/tests/runners/redis-stack.sh b/tests/runners/redis-stack.sh new file mode 100755 index 00000000..b61d44db --- /dev/null +++ b/tests/runners/redis-stack.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +TEST_ARGV="$1" docker-compose -f tests/docker/compose/redis-stack.yml \ + -f tests/docker/runners/compose/redis-stack.yml run \ + -u $(id -u ${USER}):$(id -g ${USER}) --rm redis-stack-tests \ No newline at end of file diff --git a/tests/runners/sentinel-features.sh b/tests/runners/sentinel-features.sh index 7985d027..0f717bb9 100755 --- a/tests/runners/sentinel-features.sh +++ b/tests/runners/sentinel-features.sh @@ -1,4 +1,5 @@ #!/bin/bash -TEST_ARGV="$1" docker-compose -f tests/docker/compose/sentinel.yml -f tests/docker/runners/compose/sentinel-features.yml \ +TEST_ARGV="$1" docker-compose -f tests/docker/compose/sentinel.yml \ + -f tests/docker/runners/compose/sentinel-features.yml \ run -u $(id -u ${USER}):$(id -g ${USER}) --rm sentinel-tests \ No newline at end of file