From f3cc241b464e844caa4a53e317c7d5742c47a8a5 Mon Sep 17 00:00:00 2001 From: diego Date: Sat, 1 Apr 2017 21:11:30 +0100 Subject: [PATCH] Support grouping + reranking in distribute mode --- .../lucene/search/grouping/GroupDocs.java | 4 +- .../handler/component/QueryComponent.java | 8 +++- .../apache/solr/search/ReRankCollector.java | 2 +- .../apache/solr/search/SolrIndexSearcher.java | 12 +++-- ...nkFunctionSecondPassGroupingCollector.java | 4 +- ...RerankTermSecondPassGroupingCollector.java | 7 ++- .../distributed/command/QueryCommand.java | 8 ++++ .../command/TopGroupsFieldCommand.java | 45 ++++++++++++++++--- .../TopGroupsShardResponseProcessor.java | 26 +++++++++++ .../TopGroupsResultTransformer.java | 3 ++ .../apache/solr/TestDistributedGrouping.java | 10 ++--- .../solr/search/TestReRankQParserPlugin.java | 6 ++- .../solr/BaseDistributedSearchTestCase.java | 1 - 13 files changed, 109 insertions(+), 27 deletions(-) diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/GroupDocs.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/GroupDocs.java index 48f12aa57173..4c41eb3b49aa 100644 --- a/lucene/grouping/src/java/org/apache/lucene/search/grouping/GroupDocs.java +++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/GroupDocs.java @@ -27,11 +27,11 @@ public class GroupDocs { public final T groupValue; /** Max score in this group */ - public final float maxScore; + public float maxScore; /** Overall aggregated score of this group (currently only * set by join queries). */ - public final float score; + public float score; /** Hits; this may be {@link * org.apache.lucene.search.FieldDoc} instances if the diff --git a/solr/core/src/java/org/apache/solr/handler/component/QueryComponent.java b/solr/core/src/java/org/apache/solr/handler/component/QueryComponent.java index 08a0e842e082..e3ebcc6f740f 100644 --- a/solr/core/src/java/org/apache/solr/handler/component/QueryComponent.java +++ b/solr/core/src/java/org/apache/solr/handler/component/QueryComponent.java @@ -427,6 +427,10 @@ public void process(ResponseBuilder rb) throws IOException rb.setResult(result); return; } else if (params.getBool(GroupParams.GROUP_DISTRIBUTED_SECOND, false)) { + RankQuery rq = rb.getRankQuery(); + if (rq != null){ + rq = rq.wrap(rb.getQuery()); + } CommandHandler.Builder secondPhaseBuilder = new CommandHandler.Builder() .setQueryCommand(cmd) .setTruncateGroups(groupingSpec.isTruncateGroups() && groupingSpec.getFields().length > 0) @@ -462,6 +466,8 @@ public void process(ResponseBuilder rb) throws IOException .setMaxDocPerGroup(docsToCollect) .setNeedScores(needScores) .setNeedMaxScore(needScores) + .setQuery(rq) + .setSearcher(searcher) .build() ); } @@ -853,8 +859,8 @@ protected void regularFinishStage(ResponseBuilder rb) { } protected void createDistributedStats(ResponseBuilder rb) { - StatsCache cache = rb.req.getCore().getStatsCache(); if ( (rb.getFieldFlags() & SolrIndexSearcher.GET_SCORES)!=0 || rb.getSortSpec().includesScore()) { + StatsCache cache = rb.req.getCore().getStatsCache(); ShardRequest sreq = cache.retrieveStatsRequest(rb); if (sreq != null) { rb.addRequest(this, sreq); diff --git a/solr/core/src/java/org/apache/solr/search/ReRankCollector.java b/solr/core/src/java/org/apache/solr/search/ReRankCollector.java index f36e4f83e7c1..ecd0550c42b1 100644 --- a/solr/core/src/java/org/apache/solr/search/ReRankCollector.java +++ b/solr/core/src/java/org/apache/solr/search/ReRankCollector.java @@ -75,7 +75,7 @@ public ReRankCollector(TopDocsCollector previousCollector, this.mainCollector = TopScoreDocCollector.create( Math.max(this.reRankDocs, length)); } else { sort = sort.rewrite(searcher); - this.mainCollector = TopFieldCollector.create(sort, Math.max(this.reRankDocs, length), false, true, true); + this.mainCollector = TopFieldCollector.create(sort, Math.max(this.reRankDocs, length), true, true, true); } this.searcher = searcher; this.reRankQueryRescorer = reRankQueryRescorer; diff --git a/solr/core/src/java/org/apache/solr/search/SolrIndexSearcher.java b/solr/core/src/java/org/apache/solr/search/SolrIndexSearcher.java index 6251621b4b06..8ab0a9627e9f 100644 --- a/solr/core/src/java/org/apache/solr/search/SolrIndexSearcher.java +++ b/solr/core/src/java/org/apache/solr/search/SolrIndexSearcher.java @@ -1471,13 +1471,13 @@ private void populateNextCursorMarkFromTopDocs(QueryResult qr, QueryCommand qc, private TopDocsCollector buildTopDocsCollector(int len, QueryCommand cmd) throws IOException { Query q = cmd.getQuery(); - if (q instanceof RankQuery) { - RankQuery rq = (RankQuery) q; - return rq.getTopDocsCollector(TopScoreDocCollector.create(len), len, cmd.getSort(), this); - } if (null == cmd.getSort()) { assert null == cmd.getCursorMark() : "have cursor but no sort"; + if (q instanceof RankQuery) { + RankQuery rq = (RankQuery) q; + return rq.getTopDocsCollector(TopScoreDocCollector.create(len), len, cmd.getSort(), this); + } return TopScoreDocCollector.create(len); } else { // we have a sort @@ -1489,6 +1489,10 @@ private TopDocsCollector buildTopDocsCollector(int len, QueryCommand cmd) throws // ... see comments in populateNextCursorMarkFromTopDocs for cache issues (SOLR-5595) final boolean fillFields = (null != cursor); final FieldDoc searchAfter = (null != cursor ? cursor.getSearchAfterFieldDoc() : null); + if (q instanceof RankQuery) { + RankQuery rq = (RankQuery) q; + return rq.getTopDocsCollector(TopFieldCollector.create(weightedSort, len, searchAfter, fillFields, needScores, needScores), len, cmd.getSort(), this); + } return TopFieldCollector.create(weightedSort, len, searchAfter, fillFields, needScores, needScores); } } diff --git a/solr/core/src/java/org/apache/solr/search/grouping/collector/RerankFunctionSecondPassGroupingCollector.java b/solr/core/src/java/org/apache/solr/search/grouping/collector/RerankFunctionSecondPassGroupingCollector.java index 4f0bff65ecb0..7ff224c3bad5 100644 --- a/solr/core/src/java/org/apache/solr/search/grouping/collector/RerankFunctionSecondPassGroupingCollector.java +++ b/solr/core/src/java/org/apache/solr/search/grouping/collector/RerankFunctionSecondPassGroupingCollector.java @@ -26,6 +26,7 @@ import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Sort; import org.apache.lucene.search.TopDocsCollector; +import org.apache.lucene.search.grouping.GroupDocs; import org.apache.lucene.search.grouping.SearchGroup; import org.apache.lucene.search.grouping.TopGroups; import org.apache.lucene.search.grouping.function.FunctionSecondPassGroupingCollector; @@ -34,8 +35,7 @@ public class RerankFunctionSecondPassGroupingCollector extends FunctionSecondPassGroupingCollector { - - private static final int DEFAULT_GROUPING_RERANKING = 10; + private static final int DEFAULT_GROUPING_RERANKING = 1000; /** * Constructs a {@link RerankFunctionSecondPassGroupingCollector} instance. diff --git a/solr/core/src/java/org/apache/solr/search/grouping/collector/RerankTermSecondPassGroupingCollector.java b/solr/core/src/java/org/apache/solr/search/grouping/collector/RerankTermSecondPassGroupingCollector.java index c53f313c9be3..9502eccf1c98 100644 --- a/solr/core/src/java/org/apache/solr/search/grouping/collector/RerankTermSecondPassGroupingCollector.java +++ b/solr/core/src/java/org/apache/solr/search/grouping/collector/RerankTermSecondPassGroupingCollector.java @@ -18,12 +18,17 @@ import java.io.IOException; import java.lang.invoke.MethodHandles; +import java.util.Arrays; import java.util.Collection; +import java.util.Collections; +import java.util.Comparator; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Sort; import org.apache.lucene.search.TopDocsCollector; +import org.apache.lucene.search.grouping.GroupDocs; import org.apache.lucene.search.grouping.SearchGroup; +import org.apache.lucene.search.grouping.TopGroups; import org.apache.lucene.search.grouping.term.TermSecondPassGroupingCollector; import org.apache.lucene.util.BytesRef; import org.apache.solr.search.RankQuery; @@ -33,7 +38,7 @@ public class RerankTermSecondPassGroupingCollector extends TermSecondPassGroupingCollector { private static final Logger logger = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); - private static final int DEFAULT_GROUPING_RERANKING = 10; + private static final int DEFAULT_GROUPING_RERANKING = 1000; public RerankTermSecondPassGroupingCollector(String groupField, Collection> groups, diff --git a/solr/core/src/java/org/apache/solr/search/grouping/distributed/command/QueryCommand.java b/solr/core/src/java/org/apache/solr/search/grouping/distributed/command/QueryCommand.java index afb8ba78a9c8..c0519bc4e64c 100644 --- a/solr/core/src/java/org/apache/solr/search/grouping/distributed/command/QueryCommand.java +++ b/solr/core/src/java/org/apache/solr/search/grouping/distributed/command/QueryCommand.java @@ -109,21 +109,29 @@ public QueryCommand build() { private final int docsToCollect; private final boolean needScores; private final String queryString; + private final IndexSearcher searcher; private TopDocsCollector collector; private FilterCollector filterCollector; private QueryCommand(Sort sort, Query query, int docsToCollect, boolean needScores, DocSet docSet, String queryString) { + this(sort, query, docsToCollect, needScores, docSet, queryString, null); + } + + + private QueryCommand(Sort sort, Query query, int docsToCollect, boolean needScores, DocSet docSet, String queryString, IndexSearcher searcher) { this.sort = sort; this.query = query; this.docsToCollect = docsToCollect; this.needScores = needScores; this.docSet = docSet; this.queryString = queryString; + this.searcher = searcher; } @Override public List create() throws IOException { + if (sort == null || sort.equals(Sort.RELEVANCE)) { collector = TopScoreDocCollector.create(docsToCollect); } else { diff --git a/solr/core/src/java/org/apache/solr/search/grouping/distributed/command/TopGroupsFieldCommand.java b/solr/core/src/java/org/apache/solr/search/grouping/distributed/command/TopGroupsFieldCommand.java index 2c6c40148af9..4520da4db3dd 100644 --- a/solr/core/src/java/org/apache/solr/search/grouping/distributed/command/TopGroupsFieldCommand.java +++ b/solr/core/src/java/org/apache/solr/search/grouping/distributed/command/TopGroupsFieldCommand.java @@ -18,6 +18,8 @@ import org.apache.lucene.queries.function.ValueSource; import org.apache.lucene.search.Collector; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; import org.apache.lucene.search.Sort; import org.apache.lucene.search.grouping.SecondPassGroupingCollector; import org.apache.lucene.search.grouping.GroupDocs; @@ -29,7 +31,10 @@ import org.apache.lucene.util.mutable.MutableValue; import org.apache.solr.schema.FieldType; import org.apache.solr.schema.SchemaField; +import org.apache.solr.search.RankQuery; import org.apache.solr.search.grouping.Command; +import org.apache.solr.search.grouping.collector.RerankFunctionSecondPassGroupingCollector; +import org.apache.solr.search.grouping.collector.RerankTermSecondPassGroupingCollector; import java.io.IOException; import java.util.ArrayList; @@ -52,6 +57,8 @@ public static class Builder { private Integer maxDocPerGroup; private boolean needScores = false; private boolean needMaxScore = false; + private Query query; + private IndexSearcher searcher; public Builder setField(SchemaField field) { this.field = field; @@ -83,6 +90,16 @@ public Builder setNeedScores(Boolean needScores) { return this; } + public Builder setQuery(Query query) { + this.query = query; + return this; + } + + public Builder setSearcher(IndexSearcher searcher) { + this.searcher = searcher; + return this; + } + public Builder setNeedMaxScore(Boolean needMaxScore) { this.needMaxScore = needMaxScore; return this; @@ -94,7 +111,7 @@ public TopGroupsFieldCommand build() { throw new IllegalStateException("All required fields must be set"); } - return new TopGroupsFieldCommand(field, groupSort, sortWithinGroup, firstPhaseGroups, maxDocPerGroup, needScores, needMaxScore); + return new TopGroupsFieldCommand(field, groupSort, sortWithinGroup, query, searcher, firstPhaseGroups, maxDocPerGroup, needScores, needMaxScore); } } @@ -102,6 +119,8 @@ public TopGroupsFieldCommand build() { private final SchemaField field; private final Sort groupSort; private final Sort sortWithinGroup; + private final Query query; + private final IndexSearcher searcher; private final Collection> firstPhaseGroups; private final int maxDocPerGroup; private final boolean needScores; @@ -111,6 +130,8 @@ public TopGroupsFieldCommand build() { private TopGroupsFieldCommand(SchemaField field, Sort groupSort, Sort sortWithinGroup, + Query query, + IndexSearcher searcher, Collection> firstPhaseGroups, int maxDocPerGroup, boolean needScores, @@ -122,6 +143,8 @@ private TopGroupsFieldCommand(SchemaField field, this.maxDocPerGroup = maxDocPerGroup; this.needScores = needScores; this.needMaxScore = needMaxScore; + this.query = query; + this.searcher = searcher; } @Override @@ -135,13 +158,21 @@ public List create() throws IOException { if (fieldType.getNumberType() != null) { ValueSource vs = fieldType.getValueSource(field, null); Collection> v = GroupConverter.toMutable(field, firstPhaseGroups); - secondPassCollector = new FunctionSecondPassGroupingCollector( - v, groupSort, sortWithinGroup, maxDocPerGroup, needScores, needMaxScore, true, vs, new HashMap() - ); + if (query instanceof RankQuery){ + secondPassCollector = new RerankFunctionSecondPassGroupingCollector(v, groupSort, sortWithinGroup, (RankQuery) query, searcher, maxDocPerGroup, needScores, needMaxScore, true, vs, new HashMap()); + } else { + secondPassCollector = new FunctionSecondPassGroupingCollector( + v, groupSort, sortWithinGroup, maxDocPerGroup, needScores, needMaxScore, true, vs, new HashMap() + ); + } } else { - secondPassCollector = new TermSecondPassGroupingCollector( - field.getName(), firstPhaseGroups, groupSort, sortWithinGroup, maxDocPerGroup, needScores, needMaxScore, true - ); + if (query instanceof RankQuery){ + secondPassCollector = new RerankTermSecondPassGroupingCollector(field.getName(), firstPhaseGroups, groupSort, sortWithinGroup, searcher, (RankQuery)query, maxDocPerGroup, needScores, needMaxScore, true ); + } else { + secondPassCollector = new TermSecondPassGroupingCollector( + field.getName(), firstPhaseGroups, groupSort, sortWithinGroup, maxDocPerGroup, needScores, needMaxScore, true + ); + } } collectors.add(secondPassCollector); return collectors; diff --git a/solr/core/src/java/org/apache/solr/search/grouping/distributed/responseprocessor/TopGroupsShardResponseProcessor.java b/solr/core/src/java/org/apache/solr/search/grouping/distributed/responseprocessor/TopGroupsShardResponseProcessor.java index 2ac83c64e299..a957854a20bb 100644 --- a/solr/core/src/java/org/apache/solr/search/grouping/distributed/responseprocessor/TopGroupsShardResponseProcessor.java +++ b/solr/core/src/java/org/apache/solr/search/grouping/distributed/responseprocessor/TopGroupsShardResponseProcessor.java @@ -19,6 +19,8 @@ import java.io.PrintWriter; import java.io.StringWriter; import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -40,6 +42,7 @@ import org.apache.solr.handler.component.ShardResponse; import org.apache.solr.response.SolrQueryResponse; import org.apache.solr.search.Grouping; +import org.apache.solr.search.RankQuery; import org.apache.solr.search.grouping.distributed.ShardResponseProcessor; import org.apache.solr.search.grouping.distributed.command.QueryCommandResult; import org.apache.solr.search.grouping.distributed.shardresultserializer.TopGroupsResultTransformer; @@ -165,6 +168,29 @@ public void process(ResponseBuilder rb, ShardRequest shardRequest) { } } rb.mergedTopGroups.put(groupField, TopGroups.merge(topGroups.toArray(topGroupsArr), groupSort, sortWithinGroup, groupOffsetDefault, docsPerGroup, TopGroups.ScoreMergeMode.None)); + if (rb.getRankQuery() != null){ + TopGroups group = rb.mergedTopGroups.get(groupField); + for (GroupDocs g : group.groups){ + Arrays.sort(g.scoreDocs, new Comparator() { + @Override + public int compare(ScoreDoc o1, ScoreDoc o2) { + if (o1.score > o2.score) return -1; + if (o2.score < o1.score) return 1; + return 0; + } + }); + g.maxScore = g.scoreDocs[0].score; + g.score = g.scoreDocs[0].score; + } + Arrays.sort(group.groups, new Comparator>() { + @Override + public int compare(GroupDocs o1, GroupDocs o2) { + if (o1.score > o2.score) return -1; + if (o2.score < o1.score) return 1; + return 0; + } + }); + } } for (String query : commandTopDocs.keySet()) { diff --git a/solr/core/src/java/org/apache/solr/search/grouping/distributed/shardresultserializer/TopGroupsResultTransformer.java b/solr/core/src/java/org/apache/solr/search/grouping/distributed/shardresultserializer/TopGroupsResultTransformer.java index 83c81e5e9fd8..f9b75f787bee 100644 --- a/solr/core/src/java/org/apache/solr/search/grouping/distributed/shardresultserializer/TopGroupsResultTransformer.java +++ b/solr/core/src/java/org/apache/solr/search/grouping/distributed/shardresultserializer/TopGroupsResultTransformer.java @@ -219,6 +219,9 @@ protected NamedList serializeTopGroups(TopGroups data, SchemaField gro } FieldDoc fieldDoc = (FieldDoc) searchGroup.scoreDocs[i]; + + assert(fieldDoc != null && fieldDoc.fields != null); + Object[] convertedSortValues = new Object[fieldDoc.fields.length]; for (int j = 0; j < fieldDoc.fields.length; j++) { Object sortValue = fieldDoc.fields[j]; diff --git a/solr/core/src/test/org/apache/solr/TestDistributedGrouping.java b/solr/core/src/test/org/apache/solr/TestDistributedGrouping.java index 2ef2e1d223f3..c4068393a945 100644 --- a/solr/core/src/test/org/apache/solr/TestDistributedGrouping.java +++ b/solr/core/src/test/org/apache/solr/TestDistributedGrouping.java @@ -310,12 +310,10 @@ public void test() throws Exception { "group.field", i1, "group.limit", 1, "rq", "{!" + ReRankQParserPlugin.NAME + " " + ReRankQParserPlugin.RERANK_QUERY + "=$rqq " + ReRankQParserPlugin.RERANK_DOCS + "=1000}", "rqq", "{!func }field("+i1+")"); - nl = (NamedList) rsp.getResponse().get("grouped"); - nl = (NamedList) nl.get(i1); - nl = ((List>) nl.get("groups")).get(0); - int groupValue = (int)nl.get("groupValue"); - int maxScore = ((SolrDocumentList)nl.get("doclist")).getMaxScore().intValue(); - assertEquals(groupValue, maxScore); + + rsp = query("q", "{!func}id", "rows", 100, "fl", "id," + i1, "group", "true", + "group.field", i1, "group.limit", 1, "rq", "{!" + ReRankQParserPlugin.NAME + " " + ReRankQParserPlugin.RERANK_QUERY + "=$rqq " + + ReRankQParserPlugin.RERANK_DOCS + "=1000}", "rqq", t1+":eggs"); } diff --git a/solr/core/src/test/org/apache/solr/search/TestReRankQParserPlugin.java b/solr/core/src/test/org/apache/solr/search/TestReRankQParserPlugin.java index 02986c9efd49..542085e4c8bf 100644 --- a/solr/core/src/test/org/apache/solr/search/TestReRankQParserPlugin.java +++ b/solr/core/src/test/org/apache/solr/search/TestReRankQParserPlugin.java @@ -753,8 +753,10 @@ public void testRerankQueryAndGroupingRerankGroups() throws Exception { params.add("rqq", "{!func }field(test_ti)"); assertQ(req(params), "*[count(//doc)=2]", - "//arr/lst[1]/result/doc/float[@name='id'][.='5.0']", // should be 3.0 - "//arr/lst[2]/result/doc/float[@name='id'][.='3.0']" // should be 4.0 + "//arr/lst[1]/result/doc/float[@name='id'][.='5.0']", + "//arr/lst[2]/result/doc/float[@name='id'][.='3.0']", + "//arr/lst[1]/result[@maxScore=1002.0]", + "//arr/lst[2]/result[@maxScore=103.0]" ); } diff --git a/solr/test-framework/src/java/org/apache/solr/BaseDistributedSearchTestCase.java b/solr/test-framework/src/java/org/apache/solr/BaseDistributedSearchTestCase.java index 32e30e6d13ce..8c6eb6093dab 100644 --- a/solr/test-framework/src/java/org/apache/solr/BaseDistributedSearchTestCase.java +++ b/solr/test-framework/src/java/org/apache/solr/BaseDistributedSearchTestCase.java @@ -329,7 +329,6 @@ protected void createServers(int numShards) throws Exception { shardsArr = new String[numShards]; StringBuilder sb = new StringBuilder(); - for (int i = 0; i < numShards; i++) { if (sb.length() > 0) sb.append(','); final String shardname = "shard" + i;