diff --git a/src/server/search/aggregator.cc b/src/server/search/aggregator.cc index 4b6b4a5620cf..09ccb841eafa 100644 --- a/src/server/search/aggregator.cc +++ b/src/server/search/aggregator.cc @@ -10,62 +10,99 @@ namespace dfly::aggregate { namespace { -struct GroupStep { - PipelineResult operator()(PipelineResult result) { - // Separate items into groups - absl::flat_hash_map, std::vector> groups; - for (auto& value : result.values) { - groups[Extract(value)].push_back(std::move(value)); - } +using ValuesList = absl::FixedArray; - // Restore DocValues and apply reducers - std::vector out; - while (!groups.empty()) { - auto node = groups.extract(groups.begin()); - DocValues doc = Unpack(std::move(node.key())); - for (auto& reducer : reducers_) { - doc[reducer.result_field] = reducer.func({reducer.source_field, node.mapped()}); - } - out.push_back(std::move(doc)); - } +ValuesList ExtractFieldsValues(const DocValues& dv, absl::Span fields) { + ValuesList out(fields.size()); + for (size_t i = 0; i < fields.size(); i++) { + auto it = dv.find(fields[i]); + out[i] = (it != dv.end()) ? it->second : Value{}; + } + return out; +} - absl::flat_hash_set fields_to_print; - fields_to_print.reserve(fields_.size() + reducers_.size()); +DocValues PackFields(ValuesList values, absl::Span fields) { + DCHECK_EQ(values.size(), fields.size()); + DocValues out; + for (size_t i = 0; i < fields.size(); i++) + out[fields[i]] = std::move(values[i]); + return out; +} - for (auto& field : fields_) { - fields_to_print.insert(std::move(field)); - } - for (auto& reducer : reducers_) { - fields_to_print.insert(std::move(reducer.result_field)); - } +const Value kEmptyValue = Value{}; + +} // namespace - return {std::move(out), std::move(fields_to_print)}; +void Aggregator::DoGroup(absl::Span fields, absl::Span reducers) { + // Separate items into groups + absl::flat_hash_map> groups; + for (auto& value : result.values) { + groups[ExtractFieldsValues(value, fields)].push_back(std::move(value)); } - absl::FixedArray Extract(const DocValues& dv) { - absl::FixedArray out(fields_.size()); - for (size_t i = 0; i < fields_.size(); i++) { - auto it = dv.find(fields_[i]); - out[i] = (it != dv.end()) ? it->second : Value{}; + // Restore DocValues and apply reducers + auto& values = result.values; + values.clear(); + values.reserve(groups.size()); + while (!groups.empty()) { + auto node = groups.extract(groups.begin()); + DocValues doc = PackFields(std::move(node.key()), fields); + for (auto& reducer : reducers) { + doc[reducer.result_field] = reducer.func({reducer.source_field, node.mapped()}); } - return out; + values.push_back(std::move(doc)); } - DocValues Unpack(absl::FixedArray&& values) { - DCHECK_EQ(values.size(), fields_.size()); - DocValues out; - for (size_t i = 0; i < fields_.size(); i++) - out[fields_[i]] = std::move(values[i]); - return out; + auto& fields_to_print = result.fields_to_print; + fields_to_print.clear(); + fields_to_print.reserve(fields.size() + reducers.size()); + + for (auto& field : fields) { + fields_to_print.insert(field); } + for (auto& reducer : reducers) { + fields_to_print.insert(reducer.result_field); + } +} - std::vector fields_; - std::vector reducers_; -}; +void Aggregator::DoSort(std::string_view field, bool descending) { + /* + Comparator for sorting DocValues by field. + If some of the fields is not present in the DocValues, comparator returns: + 1. l_it == l.end() && r_it != r.end() + asc -> false + desc -> false + 2. l_it != l.end() && r_it == r.end() + asc -> true + desc -> true + 3. l_it == l.end() && r_it == r.end() + asc -> false + desc -> false + */ + auto comparator = [&](const DocValues& l, const DocValues& r) { + auto l_it = l.find(field); + auto r_it = r.find(field); + + // If some of the values is not present + if (l_it == l.end() || r_it == r.end()) { + return l_it != l.end(); + } -const Value kEmptyValue = Value{}; + auto& lv = l_it->second; + auto& rv = r_it->second; + return !descending ? lv < rv : lv > rv; + }; -} // namespace + std::sort(result.values.begin(), result.values.end(), std::move(comparator)); + + result.fields_to_print.insert(field); +} + +void Aggregator::DoLimit(size_t offset, size_t num) { + auto& values = result.values; + values.erase(values.begin(), values.begin() + std::min(offset, values.size())); + values.resize(std::min(num, values.size())); +} const Value& ValueIterator::operator*() const { auto it = values_.front().find(field_); @@ -109,48 +146,30 @@ Reducer::Func FindReducerFunc(ReducerFunc name) { return nullptr; } -PipelineStep MakeGroupStep(absl::Span fields, - std::vector reducers) { - return GroupStep{std::vector(fields.begin(), fields.end()), std::move(reducers)}; +AggregationStep MakeGroupStep(std::vector fields, std::vector reducers) { + return [fields = std::move(fields), reducers = std::move(reducers)](Aggregator* aggregator) { + aggregator->DoGroup(fields, reducers); + }; } -PipelineStep MakeSortStep(std::string_view field, bool descending) { - return [field = std::string(field), descending](PipelineResult result) -> PipelineResult { - auto& values = result.values; - - std::sort(values.begin(), values.end(), [field](const DocValues& l, const DocValues& r) { - auto it1 = l.find(field); - auto it2 = r.find(field); - return it1 == l.end() || (it2 != r.end() && it1->second < it2->second); - }); - - if (descending) { - std::reverse(values.begin(), values.end()); - } - - result.fields_to_print.insert(field); - return result; +AggregationStep MakeSortStep(std::string field, bool descending) { + return [field = std::move(field), descending](Aggregator* aggregator) { + aggregator->DoSort(field, descending); }; } -PipelineStep MakeLimitStep(size_t offset, size_t num) { - return [offset, num](PipelineResult result) { - auto& values = result.values; - values.erase(values.begin(), values.begin() + std::min(offset, values.size())); - values.resize(std::min(num, values.size())); - return result; - }; +AggregationStep MakeLimitStep(size_t offset, size_t num) { + return [=](Aggregator* aggregator) { aggregator->DoLimit(offset, num); }; } -PipelineResult Process(std::vector values, - absl::Span fields_to_print, - absl::Span steps) { - PipelineResult result{std::move(values), {fields_to_print.begin(), fields_to_print.end()}}; +AggregationResult Process(std::vector values, + absl::Span fields_to_print, + absl::Span steps) { + Aggregator aggregator{std::move(values), {fields_to_print.begin(), fields_to_print.end()}}; for (auto& step : steps) { - PipelineResult step_result = step(std::move(result)); - result = std::move(step_result); + step(&aggregator); } - return result; + return aggregator.result; } } // namespace dfly::aggregate diff --git a/src/server/search/aggregator.h b/src/server/search/aggregator.h index 4f4008bce238..a298735182f4 100644 --- a/src/server/search/aggregator.h +++ b/src/server/search/aggregator.h @@ -17,19 +17,31 @@ namespace dfly::aggregate { +struct Reducer; + using Value = ::dfly::search::SortableValue; -using DocValues = absl::flat_hash_map; // documents sent through the pipeline -struct PipelineResult { +// DocValues sent through the pipeline +// TODO: Replace DocValues with compact linear search map instead of hash map +using DocValues = absl::flat_hash_map; + +struct AggregationResult { // Values to be passed to the next step - // TODO: Replace DocValues with compact linear search map instead of hash map std::vector values; // Fields from values to be printed - absl::flat_hash_set fields_to_print; + absl::flat_hash_set fields_to_print; +}; + +struct Aggregator { + void DoGroup(absl::Span fields, absl::Span reducers); + void DoSort(std::string_view field, bool descending = false); + void DoLimit(size_t offset, size_t num); + + AggregationResult result; }; -using PipelineStep = std::function; // Group, Sort, etc. +using AggregationStep = std::function; // Group, Sort, etc. // Iterator over Span that yields doc[field] or monostate if not present. // Extra clumsy for STL compatibility! @@ -79,18 +91,17 @@ enum class ReducerFunc { COUNT, COUNT_DISTINCT, SUM, AVG, MAX, MIN }; Reducer::Func FindReducerFunc(ReducerFunc name); // Make `GROUPBY [fields...]` with REDUCE step -PipelineStep MakeGroupStep(absl::Span fields, - std::vector reducers); +AggregationStep MakeGroupStep(std::vector fields, std::vector reducers); // Make `SORTBY field [DESC]` step -PipelineStep MakeSortStep(std::string_view field, bool descending = false); +AggregationStep MakeSortStep(std::string field, bool descending = false); // Make `LIMIT offset num` step -PipelineStep MakeLimitStep(size_t offset, size_t num); +AggregationStep MakeLimitStep(size_t offset, size_t num); // Process values with given steps -PipelineResult Process(std::vector values, - absl::Span fields_to_print, - absl::Span steps); +AggregationResult Process(std::vector values, + absl::Span fields_to_print, + absl::Span steps); } // namespace dfly::aggregate diff --git a/src/server/search/aggregator_test.cc b/src/server/search/aggregator_test.cc index 3ee8b58e1f5a..a9f9544ce3b7 100644 --- a/src/server/search/aggregator_test.cc +++ b/src/server/search/aggregator_test.cc @@ -10,13 +10,15 @@ namespace dfly::aggregate { using namespace std::string_literals; +using StepsList = std::vector; + TEST(AggregatorTest, Sort) { std::vector values = { DocValues{{"a", 1.0}}, DocValues{{"a", 0.5}}, DocValues{{"a", 1.5}}, }; - PipelineStep steps[] = {MakeSortStep("a", false)}; + StepsList steps = {MakeSortStep("a", false)}; auto result = Process(values, {"a"}, steps); @@ -32,7 +34,8 @@ TEST(AggregatorTest, Limit) { DocValues{{"i", 3.0}}, DocValues{{"i", 4.0}}, }; - PipelineStep steps[] = {MakeLimitStep(1, 2)}; + + StepsList steps = {MakeLimitStep(1, 2)}; auto result = Process(values, {"i"}, steps); @@ -49,8 +52,8 @@ TEST(AggregatorTest, SimpleGroup) { DocValues{{"i", 4.0}, {"tag", "even"}}, }; - std::string_view fields[] = {"tag"}; - PipelineStep steps[] = {MakeGroupStep(fields, {})}; + std::vector fields = {"tag"}; + StepsList steps = {MakeGroupStep(std::move(fields), {})}; auto result = Process(values, {"i", "tag"}, steps); EXPECT_EQ(result.values.size(), 2); @@ -72,13 +75,14 @@ TEST(AggregatorTest, GroupWithReduce) { }); } - std::string_view fields[] = {"tag"}; + std::vector fields = {"tag"}; std::vector reducers = { Reducer{"", "count", FindReducerFunc(ReducerFunc::COUNT)}, Reducer{"i", "sum-i", FindReducerFunc(ReducerFunc::SUM)}, Reducer{"half-i", "distinct-hi", FindReducerFunc(ReducerFunc::COUNT_DISTINCT)}, Reducer{"null-field", "distinct-null", FindReducerFunc(ReducerFunc::COUNT_DISTINCT)}}; - PipelineStep steps[] = {MakeGroupStep(fields, std::move(reducers))}; + + StepsList steps = {MakeGroupStep(std::move(fields), std::move(reducers))}; auto result = Process(values, {"i", "half-i", "tag"}, steps); EXPECT_EQ(result.values.size(), 2); diff --git a/src/server/search/doc_index.h b/src/server/search/doc_index.h index 6b5a2da6cf7d..20b2117730e0 100644 --- a/src/server/search/doc_index.h +++ b/src/server/search/doc_index.h @@ -168,7 +168,7 @@ struct AggregateParams { search::QueryParams params; std::optional load_fields; - std::vector steps; + std::vector steps; }; // Stores basic info about a document index. diff --git a/src/server/search/search_family.cc b/src/server/search/search_family.cc index 75ca1f249e43..cd14db6791a5 100644 --- a/src/server/search/search_family.cc +++ b/src/server/search/search_family.cc @@ -320,20 +320,23 @@ optional ParseAggregatorParamsOrReply(CmdArgParser parser, while (parser.HasNext()) { // GROUPBY nargs property [property ...] if (parser.Check("GROUPBY")) { - vector fields(parser.Next()); - for (string_view& field : fields) { + size_t num_fields = parser.Next(); + + std::vector fields; + fields.reserve(num_fields); + while (num_fields > 0 && parser.HasNext()) { auto parsed_field = ParseFieldWithAtSign(&parser); /* TODO: Throw an error if the field has no '@' sign at the beginning - if (!parsed_field) { builder->SendError(absl::StrCat("bad arguments for GROUPBY: Unknown property '", field, "'. Did you mean '@", field, "`?")); return nullopt; } */ - field = parsed_field; + fields.emplace_back(parsed_field); + num_fields--; } vector reducers; @@ -363,7 +366,7 @@ optional ParseAggregatorParamsOrReply(CmdArgParser parser, aggregate::Reducer{std::move(source_field), std::move(result_field), std::move(func)}); } - params.steps.push_back(aggregate::MakeGroupStep(fields, std::move(reducers))); + params.steps.push_back(aggregate::MakeGroupStep(std::move(fields), std::move(reducers))); continue; } @@ -373,7 +376,7 @@ optional ParseAggregatorParamsOrReply(CmdArgParser parser, string_view field = parser.Next(); bool desc = bool(parser.Check("DESC")); - params.steps.push_back(aggregate::MakeSortStep(field, desc)); + params.steps.push_back(aggregate::MakeSortStep(std::string{field}, desc)); continue; } @@ -975,10 +978,18 @@ void SearchFamily::FtAggregate(CmdArgList args, const CommandContext& cmd_cntx) return OpStatus::OK; }); - vector values; + // ResultContainer is absl::flat_hash_map + // DocValues is absl::flat_hash_map + // Keys of values should point to the keys of the query_results + std::vector values; for (auto& sub_results : query_results) { - values.insert(values.end(), make_move_iterator(sub_results.begin()), - make_move_iterator(sub_results.end())); + for (auto& docs : sub_results) { + aggregate::DocValues doc_value; + for (auto& doc : docs) { + doc_value[doc.first] = std::move(doc.second); + } + values.push_back(std::move(doc_value)); + } } std::vector load_fields;