diff --git a/src/core/bptree_set.h b/src/core/bptree_set.h index 227f1bab1889..ae5240b742db 100644 --- a/src/core/bptree_set.h +++ b/src/core/bptree_set.h @@ -48,7 +48,7 @@ template > class BPTree { bool Delete(KeyT item); - std::optional GetRank(KeyT item) const; + std::optional GetRank(KeyT item, bool reverse = false) const; size_t Height() const { return height_; @@ -222,7 +222,7 @@ template bool BPTree::Delete(KeyT item) } template -std::optional BPTree::GetRank(KeyT item) const { +std::optional BPTree::GetRank(KeyT item, bool reverse) const { if (!root_) return std::nullopt; @@ -231,6 +231,10 @@ std::optional BPTree::GetRank(KeyT item) const { if (!found) return std::nullopt; + if (reverse) { + return count_ - path.Rank() - 1; + } + return path.Rank(); } diff --git a/src/core/sorted_map.cc b/src/core/sorted_map.cc index d2f395908c05..497565525291 100644 --- a/src/core/sorted_map.cc +++ b/src/core/sorted_map.cc @@ -309,9 +309,9 @@ optional SortedMap::GetRank(sds ele, bool reverse) const { if (obj == nullptr) return std::nullopt; - optional rank = score_tree->GetRank(obj); + optional rank = score_tree->GetRank(obj, reverse); DCHECK(rank); - return reverse ? score_map->UpperBoundSize() - *rank - 1 : *rank; + return *rank; } SortedMap::ScoredArray SortedMap::GetRange(const zrangespec& range, unsigned offset, unsigned limit, @@ -783,5 +783,15 @@ bool SortedMap::DefragIfNeeded(float ratio) { return reallocated; } +std::optional SortedMap::GetRankAndScore(sds ele, bool reverse) const { + ScoreSds obj = score_map->FindObj(ele); + if (obj == nullptr) + return std::nullopt; + + optional rank = score_tree->GetRank(obj, reverse); + DCHECK(rank); + + return SortedMap::RankAndScore{*rank, GetObjScore(obj)}; +} } // namespace detail } // namespace dfly diff --git a/src/core/sorted_map.h b/src/core/sorted_map.h index 8206779d65f7..34f4f6135db9 100644 --- a/src/core/sorted_map.h +++ b/src/core/sorted_map.h @@ -35,6 +35,7 @@ class SortedMap { using ScoredMember = std::pair; using ScoredArray = std::vector; using ScoreSds = void*; + using RankAndScore = std::pair; SortedMap(PMR_NS::memory_resource* res); ~SortedMap(); @@ -72,6 +73,7 @@ class SortedMap { std::optional GetScore(sds ele) const; std::optional GetRank(sds ele, bool reverse) const; + std::optional GetRankAndScore(sds ele, bool reverse) const; ScoredArray GetRange(const zrangespec& r, unsigned offs, unsigned len, bool rev) const; ScoredArray GetLexRange(const zlexrangespec& r, unsigned o, unsigned l, bool rev) const; diff --git a/src/server/zset_family.cc b/src/server/zset_family.cc index b60c90da5ff7..71e898cd04d7 100644 --- a/src/server/zset_family.cc +++ b/src/server/zset_family.cc @@ -1387,8 +1387,13 @@ OpResult OpRemRange(const OpArgs& op_args, string_view key, return iv.removed(); } -OpResult OpRank(const OpArgs& op_args, string_view key, string_view member, - bool reverse) { +struct RankResult { + unsigned rank; + double score = 0; +}; + +OpResult OpRank(const OpArgs& op_args, string_view key, string_view member, + bool reverse, bool with_score) { auto res_it = op_args.GetDbSlice().FindReadOnly(op_args.db_cntx, key, OBJ_ZSET); if (!res_it) return res_it.status(); @@ -1417,18 +1422,34 @@ OpResult OpRank(const OpArgs& op_args, string_view key, string_view me if (eptr == NULL) return OpStatus::KEY_NOTFOUND; - if (reverse) { - return lpLength(zl) / 2 - rank; + RankResult res{}; + res.rank = reverse ? lpLength(zl) / 2 - rank : rank - 1; + if (with_score) { + res.score = zzlGetScore(sptr); } - return rank - 1; + return res; } DCHECK_EQ(robj_wrapper->encoding(), OBJ_ENCODING_SKIPLIST); detail::SortedMap* ss = (detail::SortedMap*)robj_wrapper->inner_obj(); - std::optional rank = ss->GetRank(WrapSds(member), reverse); - if (!rank) - return OpStatus::KEY_NOTFOUND; - return *rank; + RankResult res{}; + + if (with_score) { + auto rankAndScore = ss->GetRankAndScore(WrapSds(member), reverse); + if (!rankAndScore) { + return OpStatus::KEY_NOTFOUND; + } + res.rank = rankAndScore->first; + res.score = rankAndScore->second; + } else { + std::optional rank = ss->GetRank(WrapSds(member), reverse); + if (!rank) { + return OpStatus::KEY_NOTFOUND; + } + res.rank = *rank; + } + + return res; } OpResult OpCount(const OpArgs& op_args, std::string_view key, @@ -1979,17 +2000,40 @@ void ZRangeGeneric(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder, } void ZRankGeneric(CmdArgList args, bool reverse, Transaction* tx, SinkReplyBuilder* builder) { - string_view key = ArgS(args, 0); - string_view member = ArgS(args, 1); + // send this error exact as redis does, it checks number of arguments first + if (args.size() > 3) { + return builder->SendError(WrongNumArgsError(reverse ? "ZREVRANK" : "ZRANK")); + } + + facade::CmdArgParser parser(args); + + string_view key = parser.Next(); + string_view member = parser.Next(); + bool with_score = false; + + if (parser.HasNext()) { + parser.ExpectTag("WITHSCORE"); + with_score = true; + } + + if (!parser.Finalize()) { + return builder->SendError(parser.Error()->MakeReply()); + } auto cb = [&](Transaction* t, EngineShard* shard) { - return OpRank(t->GetOpArgs(shard), key, member, reverse); + return OpRank(t->GetOpArgs(shard), key, member, reverse, with_score); }; + OpResult result = tx->ScheduleSingleHopT(std::move(cb)); auto* rb = static_cast(builder); - OpResult result = tx->ScheduleSingleHopT(std::move(cb)); if (result) { - rb->SendLong(*result); + if (with_score) { + rb->StartArray(2); + rb->SendLong(result->rank); + rb->SendDouble(result->score); + } else { + rb->SendLong(result->rank); + } } else if (result.status() == OpStatus::KEY_NOTFOUND) { rb->SendNull(); } else { @@ -2340,7 +2384,7 @@ void ZSetFamily::ZRange(CmdArgList args, Transaction* tx, SinkReplyBuilder* buil } void ZSetFamily::ZRank(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) { - ZRankGeneric(std::move(args), false, tx, builder); + ZRankGeneric(args, false, tx, builder); } void ZSetFamily::ZRevRange(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) { @@ -2362,7 +2406,7 @@ void ZSetFamily::ZRevRangeByScore(CmdArgList args, Transaction* tx, SinkReplyBui } void ZSetFamily::ZRevRank(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) { - ZRankGeneric(std::move(args), true, tx, builder); + ZRankGeneric(args, true, tx, builder); } void ZSetFamily::ZRangeByLex(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) { @@ -3213,7 +3257,7 @@ void ZSetFamily::Register(CommandRegistry* registry) { << CI{"ZREM", CO::FAST | CO::WRITE, -3, 1, 1, acl::kZRem}.HFUNC(ZRem) << CI{"ZRANGE", CO::READONLY, -4, 1, 1, acl::kZRange}.HFUNC(ZRange) << CI{"ZRANDMEMBER", CO::READONLY, -2, 1, 1, acl::kZRandMember}.HFUNC(ZRandMember) - << CI{"ZRANK", CO::READONLY | CO::FAST, 3, 1, 1, acl::kZRank}.HFUNC(ZRank) + << CI{"ZRANK", CO::READONLY | CO::FAST, -3, 1, 1, acl::kZRank}.HFUNC(ZRank) << CI{"ZRANGEBYLEX", CO::READONLY, -4, 1, 1, acl::kZRangeByLex}.HFUNC(ZRangeByLex) << CI{"ZRANGEBYSCORE", CO::READONLY, -4, 1, 1, acl::kZRangeByScore}.HFUNC(ZRangeByScore) << CI{"ZRANGESTORE", CO::WRITE | CO::DENYOOM, -5, 1, 2, acl::kZRangeStore}.HFUNC(ZRangeStore) @@ -3226,7 +3270,7 @@ void ZSetFamily::Register(CommandRegistry* registry) { << CI{"ZREVRANGEBYLEX", CO::READONLY, -4, 1, 1, acl::kZRevRangeByLex}.HFUNC(ZRevRangeByLex) << CI{"ZREVRANGEBYSCORE", CO::READONLY, -4, 1, 1, acl::kZRevRangeByScore}.HFUNC( ZRevRangeByScore) - << CI{"ZREVRANK", CO::READONLY | CO::FAST, 3, 1, 1, acl::kZRevRank}.HFUNC(ZRevRank) + << CI{"ZREVRANK", CO::READONLY | CO::FAST, -3, 1, 1, acl::kZRevRank}.HFUNC(ZRevRank) << CI{"ZSCAN", CO::READONLY, -3, 1, 1, acl::kZScan}.HFUNC(ZScan) << CI{"ZUNION", CO::READONLY | CO::VARIADIC_KEYS, -3, 2, 2, acl::kZUnion}.HFUNC(ZUnion) << CI{"ZUNIONSTORE", kStoreMask, -4, 3, 3, acl::kZUnionStore}.HFUNC(ZUnionStore) diff --git a/src/server/zset_family_test.cc b/src/server/zset_family_test.cc index f7b0a6db62c0..140488741de7 100644 --- a/src/server/zset_family_test.cc +++ b/src/server/zset_family_test.cc @@ -279,13 +279,43 @@ TEST_F(ZSetFamilyTest, ZRangeRank) { EXPECT_EQ(2, CheckedInt({"zcount", "x", "1.1", "2.1"})); EXPECT_EQ(1, CheckedInt({"zcount", "x", "(1.1", "2.1"})); EXPECT_EQ(0, CheckedInt({"zcount", "y", "(1.1", "2.1"})); +} +TEST_F(ZSetFamilyTest, ZRank) { + Run({"zadd", "x", "1.1", "a", "2.1", "b"}); EXPECT_EQ(0, CheckedInt({"zrank", "x", "a"})); EXPECT_EQ(1, CheckedInt({"zrank", "x", "b"})); EXPECT_EQ(1, CheckedInt({"zrevrank", "x", "a"})); EXPECT_EQ(0, CheckedInt({"zrevrank", "x", "b"})); EXPECT_THAT(Run({"zrevrank", "x", "c"}), ArgType(RespExpr::NIL)); EXPECT_THAT(Run({"zrank", "y", "c"}), ArgType(RespExpr::NIL)); + EXPECT_THAT(Run({"zrevrank", "x", "c", "WITHSCORE"}), ArgType(RespExpr::NIL)); + EXPECT_THAT(Run({"zrank", "y", "c", "WITHSCORE"}), ArgType(RespExpr::NIL)); + + auto resp = Run({"zrank", "x", "a", "WITHSCORE"}); + ASSERT_THAT(resp, ArgType(RespExpr::ARRAY)); + ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), "1.1")); + + resp = Run({"zrank", "x", "b", "WITHSCORE"}); + ASSERT_THAT(resp, ArgType(RespExpr::ARRAY)); + ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(1), "2.1")); + + resp = Run({"zrevrank", "x", "a", "WITHSCORE"}); + ASSERT_THAT(resp, ArgType(RespExpr::ARRAY)); + ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(1), "1.1")); + + resp = Run({"zrevrank", "x", "b", "WITHSCORE"}); + ASSERT_THAT(resp, ArgType(RespExpr::ARRAY)); + ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), "2.1")); + + resp = Run({"zrank", "x", "a", "WITHSCORES"}); + ASSERT_THAT(resp, ErrArg("syntax error")); + + resp = Run({"zrank", "x", "a", "WITHSCORES", "42"}); + ASSERT_THAT(resp, ErrArg("wrong number of arguments for 'zrank' command")); + + resp = Run({"zrevrank", "x", "a", "WITHSCORES", "42"}); + ASSERT_THAT(resp, ErrArg("wrong number of arguments for 'zrevrank' command")); } TEST_F(ZSetFamilyTest, LargeSet) {