Skip to content

Commit

Permalink
feat(zset_family): support WITHSCORE in zrevrank/zrank commands (#3921)…
Browse files Browse the repository at this point in the history
… (#4001)

Signed-off-by: Diskein <[email protected]>
  • Loading branch information
Diskein authored Oct 29, 2024
1 parent 92be74f commit f16a325
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 22 deletions.
8 changes: 6 additions & 2 deletions src/core/bptree_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ template <typename T, typename Policy = BPTreePolicy<T>> class BPTree {

bool Delete(KeyT item);

std::optional<uint32_t> GetRank(KeyT item) const;
std::optional<uint32_t> GetRank(KeyT item, bool reverse = false) const;

size_t Height() const {
return height_;
Expand Down Expand Up @@ -222,7 +222,7 @@ template <typename T, typename Policy> bool BPTree<T, Policy>::Delete(KeyT item)
}

template <typename T, typename Policy>
std::optional<uint32_t> BPTree<T, Policy>::GetRank(KeyT item) const {
std::optional<uint32_t> BPTree<T, Policy>::GetRank(KeyT item, bool reverse) const {
if (!root_)
return std::nullopt;

Expand All @@ -231,6 +231,10 @@ std::optional<uint32_t> BPTree<T, Policy>::GetRank(KeyT item) const {
if (!found)
return std::nullopt;

if (reverse) {
return count_ - path.Rank() - 1;
}

return path.Rank();
}

Expand Down
14 changes: 12 additions & 2 deletions src/core/sorted_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -309,9 +309,9 @@ optional<unsigned> SortedMap::GetRank(sds ele, bool reverse) const {
if (obj == nullptr)
return std::nullopt;

optional rank = score_tree->GetRank(obj);
optional rank = score_tree->GetRank(obj, reverse);
DCHECK(rank);
return reverse ? score_map->UpperBoundSize() - *rank - 1 : *rank;
return *rank;
}

SortedMap::ScoredArray SortedMap::GetRange(const zrangespec& range, unsigned offset, unsigned limit,
Expand Down Expand Up @@ -783,5 +783,15 @@ bool SortedMap::DefragIfNeeded(float ratio) {
return reallocated;
}

std::optional<SortedMap::RankAndScore> SortedMap::GetRankAndScore(sds ele, bool reverse) const {
ScoreSds obj = score_map->FindObj(ele);
if (obj == nullptr)
return std::nullopt;

optional rank = score_tree->GetRank(obj, reverse);
DCHECK(rank);

return SortedMap::RankAndScore{*rank, GetObjScore(obj)};
}
} // namespace detail
} // namespace dfly
2 changes: 2 additions & 0 deletions src/core/sorted_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class SortedMap {
using ScoredMember = std::pair<std::string, double>;
using ScoredArray = std::vector<ScoredMember>;
using ScoreSds = void*;
using RankAndScore = std::pair<unsigned, double>;

SortedMap(PMR_NS::memory_resource* res);
~SortedMap();
Expand Down Expand Up @@ -72,6 +73,7 @@ class SortedMap {

std::optional<double> GetScore(sds ele) const;
std::optional<unsigned> GetRank(sds ele, bool reverse) const;
std::optional<RankAndScore> GetRankAndScore(sds ele, bool reverse) const;
ScoredArray GetRange(const zrangespec& r, unsigned offs, unsigned len, bool rev) const;
ScoredArray GetLexRange(const zlexrangespec& r, unsigned o, unsigned l, bool rev) const;

Expand Down
80 changes: 62 additions & 18 deletions src/server/zset_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1387,8 +1387,13 @@ OpResult<unsigned> OpRemRange(const OpArgs& op_args, string_view key,
return iv.removed();
}

OpResult<unsigned> OpRank(const OpArgs& op_args, string_view key, string_view member,
bool reverse) {
struct RankResult {
unsigned rank;
double score = 0;
};

OpResult<RankResult> OpRank(const OpArgs& op_args, string_view key, string_view member,
bool reverse, bool with_score) {
auto res_it = op_args.GetDbSlice().FindReadOnly(op_args.db_cntx, key, OBJ_ZSET);
if (!res_it)
return res_it.status();
Expand Down Expand Up @@ -1417,18 +1422,34 @@ OpResult<unsigned> OpRank(const OpArgs& op_args, string_view key, string_view me
if (eptr == NULL)
return OpStatus::KEY_NOTFOUND;

if (reverse) {
return lpLength(zl) / 2 - rank;
RankResult res{};
res.rank = reverse ? lpLength(zl) / 2 - rank : rank - 1;
if (with_score) {
res.score = zzlGetScore(sptr);
}
return rank - 1;
return res;
}
DCHECK_EQ(robj_wrapper->encoding(), OBJ_ENCODING_SKIPLIST);
detail::SortedMap* ss = (detail::SortedMap*)robj_wrapper->inner_obj();
std::optional<unsigned> rank = ss->GetRank(WrapSds(member), reverse);
if (!rank)
return OpStatus::KEY_NOTFOUND;

return *rank;
RankResult res{};

if (with_score) {
auto rankAndScore = ss->GetRankAndScore(WrapSds(member), reverse);
if (!rankAndScore) {
return OpStatus::KEY_NOTFOUND;
}
res.rank = rankAndScore->first;
res.score = rankAndScore->second;
} else {
std::optional<unsigned> rank = ss->GetRank(WrapSds(member), reverse);
if (!rank) {
return OpStatus::KEY_NOTFOUND;
}
res.rank = *rank;
}

return res;
}

OpResult<unsigned> OpCount(const OpArgs& op_args, std::string_view key,
Expand Down Expand Up @@ -1979,17 +2000,40 @@ void ZRangeGeneric(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder,
}

void ZRankGeneric(CmdArgList args, bool reverse, Transaction* tx, SinkReplyBuilder* builder) {
string_view key = ArgS(args, 0);
string_view member = ArgS(args, 1);
// send this error exact as redis does, it checks number of arguments first
if (args.size() > 3) {
return builder->SendError(WrongNumArgsError(reverse ? "ZREVRANK" : "ZRANK"));
}

facade::CmdArgParser parser(args);

string_view key = parser.Next();
string_view member = parser.Next();
bool with_score = false;

if (parser.HasNext()) {
parser.ExpectTag("WITHSCORE");
with_score = true;
}

if (!parser.Finalize()) {
return builder->SendError(parser.Error()->MakeReply());
}

auto cb = [&](Transaction* t, EngineShard* shard) {
return OpRank(t->GetOpArgs(shard), key, member, reverse);
return OpRank(t->GetOpArgs(shard), key, member, reverse, with_score);
};

OpResult<RankResult> result = tx->ScheduleSingleHopT(std::move(cb));
auto* rb = static_cast<RedisReplyBuilder*>(builder);
OpResult<unsigned> result = tx->ScheduleSingleHopT(std::move(cb));
if (result) {
rb->SendLong(*result);
if (with_score) {
rb->StartArray(2);
rb->SendLong(result->rank);
rb->SendDouble(result->score);
} else {
rb->SendLong(result->rank);
}
} else if (result.status() == OpStatus::KEY_NOTFOUND) {
rb->SendNull();
} else {
Expand Down Expand Up @@ -2340,7 +2384,7 @@ void ZSetFamily::ZRange(CmdArgList args, Transaction* tx, SinkReplyBuilder* buil
}

void ZSetFamily::ZRank(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) {
ZRankGeneric(std::move(args), false, tx, builder);
ZRankGeneric(args, false, tx, builder);
}

void ZSetFamily::ZRevRange(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) {
Expand All @@ -2362,7 +2406,7 @@ void ZSetFamily::ZRevRangeByScore(CmdArgList args, Transaction* tx, SinkReplyBui
}

void ZSetFamily::ZRevRank(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) {
ZRankGeneric(std::move(args), true, tx, builder);
ZRankGeneric(args, true, tx, builder);
}

void ZSetFamily::ZRangeByLex(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) {
Expand Down Expand Up @@ -3213,7 +3257,7 @@ void ZSetFamily::Register(CommandRegistry* registry) {
<< CI{"ZREM", CO::FAST | CO::WRITE, -3, 1, 1, acl::kZRem}.HFUNC(ZRem)
<< CI{"ZRANGE", CO::READONLY, -4, 1, 1, acl::kZRange}.HFUNC(ZRange)
<< CI{"ZRANDMEMBER", CO::READONLY, -2, 1, 1, acl::kZRandMember}.HFUNC(ZRandMember)
<< CI{"ZRANK", CO::READONLY | CO::FAST, 3, 1, 1, acl::kZRank}.HFUNC(ZRank)
<< CI{"ZRANK", CO::READONLY | CO::FAST, -3, 1, 1, acl::kZRank}.HFUNC(ZRank)
<< CI{"ZRANGEBYLEX", CO::READONLY, -4, 1, 1, acl::kZRangeByLex}.HFUNC(ZRangeByLex)
<< CI{"ZRANGEBYSCORE", CO::READONLY, -4, 1, 1, acl::kZRangeByScore}.HFUNC(ZRangeByScore)
<< CI{"ZRANGESTORE", CO::WRITE | CO::DENYOOM, -5, 1, 2, acl::kZRangeStore}.HFUNC(ZRangeStore)
Expand All @@ -3226,7 +3270,7 @@ void ZSetFamily::Register(CommandRegistry* registry) {
<< CI{"ZREVRANGEBYLEX", CO::READONLY, -4, 1, 1, acl::kZRevRangeByLex}.HFUNC(ZRevRangeByLex)
<< CI{"ZREVRANGEBYSCORE", CO::READONLY, -4, 1, 1, acl::kZRevRangeByScore}.HFUNC(
ZRevRangeByScore)
<< CI{"ZREVRANK", CO::READONLY | CO::FAST, 3, 1, 1, acl::kZRevRank}.HFUNC(ZRevRank)
<< CI{"ZREVRANK", CO::READONLY | CO::FAST, -3, 1, 1, acl::kZRevRank}.HFUNC(ZRevRank)
<< CI{"ZSCAN", CO::READONLY, -3, 1, 1, acl::kZScan}.HFUNC(ZScan)
<< CI{"ZUNION", CO::READONLY | CO::VARIADIC_KEYS, -3, 2, 2, acl::kZUnion}.HFUNC(ZUnion)
<< CI{"ZUNIONSTORE", kStoreMask, -4, 3, 3, acl::kZUnionStore}.HFUNC(ZUnionStore)
Expand Down
30 changes: 30 additions & 0 deletions src/server/zset_family_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -279,13 +279,43 @@ TEST_F(ZSetFamilyTest, ZRangeRank) {
EXPECT_EQ(2, CheckedInt({"zcount", "x", "1.1", "2.1"}));
EXPECT_EQ(1, CheckedInt({"zcount", "x", "(1.1", "2.1"}));
EXPECT_EQ(0, CheckedInt({"zcount", "y", "(1.1", "2.1"}));
}

TEST_F(ZSetFamilyTest, ZRank) {
Run({"zadd", "x", "1.1", "a", "2.1", "b"});
EXPECT_EQ(0, CheckedInt({"zrank", "x", "a"}));
EXPECT_EQ(1, CheckedInt({"zrank", "x", "b"}));
EXPECT_EQ(1, CheckedInt({"zrevrank", "x", "a"}));
EXPECT_EQ(0, CheckedInt({"zrevrank", "x", "b"}));
EXPECT_THAT(Run({"zrevrank", "x", "c"}), ArgType(RespExpr::NIL));
EXPECT_THAT(Run({"zrank", "y", "c"}), ArgType(RespExpr::NIL));
EXPECT_THAT(Run({"zrevrank", "x", "c", "WITHSCORE"}), ArgType(RespExpr::NIL));
EXPECT_THAT(Run({"zrank", "y", "c", "WITHSCORE"}), ArgType(RespExpr::NIL));

auto resp = Run({"zrank", "x", "a", "WITHSCORE"});
ASSERT_THAT(resp, ArgType(RespExpr::ARRAY));
ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), "1.1"));

resp = Run({"zrank", "x", "b", "WITHSCORE"});
ASSERT_THAT(resp, ArgType(RespExpr::ARRAY));
ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(1), "2.1"));

resp = Run({"zrevrank", "x", "a", "WITHSCORE"});
ASSERT_THAT(resp, ArgType(RespExpr::ARRAY));
ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(1), "1.1"));

resp = Run({"zrevrank", "x", "b", "WITHSCORE"});
ASSERT_THAT(resp, ArgType(RespExpr::ARRAY));
ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), "2.1"));

resp = Run({"zrank", "x", "a", "WITHSCORES"});
ASSERT_THAT(resp, ErrArg("syntax error"));

resp = Run({"zrank", "x", "a", "WITHSCORES", "42"});
ASSERT_THAT(resp, ErrArg("wrong number of arguments for 'zrank' command"));

resp = Run({"zrevrank", "x", "a", "WITHSCORES", "42"});
ASSERT_THAT(resp, ErrArg("wrong number of arguments for 'zrevrank' command"));
}

TEST_F(ZSetFamilyTest, LargeSet) {
Expand Down

0 comments on commit f16a325

Please sign in to comment.