Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(search_family): Add Aggregator class #4290

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 85 additions & 76 deletions src/server/search/aggregator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,62 +10,89 @@ namespace dfly::aggregate {

namespace {

struct GroupStep {
PipelineResult operator()(PipelineResult result) {
// Separate items into groups
absl::flat_hash_map<absl::FixedArray<Value>, std::vector<DocValues>> groups;
for (auto& value : result.values) {
groups[Extract(value)].push_back(std::move(value));
}
using ValuesList = absl::FixedArray<Value>;

// Restore DocValues and apply reducers
std::vector<DocValues> out;
while (!groups.empty()) {
auto node = groups.extract(groups.begin());
DocValues doc = Unpack(std::move(node.key()));
for (auto& reducer : reducers_) {
doc[reducer.result_field] = reducer.func({reducer.source_field, node.mapped()});
}
out.push_back(std::move(doc));
}
ValuesList ExtractFieldsValues(const DocValues& dv, absl::Span<const std::string> fields) {
ValuesList out(fields.size());
for (size_t i = 0; i < fields.size(); i++) {
auto it = dv.find(fields[i]);
out[i] = (it != dv.end()) ? it->second : Value{};
}
return out;
}

absl::flat_hash_set<std::string> fields_to_print;
fields_to_print.reserve(fields_.size() + reducers_.size());
DocValues PackFields(ValuesList values, absl::Span<const std::string> fields) {
DCHECK_EQ(values.size(), fields.size());
DocValues out;
for (size_t i = 0; i < fields.size(); i++)
out[fields[i]] = std::move(values[i]);
return out;
}

for (auto& field : fields_) {
fields_to_print.insert(std::move(field));
}
for (auto& reducer : reducers_) {
fields_to_print.insert(std::move(reducer.result_field));
}
const Value kEmptyValue = Value{};

return {std::move(out), std::move(fields_to_print)};
} // namespace

void Aggregator::DoGroup(absl::Span<const std::string> fields, absl::Span<const Reducer> reducers) {
// Separate items into groups
absl::flat_hash_map<ValuesList, std::vector<DocValues>> groups;
for (auto& value : result.values) {
groups[ExtractFieldsValues(value, fields)].push_back(std::move(value));
}

absl::FixedArray<Value> Extract(const DocValues& dv) {
absl::FixedArray<Value> out(fields_.size());
for (size_t i = 0; i < fields_.size(); i++) {
auto it = dv.find(fields_[i]);
out[i] = (it != dv.end()) ? it->second : Value{};
// Restore DocValues and apply reducers
auto& values = result.values;
values.clear();
values.reserve(groups.size());
while (!groups.empty()) {
auto node = groups.extract(groups.begin());
DocValues doc = PackFields(std::move(node.key()), fields);
for (auto& reducer : reducers) {
doc[reducer.result_field] = reducer.func({reducer.source_field, node.mapped()});
}
return out;
values.push_back(std::move(doc));
}

DocValues Unpack(absl::FixedArray<Value>&& values) {
DCHECK_EQ(values.size(), fields_.size());
DocValues out;
for (size_t i = 0; i < fields_.size(); i++)
out[fields_[i]] = std::move(values[i]);
return out;
auto& fields_to_print = result.fields_to_print;
fields_to_print.clear();
fields_to_print.reserve(fields.size() + reducers.size());

for (auto& field : fields) {
fields_to_print.insert(field);
}
for (auto& reducer : reducers) {
fields_to_print.insert(reducer.result_field);
}
}

std::vector<std::string> fields_;
std::vector<Reducer> reducers_;
};
void Aggregator::DoSort(std::string_view field, bool descending) {
auto comparator = [&](const DocValues& l, const DocValues& r) {
auto l_it = l.find(field);
auto r_it = r.find(field);

const Value kEmptyValue = Value{};
// Handle cases where one of the fields is missing
if (l_it == l.end() || r_it == r.end()) {
return l_it != l.end() || r_it == r.end();
}
BorysTheDev marked this conversation as resolved.
Show resolved Hide resolved
if (l_it->second < r_it->second) {
return !descending;
}
if (l_it->second > r_it->second) {
return descending;
}
return true; // Elements are equal
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we return true if elements are equal for less operator? As I understand you could take previous implementation and return something like "res != descending"

Copy link
Contributor Author

@BagritsevichStepan BagritsevichStepan Dec 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it makes sense to return descending here. But just to note, we are not guaranteeing any strict ordering in the result when elements are equal. This is because the final order also depends on the initial distribution of keys across shards (results from shards are joined into one vector, and then a sort is performed on this array).

So, it is not a "stable" sort, meaning that the initial order of the data is influenced by the key distribution across shards, which may give the user the impression of an unstable sort

Upd.: After discussion, it was decided to return false if the elements are equal.


} // namespace
std::sort(result.values.begin(), result.values.end(), std::move(comparator));

result.fields_to_print.insert(field);
}

void Aggregator::DoLimit(size_t offset, size_t num) {
auto& values = result.values;
values.erase(values.begin(), values.begin() + std::min(offset, values.size()));
values.resize(std::min(num, values.size()));
}

const Value& ValueIterator::operator*() const {
auto it = values_.front().find(field_);
Expand Down Expand Up @@ -109,48 +136,30 @@ Reducer::Func FindReducerFunc(ReducerFunc name) {
return nullptr;
}

PipelineStep MakeGroupStep(absl::Span<const std::string_view> fields,
std::vector<Reducer> reducers) {
return GroupStep{std::vector<std::string>(fields.begin(), fields.end()), std::move(reducers)};
AggregationStep MakeGroupStep(std::vector<std::string> fields, std::vector<Reducer> reducers) {
return [fields = std::move(fields), reducers = std::move(reducers)](Aggregator* aggregator) {
aggregator->DoGroup(fields, reducers);
};
}

PipelineStep MakeSortStep(std::string_view field, bool descending) {
return [field = std::string(field), descending](PipelineResult result) -> PipelineResult {
auto& values = result.values;

std::sort(values.begin(), values.end(), [field](const DocValues& l, const DocValues& r) {
auto it1 = l.find(field);
auto it2 = r.find(field);
return it1 == l.end() || (it2 != r.end() && it1->second < it2->second);
});

if (descending) {
std::reverse(values.begin(), values.end());
}

result.fields_to_print.insert(field);
return result;
AggregationStep MakeSortStep(std::string field, bool descending) {
return [field = std::move(field), descending](Aggregator* aggregator) {
aggregator->DoSort(field, descending);
};
}

PipelineStep MakeLimitStep(size_t offset, size_t num) {
return [offset, num](PipelineResult result) {
auto& values = result.values;
values.erase(values.begin(), values.begin() + std::min(offset, values.size()));
values.resize(std::min(num, values.size()));
return result;
};
AggregationStep MakeLimitStep(size_t offset, size_t num) {
return [=](Aggregator* aggregator) { aggregator->DoLimit(offset, num); };
}

PipelineResult Process(std::vector<DocValues> values,
absl::Span<const std::string_view> fields_to_print,
absl::Span<const PipelineStep> steps) {
PipelineResult result{std::move(values), {fields_to_print.begin(), fields_to_print.end()}};
AggregationResult Process(std::vector<DocValues> values,
absl::Span<const std::string_view> fields_to_print,
absl::Span<const AggregationStep> steps) {
Aggregator aggregator{std::move(values), {fields_to_print.begin(), fields_to_print.end()}};
for (auto& step : steps) {
PipelineResult step_result = step(std::move(result));
result = std::move(step_result);
step(&aggregator);
}
return result;
return aggregator.result;
}

} // namespace dfly::aggregate
35 changes: 23 additions & 12 deletions src/server/search/aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,31 @@

namespace dfly::aggregate {

struct Reducer;

using Value = ::dfly::search::SortableValue;
using DocValues = absl::flat_hash_map<std::string, Value>; // documents sent through the pipeline

struct PipelineResult {
// DocValues sent through the pipeline
// TODO: Replace DocValues with compact linear search map instead of hash map
using DocValues = absl::flat_hash_map<std::string_view, Value>;

struct AggregationResult {
// Values to be passed to the next step
// TODO: Replace DocValues with compact linear search map instead of hash map
std::vector<DocValues> values;

// Fields from values to be printed
absl::flat_hash_set<std::string> fields_to_print;
absl::flat_hash_set<std::string_view> fields_to_print;
};

struct Aggregator {
void DoGroup(absl::Span<const std::string> fields, absl::Span<const Reducer> reducers);
void DoSort(std::string_view field, bool descending = false);
void DoLimit(size_t offset, size_t num);

AggregationResult result;
};

using PipelineStep = std::function<PipelineResult(PipelineResult)>; // Group, Sort, etc.
using AggregationStep = std::function<void(Aggregator*)>; // Group, Sort, etc.

// Iterator over Span<DocValues> that yields doc[field] or monostate if not present.
// Extra clumsy for STL compatibility!
Expand Down Expand Up @@ -79,18 +91,17 @@ enum class ReducerFunc { COUNT, COUNT_DISTINCT, SUM, AVG, MAX, MIN };
Reducer::Func FindReducerFunc(ReducerFunc name);

// Make `GROUPBY [fields...]` with REDUCE step
PipelineStep MakeGroupStep(absl::Span<const std::string_view> fields,
std::vector<Reducer> reducers);
AggregationStep MakeGroupStep(std::vector<std::string> fields, std::vector<Reducer> reducers);

// Make `SORTBY field [DESC]` step
PipelineStep MakeSortStep(std::string_view field, bool descending = false);
AggregationStep MakeSortStep(std::string field, bool descending = false);

// Make `LIMIT offset num` step
PipelineStep MakeLimitStep(size_t offset, size_t num);
AggregationStep MakeLimitStep(size_t offset, size_t num);

// Process values with given steps
PipelineResult Process(std::vector<DocValues> values,
absl::Span<const std::string_view> fields_to_print,
absl::Span<const PipelineStep> steps);
AggregationResult Process(std::vector<DocValues> values,
absl::Span<const std::string_view> fields_to_print,
absl::Span<const AggregationStep> steps);

} // namespace dfly::aggregate
16 changes: 10 additions & 6 deletions src/server/search/aggregator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@ namespace dfly::aggregate {

using namespace std::string_literals;

using StepsList = std::vector<AggregationStep>;

TEST(AggregatorTest, Sort) {
std::vector<DocValues> values = {
DocValues{{"a", 1.0}},
DocValues{{"a", 0.5}},
DocValues{{"a", 1.5}},
};
PipelineStep steps[] = {MakeSortStep("a", false)};
StepsList steps = {MakeSortStep("a", false)};

auto result = Process(values, {"a"}, steps);

Expand All @@ -32,7 +34,8 @@ TEST(AggregatorTest, Limit) {
DocValues{{"i", 3.0}},
DocValues{{"i", 4.0}},
};
PipelineStep steps[] = {MakeLimitStep(1, 2)};

StepsList steps = {MakeLimitStep(1, 2)};

auto result = Process(values, {"i"}, steps);

Expand All @@ -49,8 +52,8 @@ TEST(AggregatorTest, SimpleGroup) {
DocValues{{"i", 4.0}, {"tag", "even"}},
};

std::string_view fields[] = {"tag"};
PipelineStep steps[] = {MakeGroupStep(fields, {})};
std::vector<std::string> fields = {"tag"};
StepsList steps = {MakeGroupStep(std::move(fields), {})};

auto result = Process(values, {"i", "tag"}, steps);
EXPECT_EQ(result.values.size(), 2);
Expand All @@ -72,13 +75,14 @@ TEST(AggregatorTest, GroupWithReduce) {
});
}

std::string_view fields[] = {"tag"};
std::vector<std::string> fields = {"tag"};
std::vector<Reducer> reducers = {
Reducer{"", "count", FindReducerFunc(ReducerFunc::COUNT)},
Reducer{"i", "sum-i", FindReducerFunc(ReducerFunc::SUM)},
Reducer{"half-i", "distinct-hi", FindReducerFunc(ReducerFunc::COUNT_DISTINCT)},
Reducer{"null-field", "distinct-null", FindReducerFunc(ReducerFunc::COUNT_DISTINCT)}};
PipelineStep steps[] = {MakeGroupStep(fields, std::move(reducers))};

StepsList steps = {MakeGroupStep(std::move(fields), std::move(reducers))};

auto result = Process(values, {"i", "half-i", "tag"}, steps);
EXPECT_EQ(result.values.size(), 2);
Expand Down
2 changes: 1 addition & 1 deletion src/server/search/doc_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ struct AggregateParams {
search::QueryParams params;

std::optional<SearchFieldsList> load_fields;
std::vector<aggregate::PipelineStep> steps;
std::vector<aggregate::AggregationStep> steps;
};

// Stores basic info about a document index.
Expand Down
29 changes: 20 additions & 9 deletions src/server/search/search_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -320,20 +320,23 @@ optional<AggregateParams> ParseAggregatorParamsOrReply(CmdArgParser parser,
while (parser.HasNext()) {
// GROUPBY nargs property [property ...]
if (parser.Check("GROUPBY")) {
vector<string_view> fields(parser.Next<size_t>());
for (string_view& field : fields) {
size_t num_fields = parser.Next<size_t>();

std::vector<std::string> fields;
fields.reserve(num_fields);
while (num_fields > 0 && parser.HasNext()) {
auto parsed_field = ParseFieldWithAtSign(&parser);

/*
TODO: Throw an error if the field has no '@' sign at the beginning

if (!parsed_field) {
builder->SendError(absl::StrCat("bad arguments for GROUPBY: Unknown property '", field,
"'. Did you mean '@", field, "`?"));
return nullopt;
} */

field = parsed_field;
fields.emplace_back(parsed_field);
num_fields--;
}

vector<aggregate::Reducer> reducers;
Expand Down Expand Up @@ -363,7 +366,7 @@ optional<AggregateParams> ParseAggregatorParamsOrReply(CmdArgParser parser,
aggregate::Reducer{std::move(source_field), std::move(result_field), std::move(func)});
}

params.steps.push_back(aggregate::MakeGroupStep(fields, std::move(reducers)));
params.steps.push_back(aggregate::MakeGroupStep(std::move(fields), std::move(reducers)));
continue;
}

Expand All @@ -373,7 +376,7 @@ optional<AggregateParams> ParseAggregatorParamsOrReply(CmdArgParser parser,
string_view field = parser.Next();
bool desc = bool(parser.Check("DESC"));

params.steps.push_back(aggregate::MakeSortStep(field, desc));
params.steps.push_back(aggregate::MakeSortStep(std::string{field}, desc));
continue;
}

Expand Down Expand Up @@ -975,10 +978,18 @@ void SearchFamily::FtAggregate(CmdArgList args, const CommandContext& cmd_cntx)
return OpStatus::OK;
});

vector<aggregate::DocValues> values;
// ResultContainer is absl::flat_hash_map<std::string, search::SortableValue>
// DocValues is absl::flat_hash_map<std::string_view, SortableValue>
// Keys of values should point to the keys of the query_results
std::vector<aggregate::DocValues> values;
for (auto& sub_results : query_results) {
values.insert(values.end(), make_move_iterator(sub_results.begin()),
make_move_iterator(sub_results.end()));
for (auto& docs : sub_results) {
aggregate::DocValues doc_value;
for (auto& doc : docs) {
doc_value[doc.first] = std::move(doc.second);
}
values.push_back(std::move(doc_value));
}
}

std::vector<std::string_view> load_fields;
Expand Down
Loading