diff --git a/src/server/container_utils.cc b/src/server/container_utils.cc index beb540cc3980..07c5e51ba55d 100644 --- a/src/server/container_utils.cc +++ b/src/server/container_utils.cc @@ -270,6 +270,36 @@ bool IterateSortedSet(const detail::RobjWrapper* robj_wrapper, const IterateSort return false; } +bool IterateMap(const PrimeValue& pv, const IterateKVFunc& func) { + bool finished = true; + + if (pv.Encoding() == kEncodingListPack) { + uint8_t intbuf[LP_INTBUF_SIZE]; + uint8_t* lp = (uint8_t*)pv.RObjPtr(); + uint8_t* fptr = lpFirst(lp); + while (fptr) { + string_view key = LpGetView(fptr, intbuf); + fptr = lpNext(lp, fptr); + string_view val = LpGetView(fptr, intbuf); + fptr = lpNext(lp, fptr); + if (!func(ContainerEntry{key.data(), key.size()}, ContainerEntry{val.data(), val.size()})) { + finished = false; + break; + } + } + } else { + StringMap* sm = static_cast(pv.RObjPtr()); + for (const auto& k_v : *sm) { + if (!func(ContainerEntry{k_v.first, sdslen(k_v.first)}, + ContainerEntry{k_v.second, sdslen(k_v.second)})) { + finished = false; + break; + } + } + } + return finished; +} + StringMap* GetStringMap(const PrimeValue& pv, const DbContext& db_context) { DCHECK_EQ(pv.Encoding(), kEncodingStrMap2); StringMap* res = static_cast(pv.RObjPtr()); diff --git a/src/server/container_utils.h b/src/server/container_utils.h index d2933766276e..e9f6ef44a558 100644 --- a/src/server/container_utils.h +++ b/src/server/container_utils.h @@ -54,6 +54,7 @@ struct ContainerEntry { using IterateFunc = std::function; using IterateSortedFunc = std::function; +using IterateKVFunc = std::function; // Iterate over all values and call func(val). Iteration stops as soon // as func return false. Returns true if it successfully processed all elements @@ -72,6 +73,8 @@ bool IterateSortedSet(const detail::RobjWrapper* robj_wrapper, const IterateSort int32_t start = 0, int32_t end = -1, bool reverse = false, bool use_score = false); +bool IterateMap(const PrimeValue& pv, const IterateKVFunc& func); + // Get StringMap pointer from primetable value. Sets expire time from db_context StringMap* GetStringMap(const PrimeValue& pv, const DbContext& db_context); diff --git a/src/server/journal/streamer.cc b/src/server/journal/streamer.cc index 6d4bb7f81af9..6e565505f717 100644 --- a/src/server/journal/streamer.cc +++ b/src/server/journal/streamer.cc @@ -9,6 +9,7 @@ #include "base/flags.h" #include "base/logging.h" #include "server/cluster/cluster_defs.h" +#include "server/container_utils.h" #include "util/fibers/synchronization.h" using namespace facade; @@ -317,37 +318,188 @@ void RestoreStreamer::OnDbChange(DbIndex db_index, const DbSlice::ChangeReq& req void RestoreStreamer::WriteEntry(string_view key, const PrimeValue& pk, const PrimeValue& pv, uint64_t expire_ms) { + // We send RESTORE commands for small objects, or objects we don't support breaking. + bool use_restore_serialization = true; + if (serialization_max_chunk_size > 0 && pv.MallocUsed() > serialization_max_chunk_size) { + switch (pv.ObjType()) { + case OBJ_SET: + WriteSet(key, pv); + use_restore_serialization = false; + break; + case OBJ_ZSET: + WriteZSet(key, pv); + use_restore_serialization = false; + break; + case OBJ_HASH: + WriteHash(key, pv); + use_restore_serialization = false; + break; + case OBJ_LIST: + WriteList(key, pv); + use_restore_serialization = false; + break; + case OBJ_STRING: + case OBJ_STREAM: + case OBJ_JSON: + case OBJ_SBF: + default: + // These types are unsupported wrt splitting huge values to multiple commands, so we send + // them as a RESTORE command. + break; + } + } + + if (use_restore_serialization) { + // RESTORE sets STICK and EXPIRE as part of the command. + WriteRestore(key, pk, pv, expire_ms); + } else { + WriteStickIfNeeded(key, pk); + WriteExpireIfNeeded(key, expire_ms); + } +} + +class CommandAggregator { + public: + using Callback = std::function)>; + + CommandAggregator(string_view key, Callback cb) : key_(key), cb_(cb) { + } + + ~CommandAggregator() { + CommitPending(); + } + + enum class CommitMode { kAuto, kNoCommit }; + void AddArg(string arg, CommitMode commit_mode = CommitMode::kAuto) { + agg_bytes_ += arg.size(); + members_.push_back(std::move(arg)); + + if (commit_mode != CommitMode::kNoCommit && agg_bytes_ >= serialization_max_chunk_size) { + CommitPending(); + } + } + + private: + void CommitPending() { + if (members_.empty()) { + return; + } + + args_.clear(); + args_.reserve(members_.size() + 1); + args_.push_back(key_); + for (string_view member : members_) { + args_.push_back(member); + } + cb_(args_); + members_.clear(); + } + + string_view key_; + Callback cb_; + vector members_; + absl::InlinedVector args_; + size_t agg_bytes_ = 0; +}; + +void RestoreStreamer::WriteSet(string_view key, const PrimeValue& pv) { + CommandAggregator aggregator( + key, [&](absl::Span args) { WriteCommand("SADD", args); }); + + container_utils::IterateSet(pv, [&](container_utils::ContainerEntry ce) { + aggregator.AddArg(ce.ToString()); + return true; + }); +} + +void RestoreStreamer::WriteList(string_view key, const PrimeValue& pv) { + CommandAggregator aggregator( + key, [&](absl::Span args) { WriteCommand("RPUSH", args); }); + + container_utils::IterateList(pv, [&](container_utils::ContainerEntry ce) { + aggregator.AddArg(ce.ToString()); + return true; + }); +} + +void RestoreStreamer::WriteZSet(string_view key, const PrimeValue& pv) { + CommandAggregator aggregator( + key, [&](absl::Span args) { WriteCommand("ZADD", args); }); + + container_utils::IterateSortedSet( + pv.GetRobjWrapper(), + [&](container_utils::ContainerEntry ce, double score) { + aggregator.AddArg(absl::StrCat(score), CommandAggregator::CommitMode::kNoCommit); + aggregator.AddArg(ce.ToString()); + return true; + }, + /*start=*/0, /*end=*/-1, /*reverse=*/false, /*use_score=*/true); +} + +void RestoreStreamer::WriteHash(string_view key, const PrimeValue& pv) { + CommandAggregator aggregator( + key, [&](absl::Span args) { WriteCommand("HSET", args); }); + + container_utils::IterateMap( + pv, [&](container_utils::ContainerEntry k, container_utils::ContainerEntry v) { + aggregator.AddArg(k.ToString(), CommandAggregator::CommitMode::kNoCommit); + aggregator.AddArg(v.ToString()); + return true; + }); +} + +void RestoreStreamer::WriteRestore(std::string_view key, const PrimeValue& pk, const PrimeValue& pv, + uint64_t expire_ms) { absl::InlinedVector args; args.push_back(key); string expire_str = absl::StrCat(expire_ms); args.push_back(expire_str); - io::StringSink restore_cmd_sink; - { // to destroy extra copy - io::StringSink value_dump_sink; - SerializerBase::DumpObject(pv, &value_dump_sink); - args.push_back(value_dump_sink.str()); + io::StringSink value_dump_sink; + SerializerBase::DumpObject(pv, &value_dump_sink); + args.push_back(value_dump_sink.str()); - args.push_back("ABSTTL"); // Means expire string is since epoch + args.push_back("ABSTTL"); // Means expire string is since epoch - if (pk.IsSticky()) { - args.push_back("STICK"); - } + if (pk.IsSticky()) { + args.push_back("STICK"); + } - journal::Entry entry(0, // txid - journal::Op::COMMAND, // single command - 0, // db index - 1, // shard count - 0, // slot-id, but it is ignored at this level - journal::Entry::Payload("RESTORE", ArgSlice(args))); + WriteCommand("RESTORE", args); +} - JournalWriter writer{&restore_cmd_sink}; - writer.Write(entry); +void RestoreStreamer::WriteCommand(string_view cmd, absl::Span args) { + journal::Entry entry(0, // txid + journal::Op::COMMAND, // single command + 0, // db index + 1, // shard count + 0, // slot-id, but it is ignored at this level + journal::Entry::Payload(cmd, ArgSlice(args))); + + // Serialize into a string + io::StringSink cmd_sink; + JournalWriter writer{&cmd_sink}; + writer.Write(entry); + + // Write string to dest_ + Write(cmd_sink.str()); +} + +void RestoreStreamer::WriteStickIfNeeded(string_view key, const PrimeValue& pk) { + if (!pk.IsSticky()) { + return; } - // TODO: From DumpObject to till Write we tripple copy the PrimeValue. It's very inefficient and - // will burn CPU for large values. - Write(restore_cmd_sink.str()); + + WriteCommand("STICK", {key}); +} + +void RestoreStreamer::WriteExpireIfNeeded(string_view key, uint64_t expire_ms) { + if (expire_ms == 0) { + return; + } + + WriteCommand("PEXIRE", {key, absl::StrCat(expire_ms)}); } } // namespace dfly diff --git a/src/server/journal/streamer.h b/src/server/journal/streamer.h index c625b60c5157..85da80f34859 100644 --- a/src/server/journal/streamer.h +++ b/src/server/journal/streamer.h @@ -98,7 +98,18 @@ class RestoreStreamer : public JournalStreamer { // Returns whether anything was written void WriteBucket(PrimeTable::bucket_iterator it); - void WriteEntry(string_view key, const PrimeValue& pk, const PrimeValue& pv, uint64_t expire_ms); + void WriteEntry(std::string_view key, const PrimeValue& pk, const PrimeValue& pv, + uint64_t expire_ms); + void WriteCommand(std::string_view cmd, absl::Span args); + void WriteStickIfNeeded(std::string_view key, const PrimeValue& pk); + void WriteExpireIfNeeded(std::string_view key, uint64_t expire_ms); + + void WriteSet(std::string_view key, const PrimeValue& pv); + void WriteZSet(std::string_view key, const PrimeValue& pv); + void WriteHash(std::string_view key, const PrimeValue& pv); + void WriteList(std::string_view key, const PrimeValue& pv); + void WriteRestore(std::string_view key, const PrimeValue& pk, const PrimeValue& pv, + uint64_t expire_ms); DbSlice* db_slice_; DbTableArray db_array_; diff --git a/tests/dragonfly/cluster_test.py b/tests/dragonfly/cluster_test.py index 66f70856e236..e8c29826ab1f 100644 --- a/tests/dragonfly/cluster_test.py +++ b/tests/dragonfly/cluster_test.py @@ -1294,10 +1294,11 @@ async def test_network_disconnect_during_migration(df_factory, df_seeder_factory @pytest.mark.parametrize( - "node_count, segments, keys", + "node_count, segments, keys, huge_values", [ - pytest.param(3, 16, 20_000), - pytest.param(5, 20, 30_000, marks=[pytest.mark.slow, pytest.mark.opt_only]), + pytest.param(3, 16, 20_000, 10), + pytest.param(3, 16, 20_000, 1_000_000), + pytest.param(5, 20, 30_000, 1_000_000, marks=[pytest.mark.slow, pytest.mark.opt_only]), ], ) @dfly_args({"proactor_threads": 4, "cluster_mode": "yes"}) @@ -1307,12 +1308,14 @@ async def test_cluster_fuzzymigration( node_count: int, segments: int, keys: int, + huge_values: int, ): instances = [ df_factory.create( port=BASE_PORT + i, admin_port=BASE_PORT + i + 1000, vmodule="outgoing_slot_migration=9,cluster_family=9,incoming_slot_migration=9", + serialization_max_chunk_size=huge_values, ) for i in range(node_count) ]