diff --git a/lucene/core/src/java/org/apache/lucene/search/TopFieldCollector.java b/lucene/core/src/java/org/apache/lucene/search/TopFieldCollector.java index 8ba42abafbee..7c590bf8bf43 100644 --- a/lucene/core/src/java/org/apache/lucene/search/TopFieldCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/TopFieldCollector.java @@ -384,7 +384,8 @@ public static TopFieldCollector create(Sort sort, int numHits, int totalHitsThre */ public static TopFieldCollector create(Sort sort, int numHits, FieldDoc after, int totalHitsThreshold) { - + assert sort != null : "Sort can't be null"; + assert sort.fields != null : "Sort fields can't be null"; if (sort.fields.length == 0) { throw new IllegalArgumentException("Sort must contain at least one field"); } @@ -451,6 +452,8 @@ public static void populateScores(ScoreDoc[] topDocs, IndexSearcher searcher, Qu } scoreDoc.score = currentScorer.score(); } + + } final void add(int slot, int doc) { diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/SecondPassGroupingCollector.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/SecondPassGroupingCollector.java index 0d5fc9daa26c..8e31ae56c756 100644 --- a/lucene/grouping/src/java/org/apache/lucene/search/grouping/SecondPassGroupingCollector.java +++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/SecondPassGroupingCollector.java @@ -51,15 +51,14 @@ public class SecondPassGroupingCollector extends SimpleCollector { */ public SecondPassGroupingCollector(GroupSelector groupSelector, Collection> groups, GroupReducer reducer) { - //System.out.println("SP init"); - if (groups.isEmpty()) { - throw new IllegalArgumentException("no groups to collect (groups is empty)"); + if (groups == null || groups.isEmpty()) { + throw new IllegalArgumentException("no groups to collect (groups is "+ ((groups == null) ? "null" : "empty")+")"); } this.groupSelector = Objects.requireNonNull(groupSelector); this.groupSelector.setGroups(groups); - this.groups = Objects.requireNonNull(groups); + this.groups = groups; this.groupReducer = reducer; reducer.setGroups(groups); } diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/TopGroups.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/TopGroups.java index 71338f963585..20fcf266206b 100644 --- a/lucene/grouping/src/java/org/apache/lucene/search/grouping/TopGroups.java +++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/TopGroups.java @@ -169,7 +169,9 @@ public static TopGroups merge(TopGroups[] shardGroups, Sort groupSort, shardGroupDocs.scoreDocs, docSort.getSort()); } - maxScore = Math.max(maxScore, shardGroupDocs.maxScore); + if (! Float.isNaN(shardGroupDocs.maxScore)){ + maxScore = Math.max(maxScore, shardGroupDocs.maxScore); + } assert shardGroupDocs.totalHits.relation == Relation.EQUAL_TO; totalHits += shardGroupDocs.totalHits.value; scoreSum += shardGroupDocs.score; diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/TopGroupsCollector.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/TopGroupsCollector.java index 01e992822256..747d5cefd7c5 100644 --- a/lucene/grouping/src/java/org/apache/lucene/search/grouping/TopGroupsCollector.java +++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/TopGroupsCollector.java @@ -49,21 +49,32 @@ public class TopGroupsCollector extends SecondPassGroupingCollector { /** * Create a new TopGroupsCollector + * @param groupReducer the group reducer used to collect the groups * @param groupSelector the group selector used to define groups * @param groups the groups to collect TopDocs for * @param groupSort the order in which groups are returned * @param withinGroupSort the order in which documents are sorted in each group * @param maxDocsPerGroup the maximum number of docs to collect for each group - * @param getMaxScores if true, record the maximum score for each group */ - public TopGroupsCollector(GroupSelector groupSelector, Collection> groups, Sort groupSort, Sort withinGroupSort, - int maxDocsPerGroup, boolean getMaxScores) { - super(groupSelector, groups, - new TopDocsReducer<>(withinGroupSort, maxDocsPerGroup, getMaxScores)); + protected TopGroupsCollector(GroupReducer groupReducer, GroupSelector groupSelector, Collection> groups, Sort groupSort, Sort withinGroupSort, int maxDocsPerGroup) { + super(groupSelector, groups, groupReducer); this.groupSort = Objects.requireNonNull(groupSort); this.withinGroupSort = Objects.requireNonNull(withinGroupSort); this.maxDocsPerGroup = maxDocsPerGroup; + } + /** + * Create a new TopGroupsCollector + * @param groupSelector the group selector used to define groups + * @param groups the groups to collect TopDocs for + * @param groupSort the order in which groups are returned + * @param withinGroupSort the order in which documents are sorted in each group + * @param maxDocsPerGroup the maximum number of docs to collect for each group + * @param getMaxScores if true, record the maximum score for each group + */ + public TopGroupsCollector(GroupSelector groupSelector, Collection> groups, Sort groupSort, Sort withinGroupSort, + int maxDocsPerGroup, boolean getMaxScores) { + this(new TopDocsReducer<>(withinGroupSort, maxDocsPerGroup, getMaxScores), groupSelector, groups, groupSort, withinGroupSort, maxDocsPerGroup); } private static class MaxScoreCollector extends SimpleCollector { diff --git a/solr/core/src/java/org/apache/solr/handler/component/QueryComponent.java b/solr/core/src/java/org/apache/solr/handler/component/QueryComponent.java index fea238b469fe..8ef91881207c 100644 --- a/solr/core/src/java/org/apache/solr/handler/component/QueryComponent.java +++ b/solr/core/src/java/org/apache/solr/handler/component/QueryComponent.java @@ -66,6 +66,7 @@ import org.apache.solr.response.BasicResultContext; import org.apache.solr.response.ResultContext; import org.apache.solr.response.SolrQueryResponse; +import org.apache.solr.search.AbstractReRankQuery; import org.apache.solr.schema.FieldType; import org.apache.solr.schema.IndexSchema; import org.apache.solr.schema.SchemaField; @@ -617,6 +618,7 @@ public void finishStage(ResponseBuilder rb) { protected void groupedFinishStage(final ResponseBuilder rb) { // To have same response as non-distributed request. GroupingSpecification groupSpec = rb.getGroupingSpec(); + if (rb.mergedTopGroups.isEmpty()) { for (String field : groupSpec.getFields()) { rb.mergedTopGroups.put(field, new TopGroups(null, null, 0, 0, new GroupDocs[]{}, Float.NaN)); @@ -1283,10 +1285,17 @@ private void doProcessGroupedDistributedSearchFirstPhase(ResponseBuilder rb, Que .setSearcher(searcher); for (String field : groupingSpec.getFields()) { + final int topNGroups; + Query query = cmd.getQuery(); + if (query instanceof AbstractReRankQuery){ + topNGroups = Math.max(((AbstractReRankQuery)query).getReRankDocs(), cmd.getOffset() + cmd.getLen()); + } else { + topNGroups = cmd.getOffset() + cmd.getLen(); + } topsGroupsActionBuilder.addCommandField(new SearchGroupsFieldCommand.Builder() .setField(schema.getField(field)) .setGroupSort(groupingSpec.getGroupSort()) - .setTopNGroups(cmd.getOffset() + cmd.getLen()) + .setTopNGroups(topNGroups) .setIncludeGroupCount(groupingSpec.isIncludeGroupCount()) .build() ); @@ -1301,11 +1310,27 @@ private void doProcessGroupedDistributedSearchFirstPhase(ResponseBuilder rb, Que rb.setResult(result); } + private static List> createSearchGroups(SchemaField schemaField, String[] topGroupsValues){ + List> topGroups = new ArrayList<>(topGroupsValues.length); + for (String topGroup : topGroupsValues) { + SearchGroup searchGroup = new SearchGroup<>(); + if (!topGroup.equals(TopGroupsShardRequestFactory.GROUP_NULL_VALUE)) { + BytesRefBuilder builder = new BytesRefBuilder(); + schemaField.getType().readableToIndexed(topGroup, builder); + searchGroup.groupValue = builder.get(); + } + topGroups.add(searchGroup); + } + return topGroups; + } + private void doProcessGroupedDistributedSearchSecondPhase(ResponseBuilder rb, QueryCommand cmd, QueryResult result) throws IOException, SyntaxError { GroupingSpecification groupingSpec = rb.getGroupingSpec(); assert null != groupingSpec : "GroupingSpecification is null"; + Query query = cmd.getQuery(); + SolrQueryRequest req = rb.req; SolrQueryResponse rsp = rb.rsp; @@ -1326,25 +1351,16 @@ private void doProcessGroupedDistributedSearchSecondPhase(ResponseBuilder rb, Qu for (String field : groupingSpec.getFields()) { SchemaField schemaField = schema.getField(field); + // get the top groups for each field String[] topGroupsParam = params.getParams(GroupParams.GROUP_DISTRIBUTED_TOPGROUPS_PREFIX + field); if (topGroupsParam == null) { topGroupsParam = new String[0]; } - List> topGroups = new ArrayList<>(topGroupsParam.length); - for (String topGroup : topGroupsParam) { - SearchGroup searchGroup = new SearchGroup<>(); - if (!topGroup.equals(TopGroupsShardRequestFactory.GROUP_NULL_VALUE)) { - BytesRefBuilder builder = new BytesRefBuilder(); - schemaField.getType().readableToIndexed(topGroup, builder); - searchGroup.groupValue = builder.get(); - } - topGroups.add(searchGroup); - } - + List> topGroups = createSearchGroups(schemaField, topGroupsParam); secondPhaseBuilder.addCommandField( new TopGroupsFieldCommand.Builder() - .setQuery(cmd.getQuery()) + .setQuery(query) .setField(schemaField) .setGroupSort(groupingSpec.getGroupSort()) .setSortWithinGroup(groupingSpec.getSortWithinGroup()) @@ -1352,15 +1368,16 @@ private void doProcessGroupedDistributedSearchSecondPhase(ResponseBuilder rb, Qu .setMaxDocPerGroup(docsToCollect) .setNeedScores(needScores) .setNeedMaxScore(needScores) + .setSearcher(searcher) .build() ); } - for (String query : groupingSpec.getQueries()) { + for (String groupingQuery : groupingSpec.getQueries()) { secondPhaseBuilder.addCommandField(new Builder() .setDocsToCollect(docsToCollect) .setSort(groupingSpec.getGroupSort()) - .setQuery(query, rb.req) + .setQuery(groupingQuery, rb.req) .setDocSet(searcher) .build() ); @@ -1401,6 +1418,12 @@ private void doProcessGroupedSearch(ResponseBuilder rb, QueryCommand cmd, QueryR .setGroupOffsetDefault(groupingSpec.getWithinGroupOffset()) .setGetGroupedDocSet(groupingSpec.isTruncateGroups()); + if (cmd.getQuery() instanceof AbstractReRankQuery) { + AbstractReRankQuery rankQuery = (AbstractReRankQuery) cmd.getQuery(); + final int reRankGroups = rankQuery.getReRankDocs(); + grouping.setReRankGroups(reRankGroups); + } + if (groupingSpec.getFields() != null) { for (String field : groupingSpec.getFields()) { grouping.addFieldCommand(field, rb.req); diff --git a/solr/core/src/java/org/apache/solr/search/AbstractReRankQuery.java b/solr/core/src/java/org/apache/solr/search/AbstractReRankQuery.java index c87565813e7a..f7ee42d8c445 100644 --- a/solr/core/src/java/org/apache/solr/search/AbstractReRankQuery.java +++ b/solr/core/src/java/org/apache/solr/search/AbstractReRankQuery.java @@ -25,6 +25,7 @@ import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.Rescorer; +import org.apache.lucene.search.Sort; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.TopDocsCollector; import org.apache.lucene.search.Weight; @@ -58,6 +59,10 @@ public MergeStrategy getMergeStrategy() { @SuppressWarnings("unchecked") public TopDocsCollector getTopDocsCollector(int len, QueryCommand cmd, IndexSearcher searcher) throws IOException { + return getTopDocsCollector(len, cmd.getSort(), searcher); + } + + public TopDocsCollector getTopDocsCollector(int len, Sort sort, IndexSearcher searcher) throws IOException { if(this.boostedPriority == null) { SolrRequestInfo info = SolrRequestInfo.getRequestInfo(); if(info != null) { @@ -65,8 +70,11 @@ public TopDocsCollector getTopDocsCollector(int len, QueryCommand cmd, IndexSear this.boostedPriority = (Set)context.get(QueryElevationComponent.BOOSTED); } } + return new ReRankCollector(reRankDocs, len, sort, mainQuery, reRankQueryRescorer, searcher, boostedPriority); + } - return new ReRankCollector(reRankDocs, len, reRankQueryRescorer, cmd, searcher, boostedPriority); + public int getReRankDocs(){ + return reRankDocs; } public Query rewrite(IndexReader reader) throws IOException { diff --git a/solr/core/src/java/org/apache/solr/search/ExportQParserPlugin.java b/solr/core/src/java/org/apache/solr/search/ExportQParserPlugin.java index bdf943e05514..572c2a56c10c 100644 --- a/solr/core/src/java/org/apache/solr/search/ExportQParserPlugin.java +++ b/solr/core/src/java/org/apache/solr/search/ExportQParserPlugin.java @@ -29,6 +29,7 @@ import org.apache.lucene.search.Scorable; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Sort; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocsCollector; import org.apache.lucene.search.TotalHits; @@ -95,9 +96,18 @@ public Query rewrite(IndexReader reader) throws IOException { } } + @Deprecated public TopDocsCollector getTopDocsCollector(int len, QueryCommand cmd, IndexSearcher searcher) throws IOException { + + return getTopDocsCollector(len, cmd.getSort(), searcher); + } + + @Override + public TopDocsCollector getTopDocsCollector(int len, + Sort sort, + IndexSearcher searcher) throws IOException { int leafCount = searcher.getTopReaderContext().leaves().size(); FixedBitSet[] sets = new FixedBitSet[leafCount]; return new ExportCollector(sets); diff --git a/solr/core/src/java/org/apache/solr/search/Grouping.java b/solr/core/src/java/org/apache/solr/search/Grouping.java index 4869386c3d89..ac6c9ff7a6f9 100644 --- a/solr/core/src/java/org/apache/solr/search/Grouping.java +++ b/solr/core/src/java/org/apache/solr/search/Grouping.java @@ -19,7 +19,9 @@ import java.io.IOException; import java.lang.invoke.MethodHandles; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; +import java.util.Comparator; import java.util.LinkedHashSet; import java.util.List; import java.util.Locale; @@ -66,6 +68,7 @@ import org.apache.solr.schema.SchemaField; import org.apache.solr.schema.StrFieldSource; import org.apache.solr.search.grouping.collector.FilterCollector; +import org.apache.solr.search.grouping.collector.ReRankTopGroupsCollector; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -92,6 +95,7 @@ public class Grouping { private int limitDefault; private int docsPerGroupDefault; private int groupOffsetDefault; + private int reRankGroups; private Format defaultFormat; private TotalCount defaultTotalCount; @@ -477,6 +481,10 @@ public boolean isSignalCacheWarning() { return signalCacheWarning; } + public void setReRankGroups(int reRankGroups) { + this.reRankGroups = reRankGroups; + } + //====================================== Inner classes ============================================================= public static enum Format { @@ -718,7 +726,11 @@ public class CommandField extends Command { @Override protected void prepare() throws IOException { - actualGroupsToFind = getMax(offset, numGroups, maxDoc); + if (reRankGroups > 0){ + actualGroupsToFind = getMax(offset, reRankGroups, maxDoc); + } else { + actualGroupsToFind = getMax(offset, numGroups, maxDoc); + } } @Override @@ -736,6 +748,7 @@ protected Collector createFirstPassCollector() throws IOException { @Override protected Collector createSecondPassCollector() throws IOException { + actualGroupsToFind = getMax(offset, numGroups, maxDoc); if (actualGroupsToFind <= 0) { allGroupsCollector = new AllGroupsCollector<>(new TermGroupSelector(groupBy)); return totalCount == TotalCount.grouped ? allGroupsCollector : null; @@ -756,9 +769,17 @@ protected Collector createSecondPassCollector() throws IOException { int groupedDocsToCollect = getMax(groupOffset, docsPerGroup, maxDoc); groupedDocsToCollect = Math.max(groupedDocsToCollect, 1); Sort withinGroupSort = this.withinGroupSort != null ? this.withinGroupSort : Sort.RELEVANCE; - secondPass = new TopGroupsCollector<>(new TermGroupSelector(groupBy), - topGroups, groupSort, withinGroupSort, groupedDocsToCollect, needScores - ); + + if (query instanceof RankQuery) { + secondPass = new ReRankTopGroupsCollector<>(new TermGroupSelector(groupBy), + topGroups, groupSort, withinGroupSort, groupedDocsToCollect, needScores, needScores, false, (RankQuery) query, searcher); + } + else { + secondPass = new TopGroupsCollector<>(new TermGroupSelector(groupBy), + topGroups, groupSort, withinGroupSort, groupedDocsToCollect, needScores + ); + } + if (totalCount == TotalCount.grouped) { allGroupsCollector = new AllGroupsCollector<>(new TermGroupSelector(groupBy)); @@ -780,6 +801,26 @@ protected void finish() throws IOException { result = secondPass.getTopGroups(0); populateScoresIfNecessary(); } + if (result != null && query instanceof RankQuery && groupSort == Sort.RELEVANCE) { + // if we are sorting for relevance and query is a RankQuery, it may be that + // the order of the groups changed, we need to reorder + GroupDocs[] groups = result.groups; + Arrays.sort(groups, new Comparator() { + @Override + public int compare(GroupDocs o1, GroupDocs o2) { + if (o1.maxScore > o2.maxScore) return -1; + if (o1.maxScore < o2.maxScore) return 1; + return 0; + } + }); + + result = new TopGroups(groupSort.getSort(), + withinGroupSort.getSort(), + result.totalHitCount, result.totalGroupedHitCount, Arrays.copyOfRange(result.groups, 0, Math.min(result.groups.length, limitDefault)), + maxScore); + } + + if (main) { mainResult = createSimpleResponse(); return; @@ -926,7 +967,11 @@ private ValueSourceGroupSelector newSelector() { protected void prepare() throws IOException { context = ValueSource.newContext(searcher); groupBy.createWeight(context, searcher); - actualGroupsToFind = getMax(offset, numGroups, maxDoc); + if (reRankGroups > 0){ + actualGroupsToFind = getMax(offset, reRankGroups, maxDoc); + } else { + actualGroupsToFind = getMax(offset, numGroups, maxDoc); + } } @Override @@ -944,6 +989,7 @@ protected Collector createFirstPassCollector() throws IOException { @Override protected Collector createSecondPassCollector() throws IOException { + actualGroupsToFind = getMax(offset, numGroups, maxDoc); if (actualGroupsToFind <= 0) { allGroupsCollector = new AllGroupsCollector<>(newSelector()); return totalCount == TotalCount.grouped ? allGroupsCollector : null; @@ -964,9 +1010,15 @@ protected Collector createSecondPassCollector() throws IOException { int groupdDocsToCollect = getMax(groupOffset, docsPerGroup, maxDoc); groupdDocsToCollect = Math.max(groupdDocsToCollect, 1); Sort withinGroupSort = this.withinGroupSort != null ? this.withinGroupSort : Sort.RELEVANCE; - secondPass = new TopGroupsCollector<>(newSelector(), - topGroups, groupSort, withinGroupSort, groupdDocsToCollect, needScores - ); + + if (query instanceof RankQuery){ + secondPass = new ReRankTopGroupsCollector<>(newSelector(), + topGroups, groupSort, withinGroupSort, groupdDocsToCollect, needScores, needScores, false, (RankQuery)query, searcher); + } else { + secondPass = new TopGroupsCollector<>(newSelector(), + topGroups, groupSort, withinGroupSort, groupdDocsToCollect, needScores + ); + } if (totalCount == TotalCount.grouped) { allGroupsCollector = new AllGroupsCollector<>(newSelector()); @@ -988,6 +1040,27 @@ protected void finish() throws IOException { result = secondPass.getTopGroups(0); populateScoresIfNecessary(); } + + if (result != null && query instanceof RankQuery && groupSort == Sort.RELEVANCE) { + // if we are sorting for relevance and query is a RankQuery, it may be that + // the order of the groups changed, we need to reorder + GroupDocs[] groups = result.groups; + Arrays.sort(groups, new Comparator() { + @Override + public int compare(GroupDocs o1, GroupDocs o2) { + if (o1.maxScore > o2.maxScore) return -1; + if (o1.maxScore < o2.maxScore) return 1; + return 0; + } + }); + result = new TopGroups(groupSort.getSort(), + withinGroupSort.getSort(), + result.totalHitCount, result.totalGroupedHitCount, Arrays.copyOfRange(result.groups, 0, Math.min(result.groups.length, limitDefault)), + maxScore); + + } + + if (main) { mainResult = createSimpleResponse(); return; diff --git a/solr/core/src/java/org/apache/solr/search/RankQuery.java b/solr/core/src/java/org/apache/solr/search/RankQuery.java index 4812e4117a91..d93f4ca1b799 100644 --- a/solr/core/src/java/org/apache/solr/search/RankQuery.java +++ b/solr/core/src/java/org/apache/solr/search/RankQuery.java @@ -17,8 +17,9 @@ package org.apache.solr.search; import org.apache.lucene.search.IndexSearcher; -import org.apache.lucene.search.TopDocsCollector; import org.apache.lucene.search.Query; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.TopDocsCollector; import org.apache.solr.handler.component.MergeStrategy; import java.io.IOException; @@ -29,8 +30,9 @@ public abstract class RankQuery extends ExtendedQueryBase { + @Deprecated public abstract TopDocsCollector getTopDocsCollector(int len, QueryCommand cmd, IndexSearcher searcher) throws IOException; + public abstract TopDocsCollector getTopDocsCollector(int len, Sort sort, IndexSearcher searcher) throws IOException; public abstract MergeStrategy getMergeStrategy(); public abstract RankQuery wrap(Query mainQuery); - } diff --git a/solr/core/src/java/org/apache/solr/search/ReRankCollector.java b/solr/core/src/java/org/apache/solr/search/ReRankCollector.java index ed917c3f86f3..a8fb804cee66 100644 --- a/solr/core/src/java/org/apache/solr/search/ReRankCollector.java +++ b/solr/core/src/java/org/apache/solr/search/ReRankCollector.java @@ -51,21 +51,29 @@ public class ReRankCollector extends TopDocsCollector { final private Set boostedPriority; // order is the "priority" final private Rescorer reRankQueryRescorer; final private Sort sort; - final private Query query; - + final private Query mainQuery; + @Deprecated public ReRankCollector(int reRankDocs, int length, Rescorer reRankQueryRescorer, QueryCommand cmd, IndexSearcher searcher, Set boostedPriority) throws IOException { + this(reRankDocs, length, cmd.getSort(), cmd.getQuery(), reRankQueryRescorer, searcher, boostedPriority); + } + + public ReRankCollector(int reRankDocs, + int length, + Sort sort, + Query mainQuery, + Rescorer reRankQueryRescorer, + IndexSearcher searcher, + Set boostedPriority) throws IOException { super(null); this.reRankDocs = reRankDocs; this.length = length; this.boostedPriority = boostedPriority; - this.query = cmd.getQuery(); - Sort sort = cmd.getSort(); if(sort == null) { this.sort = null; this.mainCollector = TopScoreDocCollector.create(Math.max(this.reRankDocs, length), Integer.MAX_VALUE); @@ -73,9 +81,11 @@ public ReRankCollector(int reRankDocs, this.sort = sort = sort.rewrite(searcher); //scores are needed for Rescorer (regardless of whether sort needs it) this.mainCollector = TopFieldCollector.create(sort, Math.max(this.reRankDocs, length), Integer.MAX_VALUE); + } this.searcher = searcher; this.reRankQueryRescorer = reRankQueryRescorer; + this.mainQuery = mainQuery; } public int getTotalHits() { @@ -103,7 +113,7 @@ public TopDocs topDocs(int start, int howMany) { } if (sort != null) { - TopFieldCollector.populateScores(mainDocs.scoreDocs, searcher, query); + TopFieldCollector.populateScores(mainDocs.scoreDocs, searcher, mainQuery); } ScoreDoc[] mainScoreDocs = mainDocs.scoreDocs; @@ -138,6 +148,7 @@ public TopDocs topDocs(int start, int howMany) { ScoreDoc[] scoreDocs = new ScoreDoc[howMany]; System.arraycopy(mainScoreDocs, 0, scoreDocs, 0, scoreDocs.length); //lay down the initial docs System.arraycopy(rescoredDocs.scoreDocs, 0, scoreDocs, 0, rescoredDocs.scoreDocs.length);//overlay the re-ranked docs. + rescoredDocs.scoreDocs = scoreDocs; return rescoredDocs; } else { diff --git a/solr/core/src/java/org/apache/solr/search/ReRankWeight.java b/solr/core/src/java/org/apache/solr/search/ReRankWeight.java index 9c11a894200c..b7651e7c2d5f 100644 --- a/solr/core/src/java/org/apache/solr/search/ReRankWeight.java +++ b/solr/core/src/java/org/apache/solr/search/ReRankWeight.java @@ -24,6 +24,7 @@ import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.Rescorer; +import org.apache.lucene.search.ScorerSupplier; import org.apache.lucene.search.Weight; /** @@ -40,9 +41,9 @@ public ReRankWeight(Query mainQuery, Rescorer reRankQueryRescorer, IndexSearcher this.reRankQueryRescorer = reRankQueryRescorer; } + @Override public Explanation explain(LeafReaderContext context, int doc) throws IOException { final Explanation mainExplain = in.explain(context, doc); return reRankQueryRescorer.explain(searcher, mainExplain, context.docBase+doc); } - } diff --git a/solr/core/src/java/org/apache/solr/search/SolrIndexSearcher.java b/solr/core/src/java/org/apache/solr/search/SolrIndexSearcher.java index 210e0ad66a9d..ca68e05b6f38 100644 --- a/solr/core/src/java/org/apache/solr/search/SolrIndexSearcher.java +++ b/solr/core/src/java/org/apache/solr/search/SolrIndexSearcher.java @@ -1504,22 +1504,25 @@ private void populateNextCursorMarkFromTopDocs(QueryResult qr, QueryCommand qc, private TopDocsCollector buildTopDocsCollector(int len, QueryCommand cmd) throws IOException { Query q = cmd.getQuery(); - if (q instanceof RankQuery) { - RankQuery rq = (RankQuery) q; - return rq.getTopDocsCollector(len, cmd, this); - } if (null == cmd.getSort()) { assert null == cmd.getCursorMark() : "have cursor but no sort"; + if (q instanceof RankQuery) { + RankQuery rq = (RankQuery) q; + return rq.getTopDocsCollector(len, cmd.getSort(), this); + } return TopScoreDocCollector.create(len, Integer.MAX_VALUE); - } else { - // we have a sort - final Sort weightedSort = weightSort(cmd.getSort()); - final CursorMark cursor = cmd.getCursorMark(); - - final FieldDoc searchAfter = (null != cursor ? cursor.getSearchAfterFieldDoc() : null); - return TopFieldCollector.create(weightedSort, len, searchAfter, Integer.MAX_VALUE); } + // we have a sort + final Sort weightedSort = weightSort(cmd.getSort()); + if (q instanceof RankQuery) { + RankQuery rq = (RankQuery) q; + return rq.getTopDocsCollector(len, weightedSort, this); + } + final CursorMark cursor = cmd.getCursorMark(); + + final FieldDoc searchAfter = (null != cursor ? cursor.getSearchAfterFieldDoc() : null); + return TopFieldCollector.create(weightedSort, len, searchAfter, Integer.MAX_VALUE); } private void getDocListNC(QueryResult qr, QueryCommand cmd) throws IOException { diff --git a/solr/core/src/java/org/apache/solr/search/grouping/collector/ReRankTopGroupsCollector.java b/solr/core/src/java/org/apache/solr/search/grouping/collector/ReRankTopGroupsCollector.java new file mode 100644 index 000000000000..b15c2cd72470 --- /dev/null +++ b/solr/core/src/java/org/apache/solr/search/grouping/collector/ReRankTopGroupsCollector.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.search.grouping.collector; + +import java.io.IOException; +import java.lang.invoke.MethodHandles; +import java.util.Collection; +import java.util.Objects; +import java.util.function.Supplier; + +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TopDocsCollector; +import org.apache.lucene.search.TopFieldCollector; +import org.apache.lucene.search.TopScoreDocCollector; +import org.apache.lucene.search.grouping.GroupDocs; +import org.apache.lucene.search.grouping.GroupReducer; +import org.apache.lucene.search.grouping.GroupSelector; +import org.apache.lucene.search.grouping.SearchGroup; +import org.apache.lucene.search.grouping.TopGroups; +import org.apache.lucene.search.grouping.TopGroupsCollector; +import org.apache.solr.search.AbstractReRankQuery; +import org.apache.solr.search.RankQuery; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class ReRankTopGroupsCollector extends TopGroupsCollector { + + private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); + private Sort groupSort; + private Sort withinGroupSort; + private int maxDocsPerGroup; + + /** + * Create a new TopGroupsCollector + * @param groupSelector the group selector used to define groups + * @param groups the groups to collect TopDocs for + * @param groupSort the order in which groups are returned + * @param withinGroupSort the order in which documents are sorted in each group + * @param maxDocsPerGroup the maximum number of docs to collect for each group + * @param getScores if true, record the scores of all docs in each group + * @param getMaxScores if true, record the maximum score for each group + * @param fillSortFields if true, record the sort field values for all docs + * @param query the rankQuery if provided by the user, null otherwise + * @param searcher an index searcher + */ + public ReRankTopGroupsCollector(GroupSelector groupSelector, Collection> groups, Sort groupSort, Sort withinGroupSort, + int maxDocsPerGroup, boolean getScores, boolean getMaxScores, boolean fillSortFields, RankQuery query, IndexSearcher searcher) { + super(new ReRankTopGroupsCollector.TopDocsReducer<>(withinGroupSort, maxDocsPerGroup, getScores, getMaxScores, query, searcher), groupSelector, groups, groupSort, withinGroupSort, maxDocsPerGroup); + this.groupSort = Objects.requireNonNull(groupSort); + this.withinGroupSort = Objects.requireNonNull(withinGroupSort); + this.maxDocsPerGroup = maxDocsPerGroup; + } + + private static class TopDocsReducer extends GroupReducer> { + + private final Supplier> supplier; + private final boolean needsScores; + private final RankQuery query; + private final IndexSearcher searcher; + private final Sort groupSort; + private final int maxDocsPerGroup; + + TopDocsReducer(Sort withinGroupSort, + int maxDocsPerGroup, boolean getScores, boolean getMaxScores, RankQuery query, IndexSearcher searcher) { + this.needsScores = getScores || getMaxScores || withinGroupSort.needsScores(); + if (withinGroupSort == Sort.RELEVANCE) { + this.supplier = () -> TopScoreDocCollector.create(maxDocsPerGroup, Integer.MAX_VALUE); + } + else { + this.supplier = () -> TopFieldCollector.create(withinGroupSort, maxDocsPerGroup, Integer.MAX_VALUE); // TODO: disable exact counts? + } + this.query = query; + this.searcher = searcher; + this.groupSort = withinGroupSort; + this.maxDocsPerGroup = maxDocsPerGroup; + + } + + + @Override + public boolean needsScores() { + return needsScores; + } + + @Override + protected TopDocsCollector newCollector() { + TopDocsCollector collector = supplier.get(); + final int len; + if (query instanceof AbstractReRankQuery){ + len = ((AbstractReRankQuery) query).getReRankDocs(); + } else { + len = maxDocsPerGroup; + } + try { + collector = this.query.getTopDocsCollector(len, groupSort, searcher); + } catch (IOException e) { + // this should never happen + log.error("Cannot rerank groups ", e); + } + return collector; + } + } + + /** + * Get the TopGroups recorded by this collector + * @param withinGroupOffset the offset within each group to start collecting documents + */ + public TopGroups getTopGroups(int withinGroupOffset) { + @SuppressWarnings({"unchecked","rawtypes"}) + final GroupDocs[] groupDocsResult = (GroupDocs[]) new GroupDocs[groups.size()]; + + int groupIDX = 0; + float maxScore = Float.MIN_VALUE; + for(SearchGroup group : groups) { + TopDocsCollector collector = (TopDocsCollector) groupReducer.getCollector(group.groupValue); + final TopDocs topDocs = collector.topDocs(withinGroupOffset, maxDocsPerGroup); + float topDocsMaxScore = topDocs.scoreDocs.length == 0 ? Float.NaN : topDocs.scoreDocs[0].score; + groupDocsResult[groupIDX++] = new GroupDocs<>(Float.NaN, + topDocsMaxScore, + topDocs.totalHits, + topDocs.scoreDocs, + group.groupValue, + group.sortValues); + if (! Float.isNaN(topDocsMaxScore)) { + maxScore = Math.max(maxScore, topDocsMaxScore); + } + } + return new TopGroups<>(groupSort.getSort(), + withinGroupSort.getSort(), + totalHitCount, totalGroupedHitCount, groupDocsResult, + maxScore); + } +} diff --git a/solr/core/src/java/org/apache/solr/search/grouping/distributed/command/TopGroupsFieldCommand.java b/solr/core/src/java/org/apache/solr/search/grouping/distributed/command/TopGroupsFieldCommand.java index b6182864e603..dc4f8a9f14b6 100644 --- a/solr/core/src/java/org/apache/solr/search/grouping/distributed/command/TopGroupsFieldCommand.java +++ b/solr/core/src/java/org/apache/solr/search/grouping/distributed/command/TopGroupsFieldCommand.java @@ -18,6 +18,7 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HashMap; @@ -39,7 +40,9 @@ import org.apache.lucene.util.mutable.MutableValue; import org.apache.solr.schema.FieldType; import org.apache.solr.schema.SchemaField; +import org.apache.solr.search.RankQuery; import org.apache.solr.search.grouping.Command; +import org.apache.solr.search.grouping.collector.ReRankTopGroupsCollector; /** * Defines all collectors for retrieving the second phase and how to handle the collector result. @@ -56,12 +59,18 @@ public static class Builder { private Integer maxDocPerGroup; private boolean needScores = false; private boolean needMaxScore = false; + private IndexSearcher searcher; public Builder setQuery(Query query) { this.query = query; return this; } + public Builder setSearcher(IndexSearcher searcher) { + this.searcher = searcher; + return this; + } + public Builder setField(SchemaField field) { this.field = field; return this; @@ -103,13 +112,14 @@ public TopGroupsFieldCommand build() { throw new IllegalStateException("All required fields must be set"); } - return new TopGroupsFieldCommand(query, field, groupSort, withinGroupSort, firstPhaseGroups, maxDocPerGroup, needScores, needMaxScore); + return new TopGroupsFieldCommand(query, field, groupSort, withinGroupSort, firstPhaseGroups, maxDocPerGroup, needScores, needMaxScore, searcher); } } private final Query query; private final SchemaField field; + private final IndexSearcher searcher; private final Sort groupSort; private final Sort withinGroupSort; private final Collection> firstPhaseGroups; @@ -126,7 +136,9 @@ private TopGroupsFieldCommand(Query query, Collection> firstPhaseGroups, int maxDocPerGroup, boolean needScores, - boolean needMaxScore) { + boolean needMaxScore, + IndexSearcher searcher + ) { this.query = query; this.field = field; this.groupSort = groupSort; @@ -135,6 +147,7 @@ private TopGroupsFieldCommand(Query query, this.maxDocPerGroup = maxDocPerGroup; this.needScores = needScores; this.needMaxScore = needMaxScore; + this.searcher = searcher; } @Override @@ -142,19 +155,28 @@ public List create() throws IOException { if (firstPhaseGroups.isEmpty()) { return Collections.emptyList(); } - final List collectors = new ArrayList<>(1); final FieldType fieldType = field.getType(); if (fieldType.getNumberType() != null) { ValueSource vs = fieldType.getValueSource(field, null); Collection> v = GroupConverter.toMutable(field, firstPhaseGroups); - secondPassCollector = new TopGroupsCollector<>(new ValueSourceGroupSelector(vs, new HashMap<>()), - v, groupSort, withinGroupSort, maxDocPerGroup, needMaxScore - ); + if (query instanceof RankQuery){ + secondPassCollector = new ReRankTopGroupsCollector<>(new ValueSourceGroupSelector(vs, new HashMap<>()), v, + groupSort, withinGroupSort, maxDocPerGroup, needScores, needMaxScore, true, (RankQuery)query, searcher); + } else { + secondPassCollector = new TopGroupsCollector<>(new ValueSourceGroupSelector(vs, new HashMap<>()), + v, groupSort, withinGroupSort, maxDocPerGroup, needMaxScore + ); + } } else { - secondPassCollector = new TopGroupsCollector<>(new TermGroupSelector(field.getName()), - firstPhaseGroups, groupSort, withinGroupSort, maxDocPerGroup, needMaxScore - ); + if (query instanceof RankQuery) { + secondPassCollector = new ReRankTopGroupsCollector(new TermGroupSelector(field.getName()),firstPhaseGroups, groupSort, withinGroupSort, maxDocPerGroup, needScores, needMaxScore, true, (RankQuery)query, searcher); + } + else { + secondPassCollector = new TopGroupsCollector<>(new TermGroupSelector(field.getName()), + firstPhaseGroups, groupSort, withinGroupSort, maxDocPerGroup, needMaxScore + ); + } } collectors.add(secondPassCollector); return collectors; diff --git a/solr/core/src/java/org/apache/solr/search/grouping/distributed/responseprocessor/SearchGroupShardResponseProcessor.java b/solr/core/src/java/org/apache/solr/search/grouping/distributed/responseprocessor/SearchGroupShardResponseProcessor.java index 163c38d6da83..87de8561c78f 100644 --- a/solr/core/src/java/org/apache/solr/search/grouping/distributed/responseprocessor/SearchGroupShardResponseProcessor.java +++ b/solr/core/src/java/org/apache/solr/search/grouping/distributed/responseprocessor/SearchGroupShardResponseProcessor.java @@ -37,6 +37,8 @@ import org.apache.solr.handler.component.ShardRequest; import org.apache.solr.handler.component.ShardResponse; import org.apache.solr.response.SolrQueryResponse; +import org.apache.solr.search.AbstractReRankQuery; +import org.apache.solr.search.RankQuery; import org.apache.solr.search.SortSpec; import org.apache.solr.search.grouping.distributed.ShardResponseProcessor; import org.apache.solr.search.grouping.distributed.command.SearchGroupsFieldCommandResult; @@ -139,7 +141,14 @@ public void process(ResponseBuilder rb, ShardRequest shardRequest) { rb.firstPhaseElapsedTime = maxElapsedTime; for (String groupField : commandSearchGroups.keySet()) { List>> topGroups = commandSearchGroups.get(groupField); - Collection> mergedTopGroups = SearchGroup.merge(topGroups, groupSortSpec.getOffset(), groupSortSpec.getCount(), groupSort); + final int topN; + RankQuery rq = rb.getRankQuery(); + if (rq instanceof AbstractReRankQuery){ + topN = ((AbstractReRankQuery) rq).getReRankDocs(); + } else { + topN = rb.getSortSpec().getCount(); + } + Collection> mergedTopGroups = SearchGroup.merge(topGroups, groupSortSpec.getOffset(), topN, groupSort); if (mergedTopGroups == null) { continue; } diff --git a/solr/core/src/java/org/apache/solr/search/grouping/distributed/responseprocessor/TopGroupsShardResponseProcessor.java b/solr/core/src/java/org/apache/solr/search/grouping/distributed/responseprocessor/TopGroupsShardResponseProcessor.java index 2db6b22f9ee0..22245c667eba 100644 --- a/solr/core/src/java/org/apache/solr/search/grouping/distributed/responseprocessor/TopGroupsShardResponseProcessor.java +++ b/solr/core/src/java/org/apache/solr/search/grouping/distributed/responseprocessor/TopGroupsShardResponseProcessor.java @@ -19,6 +19,8 @@ import java.io.PrintWriter; import java.io.StringWriter; import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -39,7 +41,9 @@ import org.apache.solr.handler.component.ShardRequest; import org.apache.solr.handler.component.ShardResponse; import org.apache.solr.response.SolrQueryResponse; +import org.apache.solr.search.AbstractReRankQuery; import org.apache.solr.search.Grouping; +import org.apache.solr.search.RankQuery; import org.apache.solr.search.grouping.distributed.ShardResponseProcessor; import org.apache.solr.search.grouping.distributed.command.QueryCommandResult; import org.apache.solr.search.grouping.distributed.shardresultserializer.TopGroupsResultTransformer; @@ -65,7 +69,7 @@ public void process(ResponseBuilder rb, ShardRequest shardRequest) { } else { groupOffsetDefault = rb.getGroupingSpec().getWithinGroupOffset(); } - int docsPerGroupDefault = rb.getGroupingSpec().getWithinGroupLimit(); + final int docsPerGroupDefault = rb.getGroupingSpec().getWithinGroupLimit(); Map>> commandTopGroups = new HashMap<>(); for (String field : fields) { @@ -159,9 +163,46 @@ public void process(ResponseBuilder rb, ShardRequest shardRequest) { docsPerGroup += subTopGroups.totalGroupedHitCount; } } - rb.mergedTopGroups.put(groupField, TopGroups.merge(topGroups.toArray(topGroupsArr), groupSort, withinGroupSort, groupOffsetDefault, docsPerGroup, TopGroups.ScoreMergeMode.None)); + + if (rb.getRankQuery() != null){ + docsPerGroup = Math.max(docsPerGroupDefault, ((AbstractReRankQuery)rb.getRankQuery()).getReRankDocs()); + rb.mergedTopGroups.put(groupField, TopGroups.merge(topGroups.toArray(topGroupsArr), groupSort, withinGroupSort, groupOffsetDefault, docsPerGroup, TopGroups.ScoreMergeMode.None)); + + TopGroups group = rb.mergedTopGroups.get(groupField); + for (int i = 0; i < group.groups.length; i++){ + GroupDocs currentGroup = group.groups[i]; + Arrays.sort(currentGroup.scoreDocs, new Comparator() { + @Override + public int compare(ScoreDoc o1, ScoreDoc o2) { + if (o1.score > o2.score) return -1; + if (o2.score < o1.score) return 1; + return 0; + } + }); + ScoreDoc[] scoreDocs = currentGroup.scoreDocs; + if (scoreDocs.length > docsPerGroupDefault) { + scoreDocs = Arrays.copyOf(currentGroup.scoreDocs, docsPerGroupDefault); + } + group.groups[i] = new GroupDocs(Float.NaN, currentGroup.maxScore, currentGroup.totalHits, scoreDocs, currentGroup.groupValue, currentGroup.groupSortValues); + } + Arrays.sort(group.groups, new Comparator>() { + @Override + public int compare(GroupDocs o1, GroupDocs o2) { + if (o1.maxScore > o2.maxScore) return -1; + if (o2.maxScore < o1.maxScore) return 1; + return 0; + } + }); + int topN = Math.min(group.groups.length, rb.getSortSpec().getCount()); + group = new TopGroups(group.groupSort, group.withinGroupSort, group.totalHitCount, group.totalGroupedHitCount, Arrays.copyOfRange(group.groups, 0, topN), group.maxScore); + rb.mergedTopGroups.put(groupField, group); + } + else { + rb.mergedTopGroups.put(groupField, TopGroups.merge(topGroups.toArray(topGroupsArr), groupSort, withinGroupSort, groupOffsetDefault, docsPerGroup, TopGroups.ScoreMergeMode.None)); + } } + for (String query : commandTopDocs.keySet()) { List queryCommandResults = commandTopDocs.get(query); List topDocs = new ArrayList<>(queryCommandResults.size()); diff --git a/solr/core/src/java/org/apache/solr/search/grouping/distributed/shardresultserializer/TopGroupsResultTransformer.java b/solr/core/src/java/org/apache/solr/search/grouping/distributed/shardresultserializer/TopGroupsResultTransformer.java index 3327dd7cf44f..70397edda811 100644 --- a/solr/core/src/java/org/apache/solr/search/grouping/distributed/shardresultserializer/TopGroupsResultTransformer.java +++ b/solr/core/src/java/org/apache/solr/search/grouping/distributed/shardresultserializer/TopGroupsResultTransformer.java @@ -215,6 +215,9 @@ protected NamedList serializeTopGroups(TopGroups data, SchemaField gro } FieldDoc fieldDoc = (FieldDoc) searchGroup.scoreDocs[i]; + + assert(fieldDoc != null && fieldDoc.fields != null); + Object[] convertedSortValues = new Object[fieldDoc.fields.length]; for (int j = 0; j < fieldDoc.fields.length; j++) { Object sortValue = fieldDoc.fields[j]; diff --git a/solr/core/src/test/org/apache/solr/TestDistributedGrouping.java b/solr/core/src/test/org/apache/solr/TestDistributedGrouping.java index 534a29963bd9..831e9bf7bf29 100644 --- a/solr/core/src/test/org/apache/solr/TestDistributedGrouping.java +++ b/solr/core/src/test/org/apache/solr/TestDistributedGrouping.java @@ -28,6 +28,7 @@ import org.apache.solr.common.params.ModifiableSolrParams; import org.apache.solr.common.util.NamedList; import org.apache.solr.SolrTestCaseJ4.SuppressPointFields; +import org.apache.solr.search.ReRankQParserPlugin; import org.junit.Test; /** @@ -310,6 +311,30 @@ public void test() throws Exception { //Debug simpleQuery("q", "*:*", "rows", 10, "fl", "id," + i1, "group", "true", "group.field", i1, "debug", "true"); + + // SOLR-8776 + simpleQuery("q", "{!func}id_i1", "rows", 10, "fl", "id," + i1+",score", "group", "true", + "group.field", i1, "group.limit", 1, "rq", "{!" + ReRankQParserPlugin.NAME + " " + ReRankQParserPlugin.RERANK_QUERY + "=$rqq " + + ReRankQParserPlugin.RERANK_DOCS + "=1000}", "rqq", "{!func}"+i1); // original rank: by id, rerank by i1 field (final score will be 2 * i1 + id) + rsp = query("q", "{!func}id_i1", "rows", 10, "fl", "id," + i1+",score", "group", "true", + "group.field", i1, "group.limit", 1, "rq", "{!" + ReRankQParserPlugin.NAME + " " + ReRankQParserPlugin.RERANK_QUERY + "=$rqq " + + ReRankQParserPlugin.RERANK_DOCS + "=1000}", "rqq", "{!func}"+i1); + int groupLimit = 1; + + rsp = query("q", "{!func}id_i1", "rows", 100, "fl", "id," + i1, "group", "true", + "group.field", i1, "group.limit", groupLimit, "rq", "{!" + ReRankQParserPlugin.NAME + " " + ReRankQParserPlugin.RERANK_QUERY + "=$rqq " + + ReRankQParserPlugin.RERANK_DOCS + "=1000}", "rqq", t1+":eggs"); + + rsp = query("q", "{!func}id_i1", "rows", 2, "fl", "id,score," + i1, "group", "true", + "group.field", i1dv, "group.limit", groupLimit, "rq", "{!" + ReRankQParserPlugin.NAME + " " + ReRankQParserPlugin.RERANK_QUERY + "=$rqq " + + ReRankQParserPlugin.RERANK_DOCS + "=1000}", "rqq", "{!func }"+i1); + + // test random limit between [2..5] + groupLimit = random().nextInt(3)+2; + + rsp = query("q", "{!func}id_i1", "rows", 2, "fl", "id,score," + i1, "group", "true", + "group.field", i1dv, "group.limit", groupLimit, "rq", "{!" + ReRankQParserPlugin.NAME + " " + ReRankQParserPlugin.RERANK_QUERY + "=$rqq " + + ReRankQParserPlugin.RERANK_DOCS + "=1000}", "rqq", "{!func }"+i1); } private void simpleQuery(Object... queryParams) throws SolrServerException, IOException { diff --git a/solr/core/src/test/org/apache/solr/search/RankQueryTestPlugin.java b/solr/core/src/test/org/apache/solr/search/RankQueryTestPlugin.java index 102488313cd4..24463b25ab80 100644 --- a/solr/core/src/test/org/apache/solr/search/RankQueryTestPlugin.java +++ b/solr/core/src/test/org/apache/solr/search/RankQueryTestPlugin.java @@ -132,12 +132,6 @@ public TestRankQuery(int collector, int mergeStrategy) { this.mergeStrategy = mergeStrategy; } - public TopDocsCollector getTopDocsCollector(int len, QueryCommand cmd, IndexSearcher searcher) { - if(collector == 0) - return new TestCollector(null); - else - return new TestCollector1(null); - } public MergeStrategy getMergeStrategy() { if(mergeStrategy == 0) @@ -145,6 +139,21 @@ public MergeStrategy getMergeStrategy() { else return new TestMergeStrategy1(); } + + @Override + public TopDocsCollector getTopDocsCollector(int len, QueryCommand cmd, IndexSearcher searcher) throws IOException { + return getTopDocsCollector(len, cmd.getSort(), searcher); + } + + @Override + public TopDocsCollector getTopDocsCollector(int len, Sort sort, IndexSearcher searcher) + throws IOException { + if(collector == 0) + return new TestCollector(null); + else + return new TestCollector1(null); + } + } static class TestMergeStrategy implements MergeStrategy { diff --git a/solr/core/src/test/org/apache/solr/search/TestReRankQParserPlugin.java b/solr/core/src/test/org/apache/solr/search/TestReRankQParserPlugin.java index b3e01f278437..c7f153152474 100644 --- a/solr/core/src/test/org/apache/solr/search/TestReRankQParserPlugin.java +++ b/solr/core/src/test/org/apache/solr/search/TestReRankQParserPlugin.java @@ -88,6 +88,9 @@ public void testReRankQueries() throws Exception { ModifiableSolrParams params = new ModifiableSolrParams(); params.add("rq", "{!"+ReRankQParserPlugin.NAME+" "+ReRankQParserPlugin.RERANK_QUERY+"=$rqq "+ReRankQParserPlugin.RERANK_DOCS+"=200}"); params.add("q", "term_s:YYYY"); + // rank query, it will match all the + // documents containing the term YYYY, and rerank + // them using the value in the field test_fi params.add("rqq", "{!edismax bf=$bff}*:*"); params.add("bff", "field(test_ti)"); params.add("start", "0"); @@ -605,6 +608,236 @@ public void testRerankQueryParsingShouldFailWithoutMandatoryReRankQueryParameter } + @Test + public void testRerankQueryAndGrouping() throws Exception { + assertU(delQ("*:*")); + assertU(commit()); + + String[] doc1 = {"id", "1", "term_s", "YYYY", "group_s", "group1", "test_ti", "5", "test_tl", "9", "test_tf", + "2000"}; + assertU(adoc(doc1)); + assertU(commit()); + String[] doc2 = {"id", "2", "term_s", "YYYY", "group_s", "group1", "test_ti", "50", "test_tl", "8", "test_tf", + "200"}; + assertU(adoc(doc2)); + assertU(commit()); + String[] doc3 = {"id", "3", "term_s", "YYYY", "group_s", "group2", "test_ti", "100", "test_tl", "10", "test_tf", + "2000"}; + assertU(adoc(doc3)); + assertU(commit()); + String[] doc4 = {"id", "4", "term_s", "YYYY", "group_s", "group2", "test_ti", "74", "test_tl", "11", "test_tf", + "200"}; + assertU(adoc(doc4)); + assertU(commit()); + + ModifiableSolrParams params = new ModifiableSolrParams(); + + params.add("rq", "{!" + ReRankQParserPlugin.NAME + " " + ReRankQParserPlugin.RERANK_QUERY + "=$rqq " + + ReRankQParserPlugin.RERANK_DOCS + "=200}"); + + // first query will sort the documents on the value of test_tl + params.add("q", "{!edismax bq=$bqq1}*:*"); + params.add("bqq1", "{!func }field(test_tl)"); + // rank query, rerank documents on the value of test_ti + params.add("rqq", "{!func }field(test_ti)"); + params.add("start", "0"); + params.add("rows", "6"); + params.add("fl", "id,score"); + + assertQ(req(params), "*[count(//doc)=4]", + "//result/doc[1]/str[@name='id'][.='3']", // group2 + "//result/doc[2]/str[@name='id'][.='4']", // group2 + "//result/doc[3]/str[@name='id'][.='2']", // group1 + "//result/doc[4]/str[@name='id'][.='1']");// group1 + + System.out.println(h.query(req(params))); + + params.add("group", "true"); + params.add("group.field", "group_s"); + + assertQ(req(params),"*[count(//doc)=2]", + "//arr/lst[1]/result/doc/str[@name='id'][.='3']", // instead is 4.0 + "//arr/lst[1]/result[@maxScore=211.0]", + "//arr/lst[2]/result/doc/str[@name='id'][.='2']", // instead is 1.0 + "//arr/lst[2]/result[@maxScore=109.0]" + ); + + } + + @Test + public void testRerankQueryAndGroupingRerankGroups() throws Exception { + assertU(delQ("*:*")); + assertU(commit()); + String[] doc1 = {"id", "1", "term_s", "YYYY", "group_s", "group1", "test_ti", "5", "test_tl", "9", "test_tf", + "2000"}; + assertU(adoc(doc1)); + assertU(commit()); + String[] doc2 = {"id", "2", "term_s", "YYYY", "group_s", "group1", "test_ti", "50", "test_tl", "8", "test_tf", + "200"}; + assertU(adoc(doc2)); + assertU(commit()); + String[] doc3 = {"id", "3", "term_s", "YYYY", "group_s", "group2", "test_ti", "100", "test_tl", "2", "test_tf", + "2000"}; + assertU(adoc(doc3)); + assertU(commit()); + String[] doc4 = {"id", "4", "term_s", "YYYY", "group_s", "group2", "test_ti", "74", "test_tl", "3", "test_tf", + "200"}; + assertU(adoc(doc4)); + String[] doc5 = {"id", "5", "term_s", "YYYY", "group_s", "group3", "test_ti", "1000", "test_tl", "1", "test_tf", + "200"}; + assertU(adoc(doc5)); + assertU(commit()); + + ModifiableSolrParams params = new ModifiableSolrParams(); + + // first query will sort the documents on the value of test_tl + params.add("q", "{!edismax bq=$bqq1}*:*"); + params.add("bqq1", "{!func }field(test_tl)"); + + params.add("start", "0"); + params.add("rows", "6"); + params.add("fl", "id,score,[explain]"); + + assertQ(req(params), "*[count(//doc)=5]", + "//result/doc[1]/str[@name='id'][.='1']", // group1 + "//result/doc[2]/str[@name='id'][.='2']", // group1 + "//result/doc[3]/str[@name='id'][.='4']", // group2 + "//result/doc[4]/str[@name='id'][.='3']", // group2 + "//result/doc[5]/str[@name='id'][.='5']"); // group3 + + // test grouping + params.add("group", "true"); + params.add("group.field", "group_s"); + + assertQ(req(params),"*[count(//doc)=3]", + "//arr/lst[1]/result/doc/str[@name='id'][.='1']", + "//arr/lst[2]/result/doc/str[@name='id'][.='4']", + "//arr/lst[3]/result/doc/str[@name='id'][.='5']" + ); + + // add reranking + params.add("rq", "{!" + ReRankQParserPlugin.NAME + " " + ReRankQParserPlugin.RERANK_QUERY + "=$rqq " + + ReRankQParserPlugin.RERANK_DOCS + "=200}"); + //rank query, rerank documents on the value of test_ti + params.add("rqq", "{!func }field(test_ti)"); + + params.remove("group"); + params.remove("group.field"); + + assertQ(req(params), "*[count(//doc)=5]", + "//result/doc[1]/str[@name='id'][.='5']", // group3 + "//result/doc[2]/str[@name='id'][.='3']", // group2 + "//result/doc[3]/str[@name='id'][.='4']", // group2 + "//result/doc[4]/str[@name='id'][.='2']", // group1 + "//result/doc[5]/str[@name='id'][.='1']");// group1 + + // now grouping and reranking should rescore the documents inside the groups and then + // reorder the groups if the scores changed: + // so: + // + // the result should be group3[doc5], group2[doc3], group1[doc2] + + params.add("group", "true"); + params.add("group.field", "group_s"); + + assertQ(req(params),"*[count(//doc)=3]", + "//arr/lst[1]/result/doc/str[@name='id'][.='5']", + "//arr/lst[2]/result/doc/str[@name='id'][.='3']", + "//arr/lst[3]/result/doc/str[@name='id'][.='2']" + ); + // test grouping by function: + // documents are + // 1. firstly scored by their test_tl score and then + // 2. then grouped by the value of log(test_tf) + // 3. finally reranked by test_ti + // rerank query will add the score so: + + // 1. first pass (doc_id, score): results = (1, 9) (2, 8) (4, 3), (3, 2) (5, 1) + // 2. grouping: group1(max score=9) = [(1, 9) (3, 2)], group2(max score = 8) = [(2, 8) (4, 3) (5, 1)] + // 3. reranking (scoring): group1 = [ (1, 9 + 5), (3, 2 + 100) ] group2 [ (2, 8 + 50), (4, 3 + 74), (5, 1 + 1000)] + // reordered: group2(max score = 1001) [ (5, 1001) (4, 77) (2, 58)) group1(max score = 102) = [ (3, 102) (1, 14) ] + // (actual scores will be offset by 1.0 because of a constant boost (=1) added to the scores) + + params = new ModifiableSolrParams(); + params.add("q", "{!edismax bq=$bqq1}*:*"); + params.add("bqq1", "{!func }field(test_tl)"); + params.add("start", "0"); + params.add("rows", "6"); + params.add("fl", "id,score,[explain]"); + params.add("group", "true"); + params.add("group.func", "log(test_tf)"); + params.remove("rq"); + params.add("rq", "{!" + ReRankQParserPlugin.NAME + " " + ReRankQParserPlugin.RERANK_QUERY + "=$rqq " + + ReRankQParserPlugin.RERANK_DOCS + "=200 "+ ReRankQParserPlugin.RERANK_WEIGHT + "=1 }"); + //rank query, rerank documents on the value of test_ti + params.add("rqq", "{!func }field(test_ti)"); + + assertQ(req(params), "*[count(//doc)=2]", + "//arr/lst[1]/result/doc/str[@name='id'][.='5']", + "//arr/lst[2]/result/doc/str[@name='id'][.='3']", + "//arr/lst[1]/result[@maxScore=1002.0]", + "//arr/lst[2]/result[@maxScore=103.0]" + ); + + } + + @Test + public void testRerankTopGroups() throws Exception { + assertU(delQ("*:*")); + assertU(commit()); + String[] doc1 = {"id", "1", "term_s", "YYYY", "group_s", "group1", "test_ti", "5", "test_tl", "9", "test_tf", + "2000"}; + assertU(adoc(doc1)); + assertU(commit()); + String[] doc2 = {"id", "2", "term_s", "YYYY", "group_s", "group1", "test_ti", "50", "test_tl", "8", "test_tf", + "200"}; + assertU(adoc(doc2)); + assertU(commit()); + String[] doc3 = {"id", "3", "term_s", "YYYY", "group_s", "group2", "test_ti", "100", "test_tl", "2", "test_tf", + "2000"}; + assertU(adoc(doc3)); + assertU(commit()); + String[] doc4 = {"id", "4", "term_s", "YYYY", "group_s", "group2", "test_ti", "74", "test_tl", "3", "test_tf", + "200"}; + assertU(adoc(doc4)); + String[] doc5 = {"id", "5", "term_s", "YYYY", "group_s", "group3", "test_ti", "1000", "test_tl", "1", "test_tf", + "200"}; + assertU(adoc(doc5)); + assertU(commit()); + + // The first pass must return the top N groups reranked, otherwise reranking can change the documents selected + // within a group and the order of the groups, but can put a better group in the ranking. + ModifiableSolrParams params = new ModifiableSolrParams(); + + // first query will sort the documents on the value of test_tl + params.add("q", "{!func}test_tl"); + params.add("start", "0"); + params.add("rows", "1"); + params.add("fl", "id,score,[explain]"); + + // test grouping + params.add("group", "true"); + params.add("group.field", "group_s"); + + assertQ(req(params),"*[count(//doc)=1]", + "//arr/lst[1]/result/doc/str[@name='id'][.='1']"); + + params.remove("rq"); + params.add("rq", "{!" + ReRankQParserPlugin.NAME + " " + ReRankQParserPlugin.RERANK_QUERY + "=$rqq " + + ReRankQParserPlugin.RERANK_DOCS + "=3 "+ ReRankQParserPlugin.RERANK_WEIGHT + "=1 }"); + //rank query, rerank 3 groups on the value of test_ti + params.add("rqq", "{!func }field(test_ti)"); + assertQ(req(params), "*[count(//doc)=1]", + "//arr/lst[1]/result/doc/str[@name='id'][.='5']"); + + params.remove("rq"); + params.add("rq", "{!" + ReRankQParserPlugin.NAME + " " + ReRankQParserPlugin.RERANK_QUERY + "=$rqq " + + ReRankQParserPlugin.RERANK_DOCS + "=2 "+ ReRankQParserPlugin.RERANK_WEIGHT + "=1 }"); + //rank query, rerank 2 (instead of 3) groups on the value of test_ti + assertQ(req(params), "*[count(//doc)=1]", + "//arr/lst[1]/result/doc/str[@name='id'][.='3']"); + } + @Test public void testReRankQueriesWithDefType() throws Exception {