Skip to content

Commit

Permalink
fix(search_family): Support multiple fields in SORTBY option in the F…
Browse files Browse the repository at this point in the history
…T.AGGREGATE command

fixes dragonfly#3631

Signed-off-by: Stepan Bagritsevich <[email protected]>
  • Loading branch information
BagritsevichStepan committed Nov 29, 2024
1 parent c108b12 commit e039560
Show file tree
Hide file tree
Showing 5 changed files with 251 additions and 23 deletions.
43 changes: 32 additions & 11 deletions src/server/search/aggregator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,21 +114,42 @@ PipelineStep MakeGroupStep(absl::Span<const std::string_view> fields,
return GroupStep{std::vector<std::string>(fields.begin(), fields.end()), std::move(reducers)};
}

PipelineStep MakeSortStep(std::string_view field, bool descending) {
return [field = std::string(field), descending](PipelineResult result) -> PipelineResult {
auto& values = result.values;
PipelineStep MakeSortStep(SortParams sort_params) {
return [params = std::move(sort_params)](PipelineResult result) -> PipelineResult {
auto comparator = [&params](const DocValues& l, const DocValues& r) {
for (const auto& [field, order] : params.fields) {
auto l_it = l.find(field);
auto r_it = r.find(field);

// Handle cases where one of the fields is missing
if (l_it == l.end() || r_it == r.end()) {
return l_it != l.end() || r_it == r.end();
}

if (l_it->second < r_it->second) {
return order == SortParams::SortOrder::ASC;
}
if (l_it->second > r_it->second) {
return order == SortParams::SortOrder::DESC;
}
}
return false; // Elements are equal
};

std::sort(values.begin(), values.end(), [field](const DocValues& l, const DocValues& r) {
auto it1 = l.find(field);
auto it2 = r.find(field);
return it1 == l.end() || (it2 != r.end() && it1->second < it2->second);
});
auto& values = result.values;
if (params.SortAll()) {
std::sort(values.begin(), values.end(), comparator);
} else {
DCHECK_GE(params.max, 0);
const size_t limit = std::min(values.size(), size_t(params.max));
std::partial_sort(values.begin(), values.begin() + limit, values.end(), comparator);
values.resize(limit);
}

if (descending) {
std::reverse(values.begin(), values.end());
for (auto& field : params.fields) {
result.fields_to_print.insert(field.first); // TODO: move
}

result.fields_to_print.insert(field);
return result;
};
}
Expand Down
21 changes: 20 additions & 1 deletion src/server/search/aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,25 @@ struct Reducer {
Func func;
};

struct SortParams {
enum class SortOrder { ASC, DESC };

constexpr static int64_t kSortAll = -1;

bool SortAll() const {
return max == kSortAll;
}

/* Fields to sort by. If multiple fields are provided, sorting works hierarchically:
- First, the i-th field is compared.
- If the i-th field values are equal, the (i + 1)-th field is compared, and so on. */
absl::InlinedVector<std::pair<std::string, SortOrder>, 1> fields;

/* Max number of elements to include in the sorted result.
If set, only the first [max] elements are fully sorted using partial_sort. */
int64_t max = kSortAll;
};

enum class ReducerFunc { COUNT, COUNT_DISTINCT, SUM, AVG, MAX, MIN };

// Find reducer function by uppercase name (COUNT, MAX, etc...), empty functor if not found
Expand All @@ -83,7 +102,7 @@ PipelineStep MakeGroupStep(absl::Span<const std::string_view> fields,
std::vector<Reducer> reducers);

// Make `SORTBY field [DESC]` step
PipelineStep MakeSortStep(std::string_view field, bool descending = false);
PipelineStep MakeSortStep(SortParams sort_params);

// Make `LIMIT offset num` step
PipelineStep MakeLimitStep(size_t offset, size_t num);
Expand Down
5 changes: 4 additions & 1 deletion src/server/search/aggregator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ TEST(AggregatorTest, Sort) {
DocValues{{"a", 0.5}},
DocValues{{"a", 1.5}},
};
PipelineStep steps[] = {MakeSortStep("a", false)};

SortParams params;
params.fields.emplace_back("a", SortParams::SortOrder::ASC);
PipelineStep steps[] = {MakeSortStep(std::move(params))};

auto result = Process(values, {"a"}, steps);

Expand Down
59 changes: 50 additions & 9 deletions src/server/search/search_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,42 @@ optional<SearchParams> ParseSearchParamsOrReply(CmdArgParser* parser, SinkReplyB
return params;
}

std::optional<aggregate::SortParams> ParseAggregatorSortParams(CmdArgParser* parser) {
using SordOrder = aggregate::SortParams::SortOrder;

size_t strings_num = parser->Next<size_t>();

aggregate::SortParams sort_params;
sort_params.fields.reserve(strings_num / 2);

while (parser->HasNext() && strings_num > 0) {
// TODO: Throw an error if the field has no '@' sign at the beginning
std::string_view parsed_field = ParseFieldWithAtSign(parser);
strings_num--;

SordOrder sord_order = SordOrder::ASC;
if (strings_num > 0) {
auto order = parser->TryMapNext("ASC", SordOrder::ASC, "DESC", SordOrder::DESC);
if (order) {
sord_order = order.value();
strings_num--;
}
}

sort_params.fields.emplace_back(parsed_field, sord_order);
}

if (strings_num) {
return std::nullopt;
}

if (parser->Check("MAX")) {
sort_params.max = parser->Next<size_t>();
}

return sort_params;
}

optional<AggregateParams> ParseAggregatorParamsOrReply(CmdArgParser parser,
SinkReplyBuilder* builder) {
AggregateParams params;
Expand Down Expand Up @@ -369,11 +405,13 @@ optional<AggregateParams> ParseAggregatorParamsOrReply(CmdArgParser parser,

// SORTBY nargs
if (parser.Check("SORTBY")) {
parser.ExpectTag("1");
string_view field = parser.Next();
bool desc = bool(parser.Check("DESC"));
auto sort_params = ParseAggregatorSortParams(&parser);
if (!sort_params) {
builder->SendError("bad arguments for SORTBY: specified invalid number of strings");
return nullopt;
}

params.steps.push_back(aggregate::MakeSortStep(field, desc));
params.steps.push_back(aggregate::MakeSortStep(std::move(sort_params).value()));
continue;
}

Expand Down Expand Up @@ -997,17 +1035,20 @@ void SearchFamily::FtAggregate(CmdArgList args, Transaction* tx, SinkReplyBuilde
rb->StartArray(result_size + 1);
rb->SendLong(result_size);

const size_t field_count = agg_results.fields_to_print.size();
for (const auto& value : agg_results.values) {
rb->StartArray(field_count * 2);
size_t fields_count = 0;
for (const auto& field : agg_results.fields_to_print) {
rb->SendBulkString(field);
if (value.find(field) != value.end()) {
fields_count++;
}
}

rb->StartArray(fields_count * 2);
for (const auto& field : agg_results.fields_to_print) {
auto it = value.find(field);
if (it != value.end()) {
rb->SendBulkString(field);
std::visit(sortable_value_sender, it->second);
} else {
rb->SendNull();
}
}
}
Expand Down
146 changes: 145 additions & 1 deletion src/server/search/search_family_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -852,7 +852,6 @@ TEST_F(SearchFamilyTest, FtProfileInvalidQuery) {

TEST_F(SearchFamilyTest, FtProfileErrorReply) {
Run({"ft.create", "i1", "schema", "name", "text"});
;

auto resp = Run({"ft.profile", "i1", "not_search", "query", "(a | b) c d"});
EXPECT_THAT(resp, ErrArg("no `SEARCH` or `AGGREGATE` provided"));
Expand Down Expand Up @@ -1657,4 +1656,149 @@ TEST_F(SearchFamilyTest, AggregateResultFields) {
IsMap("b", "\"5\"", "a", "4", "count", "1")));
}

TEST_F(SearchFamilyTest, AggregateSortByJson) {
Run({"JSON.SET", "j1", "$", R"({"name": "first", "number": 1200, "group": "first"})"});
Run({"JSON.SET", "j2", "$", R"({"name": "second", "number": 800, "group": "first"})"});
Run({"JSON.SET", "j3", "$", R"({"name": "third", "number": 300, "group": "first"})"});
Run({"JSON.SET", "j4", "$", R"({"name": "fourth", "number": 400, "group": "second"})"});
Run({"JSON.SET", "j5", "$", R"({"name": "fifth", "number": 900, "group": "second"})"});
Run({"JSON.SET", "j6", "$", R"({"name": "sixth", "number": 300, "group": "first"})"});
Run({"JSON.SET", "j7", "$", R"({"name": "seventh", "number": 400, "group": "second"})"});
Run({"JSON.SET", "j8", "$", R"({"name": "eighth", "group": "first"})"});
Run({"JSON.SET", "j9", "$", R"({"name": "ninth", "group": "second"})"});

Run({"FT.CREATE", "index", "ON", "JSON", "SCHEMA", "$.name", "AS", "name", "TEXT", "$.number",
"AS", "number", "NUMERIC", "$.group", "AS", "group", "TAG"});

// Test sorting by name (DESC) and number (ASC)
auto resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "4", "@name", "DESC", "@number", "ASC"});
EXPECT_THAT(resp, IsUnordArrayWithSize(
IsMap("name", "\"third\"", "number", "300"),
IsMap("name", "\"sixth\"", "number", "300"),
IsMap("name", "\"seventh\"", "number", "400"),
IsMap("name", "\"second\"", "number", "800"), IsMap("name", "\"ninth\""),
IsMap("name", "\"fourth\"", "number", "400"),
IsMap("name", "\"first\"", "number", "1200"),
IsMap("name", "\"fifth\"", "number", "900"), IsMap("name", "\"eighth\"")));

// Test sorting by name (ASC) and number (DESC)
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "4", "@name", "ASC", "@number", "DESC"});
EXPECT_THAT(resp, IsUnordArrayWithSize(
IsMap("name", "\"eighth\""), IsMap("name", "\"fifth\"", "number", "900"),
IsMap("name", "\"first\"", "number", "1200"),
IsMap("name", "\"fourth\"", "number", "400"), IsMap("name", "\"ninth\""),
IsMap("name", "\"second\"", "number", "800"),
IsMap("name", "\"seventh\"", "number", "400"),
IsMap("name", "\"sixth\"", "number", "300"),
IsMap("name", "\"third\"", "number", "300")));

// Test sorting by group (ASC), number (DESC), and name
resp = Run(
{"FT.AGGREGATE", "index", "*", "SORTBY", "5", "@group", "ASC", "@number", "DESC", "@name"});
EXPECT_THAT(resp, IsUnordArrayWithSize(
IsMap("group", "\"first\"", "number", "1200", "name", "\"first\""),
IsMap("group", "\"first\"", "number", "800", "name", "\"second\""),
IsMap("group", "\"first\"", "number", "300", "name", "\"sixth\""),
IsMap("group", "\"first\"", "number", "300", "name", "\"third\""),
IsMap("group", "\"first\"", "name", "\"eighth\""),
IsMap("group", "\"second\"", "number", "900", "name", "\"fifth\""),
IsMap("group", "\"second\"", "number", "400", "name", "\"fourth\""),
IsMap("group", "\"second\"", "number", "400", "name", "\"seventh\""),
IsMap("group", "\"second\"", "name", "\"ninth\"")));

// Test sorting by number (ASC), group (DESC), and name
resp = Run(
{"FT.AGGREGATE", "index", "*", "SORTBY", "5", "@number", "ASC", "@group", "DESC", "@name"});
EXPECT_THAT(resp, IsUnordArrayWithSize(
IsMap("number", "300", "group", "\"first\"", "name", "\"sixth\""),
IsMap("number", "300", "group", "\"first\"", "name", "\"third\""),
IsMap("number", "400", "group", "\"second\"", "name", "\"fourth\""),
IsMap("number", "400", "group", "\"second\"", "name", "\"seventh\""),
IsMap("number", "800", "group", "\"first\"", "name", "\"second\""),
IsMap("number", "900", "group", "\"second\"", "name", "\"fifth\""),
IsMap("number", "1200", "group", "\"first\"", "name", "\"first\""),
IsMap("group", "\"second\"", "name", "\"ninth\""),
IsMap("group", "\"first\"", "name", "\"eighth\"")));

// Test sorting with MAX 3
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "@number", "MAX", "3"});
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("number", "300"), IsMap("number", "300"),
IsMap("number", "400")));

// Test sorting with MAX 3
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "2", "@number", "DESC", "MAX", "3"});
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("number", "1200"), IsMap("number", "900"),
IsMap("number", "800")));

// Test sorting by number (ASC) with MAX 999
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "@number", "MAX", "999"});
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("number", "300"), IsMap("number", "300"),
IsMap("number", "400"), IsMap("number", "400"),
IsMap("number", "800"), IsMap("number", "900"),
IsMap("number", "1200"), IsMap(), IsMap()));

// Test sorting by name and number (DESC)
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "3", "@name", "@number", "DESC"});
EXPECT_THAT(resp, IsUnordArrayWithSize(
IsMap("name", "\"eighth\""), IsMap("name", "\"fifth\"", "number", "900"),
IsMap("name", "\"first\"", "number", "1200"),
IsMap("name", "\"fourth\"", "number", "400"), IsMap("name", "\"ninth\""),
IsMap("name", "\"second\"", "number", "800"),
IsMap("name", "\"seventh\"", "number", "400"),
IsMap("name", "\"sixth\"", "number", "300"),
IsMap("name", "\"third\"", "number", "300")));

// Test SORTBY with MAX, GROUPBY, and REDUCE COUNT
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "@name", "MAX", "3", "GROUPBY", "1",
"@number", "REDUCE", "COUNT", "0", "AS", "count"});
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("number", "900", "count", "1"),
IsMap("number", ArgType(RespExpr::NIL), "count", "1"),
IsMap("number", "1200", "count", "1")));

// Test SORTBY with MAX, GROUPBY (0 fields), and REDUCE COUNT
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "@name", "MAX", "3", "GROUPBY", "0",
"REDUCE", "COUNT", "0", "AS", "count"});
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("count", "3")));
}

TEST_F(SearchFamilyTest, AggregateSortByParsingErrors) {
Run({"JSON.SET", "j1", "$", R"({"name": "first", "number": 1200, "group": "first"})"});
Run({"JSON.SET", "j2", "$", R"({"name": "second", "number": 800, "group": "first"})"});
Run({"JSON.SET", "j3", "$", R"({"name": "third", "number": 300, "group": "first"})"});
Run({"JSON.SET", "j4", "$", R"({"name": "fourth", "number": 400, "group": "second"})"});
Run({"JSON.SET", "j5", "$", R"({"name": "fifth", "number": 900, "group": "second"})"});
Run({"JSON.SET", "j6", "$", R"({"name": "sixth", "number": 300, "group": "first"})"});
Run({"JSON.SET", "j7", "$", R"({"name": "seventh", "number": 400, "group": "second"})"});
Run({"JSON.SET", "j8", "$", R"({"name": "eighth", "group": "first"})"});
Run({"JSON.SET", "j9", "$", R"({"name": "ninth", "group": "second"})"});

Run({"FT.CREATE", "index", "ON", "JSON", "SCHEMA", "$.name", "AS", "name", "TEXT", "$.number",
"AS", "number", "NUMERIC", "$.group", "AS", "group", "TAG"});

// Test SORTBY with invalid argument count
auto resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "999", "@name", "@number", "DESC"});
EXPECT_THAT(resp, ErrArg("bad arguments for SORTBY: specified invalid number of strings"));

// Test SORTBY with negative argument count
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "-3", "@name", "@number", "DESC"});
EXPECT_THAT(resp, ErrArg("value is not an integer or out of range"));

// Test MAX with invalid value
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "@name", "MAX", "-10"});
EXPECT_THAT(resp, ErrArg("value is not an integer or out of range"));

// Test MAX without a value
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "@name", "MAX"});
EXPECT_THAT(resp, ErrArg("syntax error"));

// Test SORTBY with a non-existing field
/* Temporary unsupported
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "@nonexistingfield"});
EXPECT_THAT(resp, ErrArg("Property `nonexistingfield` not loaded nor in schema")); */

// Test SORTBY with an invalid value
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "notvalue", "@name"});
EXPECT_THAT(resp, ErrArg("value is not an integer or out of range"));
}

} // namespace dfly

0 comments on commit e039560

Please sign in to comment.