Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(search_family): Support multiple fields in SORTBY option in the FT.AGGREGATE command. SECOND PR #4232

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 61 additions & 21 deletions src/server/search/aggregator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ namespace dfly::aggregate {
namespace {

struct GroupStep {
PipelineResult operator()(std::vector<DocValues> values) {
PipelineResult operator()(PipelineResult result) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PipelineResult contains more than result.values but you use only this field, maybe you don't need to change the type of parameter

// Separate items into groups
absl::flat_hash_map<absl::FixedArray<Value>, std::vector<DocValues>> groups;
for (auto& value : values) {
for (auto& value : result.values) {
groups[Extract(value)].push_back(std::move(value));
}

Expand All @@ -28,7 +28,18 @@ struct GroupStep {
}
out.push_back(std::move(doc));
}
return out;

absl::flat_hash_set<std::string> fields_to_print;
fields_to_print.reserve(fields_.size() + reducers_.size());

for (auto& field : fields_) {
fields_to_print.insert(std::move(field));
}
for (auto& reducer : reducers_) {
fields_to_print.insert(std::move(reducer.result_field));
}

return {std::move(out), std::move(fields_to_print)};
}

absl::FixedArray<Value> Extract(const DocValues& dv) {
Expand Down Expand Up @@ -103,35 +114,64 @@ PipelineStep MakeGroupStep(absl::Span<const std::string_view> fields,
return GroupStep{std::vector<std::string>(fields.begin(), fields.end()), std::move(reducers)};
}

PipelineStep MakeSortStep(std::string_view field, bool descending) {
return [field = std::string(field), descending](std::vector<DocValues> values) -> PipelineResult {
std::sort(values.begin(), values.end(), [field](const DocValues& l, const DocValues& r) {
auto it1 = l.find(field);
auto it2 = r.find(field);
return it1 == l.end() || (it2 != r.end() && it1->second < it2->second);
});
if (descending)
std::reverse(values.begin(), values.end());
return values;
PipelineStep MakeSortStep(SortParams sort_params) {
return [params = std::move(sort_params)](PipelineResult result) -> PipelineResult {
auto comparator = [&params](const DocValues& l, const DocValues& r) {
for (const auto& [field, order] : params.fields) {
auto l_it = l.find(field);
auto r_it = r.find(field);

// Handle cases where one of the fields is missing
if (l_it == l.end() || r_it == r.end()) {
return l_it != l.end() || r_it == r.end();
}

if (l_it->second < r_it->second) {
return order == SortParams::SortOrder::ASC;
}
if (l_it->second > r_it->second) {
return order == SortParams::SortOrder::DESC;
}
}
return false; // Elements are equal
};

auto& values = result.values;
if (params.SortAll()) {
std::sort(values.begin(), values.end(), comparator);
} else {
DCHECK_GE(params.max, 0);
const size_t limit = std::min(values.size(), size_t(params.max));
std::partial_sort(values.begin(), values.begin() + limit, values.end(), comparator);
values.resize(limit);
}

for (auto& field : params.fields) {
result.fields_to_print.insert(field.first); // TODO: move
}

return result;
};
}

PipelineStep MakeLimitStep(size_t offset, size_t num) {
return [offset, num](std::vector<DocValues> values) -> PipelineResult {
return [offset, num](PipelineResult result) {
auto& values = result.values;
values.erase(values.begin(), values.begin() + std::min(offset, values.size()));
values.resize(std::min(num, values.size()));
return values;
return result;
};
}

PipelineResult Process(std::vector<DocValues> values, absl::Span<const PipelineStep> steps) {
PipelineResult Process(std::vector<DocValues> values,
absl::Span<const std::string_view> fields_to_print,
absl::Span<const PipelineStep> steps) {
PipelineResult result{std::move(values), {fields_to_print.begin(), fields_to_print.end()}};
for (auto& step : steps) {
auto result = step(std::move(values));
if (!result.has_value())
return result;
values = std::move(result.value());
PipelineResult step_result = step(std::move(result));
result = std::move(step_result);
}
return values;
return result;
}

} // namespace dfly::aggregate
38 changes: 33 additions & 5 deletions src/server/search/aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#pragma once

#include <absl/container/flat_hash_map.h>
#include <absl/container/flat_hash_set.h>
#include <absl/types/span.h>

#include <string>
Expand All @@ -19,10 +20,16 @@ namespace dfly::aggregate {
using Value = ::dfly::search::SortableValue;
using DocValues = absl::flat_hash_map<std::string, Value>; // documents sent through the pipeline

// TODO: Replace DocValues with compact linear search map instead of hash map
struct PipelineResult {
// Values to be passed to the next step
// TODO: Replace DocValues with compact linear search map instead of hash map
std::vector<DocValues> values;

using PipelineResult = io::Result<std::vector<DocValues>, facade::ErrorReply>;
using PipelineStep = std::function<PipelineResult(std::vector<DocValues>)>; // Group, Sort, etc.
// Fields from values to be printed
absl::flat_hash_set<std::string> fields_to_print;
};

using PipelineStep = std::function<PipelineResult(PipelineResult)>; // Group, Sort, etc.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider modifying current values instead of doing move


// Iterator over Span<DocValues> that yields doc[field] or monostate if not present.
// Extra clumsy for STL compatibility!
Expand Down Expand Up @@ -66,6 +73,25 @@ struct Reducer {
Func func;
};

struct SortParams {
enum class SortOrder { ASC, DESC };

constexpr static int64_t kSortAll = -1;

bool SortAll() const {
return max == kSortAll;
}

/* Fields to sort by. If multiple fields are provided, sorting works hierarchically:
- First, the i-th field is compared.
- If the i-th field values are equal, the (i + 1)-th field is compared, and so on. */
absl::InlinedVector<std::pair<std::string, SortOrder>, 1> fields;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest using a bigger default capacity


/* Max number of elements to include in the sorted result.
If set, only the first [max] elements are fully sorted using partial_sort. */
int64_t max = kSortAll;
};

enum class ReducerFunc { COUNT, COUNT_DISTINCT, SUM, AVG, MAX, MIN };

// Find reducer function by uppercase name (COUNT, MAX, etc...), empty functor if not found
Expand All @@ -76,12 +102,14 @@ PipelineStep MakeGroupStep(absl::Span<const std::string_view> fields,
std::vector<Reducer> reducers);

// Make `SORTBY field [DESC]` step
PipelineStep MakeSortStep(std::string_view field, bool descending = false);
PipelineStep MakeSortStep(SortParams sort_params);

// Make `LIMIT offset num` step
PipelineStep MakeLimitStep(size_t offset, size_t num);

// Process values with given steps
PipelineResult Process(std::vector<DocValues> values, absl::Span<const PipelineStep> steps);
PipelineResult Process(std::vector<DocValues> values,
absl::Span<const std::string_view> fields_to_print,
absl::Span<const PipelineStep> steps);

} // namespace dfly::aggregate
57 changes: 28 additions & 29 deletions src/server/search/aggregator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@ TEST(AggregatorTest, Sort) {
DocValues{{"a", 0.5}},
DocValues{{"a", 1.5}},
};
PipelineStep steps[] = {MakeSortStep("a", false)};

auto result = Process(values, steps);
SortParams params;
params.fields.emplace_back("a", SortParams::SortOrder::ASC);
PipelineStep steps[] = {MakeSortStep(std::move(params))};

EXPECT_TRUE(result);
EXPECT_EQ(result->at(0)["a"], Value(0.5));
EXPECT_EQ(result->at(1)["a"], Value(1.0));
EXPECT_EQ(result->at(2)["a"], Value(1.5));
auto result = Process(values, {"a"}, steps);

EXPECT_EQ(result.values[0]["a"], Value(0.5));
EXPECT_EQ(result.values[1]["a"], Value(1.0));
EXPECT_EQ(result.values[2]["a"], Value(1.5));
}

TEST(AggregatorTest, Limit) {
Expand All @@ -35,12 +37,11 @@ TEST(AggregatorTest, Limit) {
};
PipelineStep steps[] = {MakeLimitStep(1, 2)};

auto result = Process(values, steps);
auto result = Process(values, {"i"}, steps);

EXPECT_TRUE(result);
EXPECT_EQ(result->size(), 2);
EXPECT_EQ(result->at(0)["i"], Value(2.0));
EXPECT_EQ(result->at(1)["i"], Value(3.0));
EXPECT_EQ(result.values.size(), 2);
EXPECT_EQ(result.values[0]["i"], Value(2.0));
EXPECT_EQ(result.values[1]["i"], Value(3.0));
}

TEST(AggregatorTest, SimpleGroup) {
Expand All @@ -54,12 +55,11 @@ TEST(AggregatorTest, SimpleGroup) {
std::string_view fields[] = {"tag"};
PipelineStep steps[] = {MakeGroupStep(fields, {})};

auto result = Process(values, steps);
EXPECT_TRUE(result);
EXPECT_EQ(result->size(), 2);
auto result = Process(values, {"i", "tag"}, steps);
EXPECT_EQ(result.values.size(), 2);

EXPECT_EQ(result->at(0).size(), 1);
std::set<Value> groups{result->at(0)["tag"], result->at(1)["tag"]};
EXPECT_EQ(result.values[0].size(), 1);
std::set<Value> groups{result.values[0]["tag"], result.values[1]["tag"]};
std::set<Value> expected{"even", "odd"};
EXPECT_EQ(groups, expected);
}
Expand All @@ -83,25 +83,24 @@ TEST(AggregatorTest, GroupWithReduce) {
Reducer{"null-field", "distinct-null", FindReducerFunc(ReducerFunc::COUNT_DISTINCT)}};
PipelineStep steps[] = {MakeGroupStep(fields, std::move(reducers))};

auto result = Process(values, steps);
EXPECT_TRUE(result);
EXPECT_EQ(result->size(), 2);
auto result = Process(values, {"i", "half-i", "tag"}, steps);
EXPECT_EQ(result.values.size(), 2);

// Reorder even first
if (result->at(0).at("tag") == Value("odd"))
std::swap(result->at(0), result->at(1));
if (result.values[0].at("tag") == Value("odd"))
std::swap(result.values[0], result.values[1]);

// Even
EXPECT_EQ(result->at(0).at("count"), Value{(double)5});
EXPECT_EQ(result->at(0).at("sum-i"), Value{(double)2 + 4 + 6 + 8});
EXPECT_EQ(result->at(0).at("distinct-hi"), Value{(double)3});
EXPECT_EQ(result->at(0).at("distinct-null"), Value{(double)1});
EXPECT_EQ(result.values[0].at("count"), Value{(double)5});
EXPECT_EQ(result.values[0].at("sum-i"), Value{(double)2 + 4 + 6 + 8});
EXPECT_EQ(result.values[0].at("distinct-hi"), Value{(double)3});
EXPECT_EQ(result.values[0].at("distinct-null"), Value{(double)1});

// Odd
EXPECT_EQ(result->at(1).at("count"), Value{(double)5});
EXPECT_EQ(result->at(1).at("sum-i"), Value{(double)1 + 3 + 5 + 7 + 9});
EXPECT_EQ(result->at(1).at("distinct-hi"), Value{(double)3});
EXPECT_EQ(result->at(1).at("distinct-null"), Value{(double)1});
EXPECT_EQ(result.values[1].at("count"), Value{(double)5});
EXPECT_EQ(result.values[1].at("sum-i"), Value{(double)1 + 3 + 5 + 7 + 9});
EXPECT_EQ(result.values[1].at("distinct-hi"), Value{(double)3});
EXPECT_EQ(result.values[1].at("distinct-null"), Value{(double)1});
}

} // namespace dfly::aggregate
80 changes: 67 additions & 13 deletions src/server/search/search_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,42 @@ optional<SearchParams> ParseSearchParamsOrReply(CmdArgParser* parser, SinkReplyB
return params;
}

std::optional<aggregate::SortParams> ParseAggregatorSortParams(CmdArgParser* parser) {
using SordOrder = aggregate::SortParams::SortOrder;

size_t strings_num = parser->Next<size_t>();

aggregate::SortParams sort_params;
sort_params.fields.reserve(strings_num / 2);

while (parser->HasNext() && strings_num > 0) {
// TODO: Throw an error if the field has no '@' sign at the beginning
std::string_view parsed_field = ParseFieldWithAtSign(parser);
strings_num--;

SordOrder sord_order = SordOrder::ASC;
if (strings_num > 0) {
auto order = parser->TryMapNext("ASC", SordOrder::ASC, "DESC", SordOrder::DESC);
if (order) {
sord_order = order.value();
strings_num--;
}
}

sort_params.fields.emplace_back(parsed_field, sord_order);
}

if (strings_num) {
return std::nullopt;
}

if (parser->Check("MAX")) {
sort_params.max = parser->Next<size_t>();
}

return sort_params;
}

optional<AggregateParams> ParseAggregatorParamsOrReply(CmdArgParser parser,
SinkReplyBuilder* builder) {
AggregateParams params;
Expand Down Expand Up @@ -369,11 +405,13 @@ optional<AggregateParams> ParseAggregatorParamsOrReply(CmdArgParser parser,

// SORTBY nargs
if (parser.Check("SORTBY")) {
parser.ExpectTag("1");
string_view field = parser.Next();
bool desc = bool(parser.Check("DESC"));
auto sort_params = ParseAggregatorSortParams(&parser);
if (!sort_params) {
builder->SendError("bad arguments for SORTBY: specified invalid number of strings");
return nullopt;
}

params.steps.push_back(aggregate::MakeSortStep(field, desc));
params.steps.push_back(aggregate::MakeSortStep(std::move(sort_params).value()));
continue;
}

Expand Down Expand Up @@ -980,22 +1018,38 @@ void SearchFamily::FtAggregate(CmdArgList args, Transaction* tx, SinkReplyBuilde
make_move_iterator(sub_results.end()));
}

auto agg_results = aggregate::Process(std::move(values), params->steps);
if (!agg_results.has_value())
return builder->SendError(agg_results.error());
std::vector<std::string_view> load_fields;
if (params->load_fields) {
load_fields.reserve(params->load_fields->size());
for (const auto& field : params->load_fields.value()) {
load_fields.push_back(field.GetShortName());
}
}

auto agg_results = aggregate::Process(std::move(values), load_fields, params->steps);

size_t result_size = agg_results->size();
auto* rb = static_cast<RedisReplyBuilder*>(builder);
auto sortable_value_sender = SortableValueSender(rb);

const size_t result_size = agg_results.values.size();
rb->StartArray(result_size + 1);
rb->SendLong(result_size);

for (const auto& result : agg_results.value()) {
rb->StartArray(result.size() * 2);
for (const auto& [k, v] : result) {
rb->SendBulkString(k);
std::visit(sortable_value_sender, v);
for (const auto& value : agg_results.values) {
size_t fields_count = 0;
for (const auto& field : agg_results.fields_to_print) {
if (value.find(field) != value.end()) {
fields_count++;
}
}

rb->StartArray(fields_count * 2);
for (const auto& field : agg_results.fields_to_print) {
auto it = value.find(field);
if (it != value.end()) {
rb->SendBulkString(field);
std::visit(sortable_value_sender, it->second);
}
}
}
}
Expand Down
Loading
Loading