diff --git a/solr/core/src/java/org/apache/solr/search/AbstractReRankQuery.java b/solr/core/src/java/org/apache/solr/search/AbstractReRankQuery.java index cc3276152fd4..b37d79ed5db3 100644 --- a/solr/core/src/java/org/apache/solr/search/AbstractReRankQuery.java +++ b/solr/core/src/java/org/apache/solr/search/AbstractReRankQuery.java @@ -23,6 +23,7 @@ import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.Rescorer; +import org.apache.lucene.search.Sort; import org.apache.lucene.search.TopDocsCollector; import org.apache.lucene.search.Weight; import org.apache.lucene.util.BytesRef; @@ -54,6 +55,11 @@ public MergeStrategy getMergeStrategy() { } public TopDocsCollector getTopDocsCollector(int len, QueryCommand cmd, IndexSearcher searcher) throws IOException { + return getTopDocsCollector(null, len, cmd.getSort(), searcher); + } + + + public TopDocsCollector getTopDocsCollector(TopDocsCollector previousCollector, int len, Sort sort, IndexSearcher searcher) throws IOException { if(this.boostedPriority == null) { SolrRequestInfo info = SolrRequestInfo.getRequestInfo(); @@ -63,7 +69,7 @@ public TopDocsCollector getTopDocsCollector(int len, QueryCommand cmd, IndexSear } } - return new ReRankCollector(reRankDocs, len, reRankQueryRescorer, cmd, searcher, boostedPriority); + return new ReRankCollector(previousCollector, reRankDocs, len, sort, reRankQueryRescorer, searcher, boostedPriority); } public Query rewrite(IndexReader reader) throws IOException { @@ -81,3 +87,4 @@ public Weight createWeight(IndexSearcher searcher, boolean needsScores, float bo return new ReRankWeight(mainQuery, reRankQueryRescorer, searcher, mainWeight); } } + diff --git a/solr/core/src/java/org/apache/solr/search/ExportQParserPlugin.java b/solr/core/src/java/org/apache/solr/search/ExportQParserPlugin.java index 38bb74f9533a..a3c3cec8ba48 100644 --- a/solr/core/src/java/org/apache/solr/search/ExportQParserPlugin.java +++ b/solr/core/src/java/org/apache/solr/search/ExportQParserPlugin.java @@ -87,6 +87,15 @@ public Query rewrite(IndexReader reader) throws IOException { public TopDocsCollector getTopDocsCollector(int len, QueryCommand cmd, IndexSearcher searcher) throws IOException { + + return getTopDocsCollector(null, len, cmd.getSort(), searcher); + } + + @Override + public TopDocsCollector getTopDocsCollector(TopDocsCollector previousCollector, + int len, + Sort sort, + IndexSearcher searcher) throws IOException { int leafCount = searcher.getTopReaderContext().leaves().size(); FixedBitSet[] sets = new FixedBitSet[leafCount]; return new ExportCollector(sets); diff --git a/solr/core/src/java/org/apache/solr/search/Grouping.java b/solr/core/src/java/org/apache/solr/search/Grouping.java index 75011e77401b..9474709a6438 100644 --- a/solr/core/src/java/org/apache/solr/search/Grouping.java +++ b/solr/core/src/java/org/apache/solr/search/Grouping.java @@ -19,7 +19,9 @@ import java.io.IOException; import java.lang.invoke.MethodHandles; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; +import java.util.Comparator; import java.util.LinkedHashSet; import java.util.List; import java.util.Locale; @@ -68,6 +70,7 @@ import org.apache.solr.schema.SchemaField; import org.apache.solr.schema.StrFieldSource; import org.apache.solr.search.grouping.collector.FilterCollector; +import org.apache.solr.search.grouping.collector.RerankTermSecondPassGroupingCollector; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -758,9 +761,18 @@ protected Collector createSecondPassCollector() throws IOException { int groupedDocsToCollect = getMax(groupOffset, docsPerGroup, maxDoc); groupedDocsToCollect = Math.max(groupedDocsToCollect, 1); Sort withinGroupSort = this.withinGroupSort != null ? this.withinGroupSort : Sort.RELEVANCE; - secondPass = new TermSecondPassGroupingCollector( - groupBy, topGroups, groupSort, withinGroupSort, groupedDocsToCollect, needScores, needScores, false - ); + if (query instanceof RankQuery) { + secondPass = new RerankTermSecondPassGroupingCollector( + groupBy, topGroups, groupSort, withinGroupSort, searcher, (RankQuery)query, groupedDocsToCollect, needScores, + needScores, false + ); + } else { + secondPass = new TermSecondPassGroupingCollector( + groupBy, topGroups, groupSort, withinGroupSort, groupedDocsToCollect, needScores, + needScores, false + ); + } + if (totalCount == TotalCount.grouped) { allGroupsCollector = new TermAllGroupsCollector(groupBy); @@ -785,6 +797,20 @@ public AllGroupHeadsCollector createAllGroupCollector() throws IOException { @Override protected void finish() throws IOException { result = secondPass != null ? secondPass.getTopGroups(0) : null; + if (result != null && query instanceof RankQuery && groupSort == Sort.RELEVANCE) { + // if we are sorting for relevance and query is a RankQuery, it may be that + // the order of the groups changed, we need to reorder + GroupDocs[] groups = result.groups; + Arrays.sort(groups, new Comparator() { + @Override + public int compare(GroupDocs o1, GroupDocs o2) { + if (o1.maxScore > o2.maxScore) return -1; + if (o1.maxScore < o2.maxScore) return 1; + return 0; + } + }); + } + if (main) { mainResult = createSimpleResponse(); return; diff --git a/solr/core/src/java/org/apache/solr/search/RankQuery.java b/solr/core/src/java/org/apache/solr/search/RankQuery.java index 4812e4117a91..83e033cc8d85 100644 --- a/solr/core/src/java/org/apache/solr/search/RankQuery.java +++ b/solr/core/src/java/org/apache/solr/search/RankQuery.java @@ -17,8 +17,9 @@ package org.apache.solr.search; import org.apache.lucene.search.IndexSearcher; -import org.apache.lucene.search.TopDocsCollector; import org.apache.lucene.search.Query; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.TopDocsCollector; import org.apache.solr.handler.component.MergeStrategy; import java.io.IOException; @@ -29,7 +30,9 @@ public abstract class RankQuery extends ExtendedQueryBase { + @Deprecated public abstract TopDocsCollector getTopDocsCollector(int len, QueryCommand cmd, IndexSearcher searcher) throws IOException; + public abstract TopDocsCollector getTopDocsCollector(TopDocsCollector previousCollector, int len, Sort sort, IndexSearcher searcher) throws IOException; public abstract MergeStrategy getMergeStrategy(); public abstract RankQuery wrap(Query mainQuery); diff --git a/solr/core/src/java/org/apache/solr/search/ReRankCollector.java b/solr/core/src/java/org/apache/solr/search/ReRankCollector.java index 1ac1eaff436e..efd2479b50bd 100644 --- a/solr/core/src/java/org/apache/solr/search/ReRankCollector.java +++ b/solr/core/src/java/org/apache/solr/search/ReRankCollector.java @@ -47,21 +47,32 @@ public class ReRankCollector extends TopDocsCollector { final private int length; final private Map boostedPriority; final private Rescorer reRankQueryRescorer; + final private TopDocsCollector previousCollector; - + @Deprecated public ReRankCollector(int reRankDocs, int length, Rescorer reRankQueryRescorer, QueryCommand cmd, IndexSearcher searcher, Map boostedPriority) throws IOException { + this(null, reRankDocs, length, cmd.getSort(), reRankQueryRescorer, searcher, boostedPriority); + } + + public ReRankCollector(TopDocsCollector previousCollector, + int reRankDocs, + int length, + Sort sort, + Rescorer reRankQueryRescorer, + IndexSearcher searcher, + Map boostedPriority) throws IOException { super(null); - this.reRankDocs = reRankDocs; this.length = length; + this.reRankDocs = reRankDocs; this.boostedPriority = boostedPriority; - Sort sort = cmd.getSort(); + this.previousCollector = previousCollector; if(sort == null) { - this.mainCollector = TopScoreDocCollector.create(Math.max(this.reRankDocs, length)); + this.mainCollector = TopScoreDocCollector.create( Math.max(this.reRankDocs, length)); } else { sort = sort.rewrite(searcher); this.mainCollector = TopFieldCollector.create(sort, Math.max(this.reRankDocs, length), false, true, true); diff --git a/solr/core/src/java/org/apache/solr/search/SolrIndexSearcher.java b/solr/core/src/java/org/apache/solr/search/SolrIndexSearcher.java index c65084536374..50d87b3a95a9 100644 --- a/solr/core/src/java/org/apache/solr/search/SolrIndexSearcher.java +++ b/solr/core/src/java/org/apache/solr/search/SolrIndexSearcher.java @@ -1858,7 +1858,7 @@ private TopDocsCollector buildTopDocsCollector(int len, QueryCommand cmd) throws Query q = cmd.getQuery(); if (q instanceof RankQuery) { RankQuery rq = (RankQuery) q; - return rq.getTopDocsCollector(len, cmd, this); + return rq.getTopDocsCollector(TopScoreDocCollector.create(len), len, cmd.getSort(), this); } if (null == cmd.getSort()) { diff --git a/solr/core/src/java/org/apache/solr/search/grouping/collector/RerankTermSecondPassGroupingCollector.java b/solr/core/src/java/org/apache/solr/search/grouping/collector/RerankTermSecondPassGroupingCollector.java new file mode 100644 index 000000000000..c53f313c9be3 --- /dev/null +++ b/solr/core/src/java/org/apache/solr/search/grouping/collector/RerankTermSecondPassGroupingCollector.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.search.grouping.collector; + +import java.io.IOException; +import java.lang.invoke.MethodHandles; +import java.util.Collection; + +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.TopDocsCollector; +import org.apache.lucene.search.grouping.SearchGroup; +import org.apache.lucene.search.grouping.term.TermSecondPassGroupingCollector; +import org.apache.lucene.util.BytesRef; +import org.apache.solr.search.RankQuery; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class RerankTermSecondPassGroupingCollector extends TermSecondPassGroupingCollector { + + private static final Logger logger = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); + private static final int DEFAULT_GROUPING_RERANKING = 10; + + + public RerankTermSecondPassGroupingCollector(String groupField, Collection> groups, + Sort groupSort, Sort withinGroupSort, IndexSearcher searcher, RankQuery query, int maxDocsPerGroup, boolean getScores, boolean getMaxScores, + boolean fillSortFields) throws IOException { + super(groupField, groups, groupSort, withinGroupSort, maxDocsPerGroup, getScores, getMaxScores, fillSortFields); + + + for (SearchGroup group : groups) { + TopDocsCollector collector; + if (query != null) { + collector = groupMap.get(group.groupValue).collector; + collector = query.getTopDocsCollector(collector, DEFAULT_GROUPING_RERANKING, groupSort, searcher); + groupMap.put(group.groupValue, new SearchGroupDocs(group.groupValue, collector)); + } + } + } +} diff --git a/solr/core/src/test/org/apache/solr/search/TestRankQueryPlugin.java b/solr/core/src/test/org/apache/solr/search/TestRankQueryPlugin.java index b42861ace7b5..72ed60fe5195 100644 --- a/solr/core/src/test/org/apache/solr/search/TestRankQueryPlugin.java +++ b/solr/core/src/test/org/apache/solr/search/TestRankQueryPlugin.java @@ -128,12 +128,6 @@ public TestRankQuery(int collector, int mergeStrategy) { this.mergeStrategy = mergeStrategy; } - public TopDocsCollector getTopDocsCollector(int len, QueryCommand cmd, IndexSearcher searcher) { - if(collector == 0) - return new TestCollector(null); - else - return new TestCollector1(null); - } public MergeStrategy getMergeStrategy() { if(mergeStrategy == 0) @@ -141,6 +135,21 @@ public MergeStrategy getMergeStrategy() { else return new TestMergeStrategy1(); } + + @Override + public TopDocsCollector getTopDocsCollector(int len, QueryCommand cmd, IndexSearcher searcher) throws IOException { + return getTopDocsCollector(null, len, cmd.getSort(), searcher); + } + + @Override + public TopDocsCollector getTopDocsCollector(TopDocsCollector previousCollector, int length, Sort sort, IndexSearcher searcher) + throws IOException { + if(collector == 0) + return new TestCollector(null); + else + return new TestCollector1(null); + } + } class TestMergeStrategy implements MergeStrategy { diff --git a/solr/core/src/test/org/apache/solr/search/TestReRankQParserPlugin.java b/solr/core/src/test/org/apache/solr/search/TestReRankQParserPlugin.java index e4d6a5b5fff2..787831a62662 100644 --- a/solr/core/src/test/org/apache/solr/search/TestReRankQParserPlugin.java +++ b/solr/core/src/test/org/apache/solr/search/TestReRankQParserPlugin.java @@ -86,6 +86,9 @@ public void testReRankQueries() throws Exception { ModifiableSolrParams params = new ModifiableSolrParams(); params.add("rq", "{!"+ReRankQParserPlugin.NAME+" "+ReRankQParserPlugin.RERANK_QUERY+"=$rqq "+ReRankQParserPlugin.RERANK_DOCS+"=200}"); params.add("q", "term_s:YYYY"); + // rank query, it will match all the + // documents containing the term YYYY, and rerank + // them using the value in the field test_fi params.add("rqq", "{!edismax bf=$bff}*:*"); params.add("bff", "field(test_ti)"); params.add("start", "0"); @@ -600,4 +603,127 @@ public void testRerankQueryParsingShouldFailWithoutMandatoryReRankQueryParameter } } + @Test + //@BadApple(bugUrl = "https://issues.apache.org/jira/browse/SOLR-8776") + public void testRerankQueryAndGrouping() throws Exception { + assertU(delQ("*:*")); + assertU(commit()); + + String[] doc1 = {"id", "1", "term_s", "YYYY", "group_s", "group1", "test_ti", "5", "test_tl", "9", "test_tf", + "2000"}; + assertU(adoc(doc1)); + assertU(commit()); + String[] doc2 = {"id", "2", "term_s", "YYYY", "group_s", "group1", "test_ti", "50", "test_tl", "8", "test_tf", + "200"}; + assertU(adoc(doc2)); + assertU(commit()); + String[] doc3 = {"id", "3", "term_s", "YYYY", "group_s", "group2", "test_ti", "100", "test_tl", "10", "test_tf", + "2000"}; + assertU(adoc(doc3)); + assertU(commit()); + String[] doc4 = {"id", "4", "term_s", "YYYY", "group_s", "group2", "test_ti", "74", "test_tl", "11", "test_tf", + "200"}; + assertU(adoc(doc4)); + assertU(commit()); + + ModifiableSolrParams params = new ModifiableSolrParams(); + + params.add("rq", "{!" + ReRankQParserPlugin.NAME + " " + ReRankQParserPlugin.RERANK_QUERY + "=$rqq " + + ReRankQParserPlugin.RERANK_DOCS + "=200}"); + + // first query will sort the documents on the value of test_tl + params.add("q", "{!edismax bq=$bqq1}*:*"); + params.add("bqq1", "{!func }field(test_tl)"); + // rank query, rerank documents on the value of test_ti + params.add("rqq", "{!func }field(test_ti)"); + params.add("start", "0"); + params.add("rows", "6"); + params.add("fl", "id,score"); + + assertQ(req(params), "*[count(//doc)=4]", + "//result/doc[1]/float[@name='id'][.='3.0']", // group2 + "//result/doc[2]/float[@name='id'][.='4.0']", // group2 + "//result/doc[3]/float[@name='id'][.='2.0']", // group1 + "//result/doc[4]/float[@name='id'][.='1.0']");// group1 + + + params.add("group", "true"); + params.add("group.field", "group_s"); + + assertQ(req(params),"*[count(//doc)=2]", + "//arr/lst[1]/result/doc/float[@name='id'][.='3.0']", // instead is 4.0 + "//arr/lst[2]/result/doc/float[@name='id'][.='2.0']" // instead is 1.0 + ); + + } + + @Test + public void testRerankQueryAndGroupingRerankGroups() throws Exception { + assertU(delQ("*:*")); + assertU(commit()); + String[] doc1 = {"id", "1", "term_s", "YYYY", "group_s", "group1", "test_ti", "5", "test_tl", "9", "test_tf", + "2000"}; + assertU(adoc(doc1)); + assertU(commit()); + String[] doc2 = {"id", "2", "term_s", "YYYY", "group_s", "group1", "test_ti", "50", "test_tl", "8", "test_tf", + "200"}; + assertU(adoc(doc2)); + assertU(commit()); + String[] doc3 = {"id", "3", "term_s", "YYYY", "group_s", "group2", "test_ti", "100", "test_tl", "2", "test_tf", + "2000"}; + assertU(adoc(doc3)); + assertU(commit()); + String[] doc4 = {"id", "4", "term_s", "YYYY", "group_s", "group2", "test_ti", "74", "test_tl", "3", "test_tf", + "200"}; + assertU(adoc(doc4)); + String[] doc5 = {"id", "5", "term_s", "YYYY", "group_s", "group3", "test_ti", "1000", "test_tl", "1", "test_tf", + "200"}; + assertU(adoc(doc5)); + assertU(commit()); + + ModifiableSolrParams params = new ModifiableSolrParams(); + + // first query will sort the documents on the value of test_tl + params.add("q", "{!edismax bq=$bqq1}*:*"); + params.add("bqq1", "{!func }field(test_tl)"); + + params.add("start", "0"); + params.add("rows", "6"); + params.add("fl", "id,score,[explain]"); + + assertQ(req(params), "*[count(//doc)=5]", + "//result/doc[1]/float[@name='id'][.='1.0']", // group1 + "//result/doc[2]/float[@name='id'][.='2.0']", // group1 + "//result/doc[3]/float[@name='id'][.='4.0']", // group2 + "//result/doc[4]/float[@name='id'][.='3.0']", // group2 + "//result/doc[5]/float[@name='id'][.='5.0']"); // group3 + + // add reranking + params.add("rq", "{!" + ReRankQParserPlugin.NAME + " " + ReRankQParserPlugin.RERANK_QUERY + "=$rqq " + + ReRankQParserPlugin.RERANK_DOCS + "=200}"); + //rank query, rerank documents on the value of test_ti + params.add("rqq", "{!func }field(test_ti)"); + + assertQ(req(params), "*[count(//doc)=5]", + "//result/doc[1]/float[@name='id'][.='5.0']", // group3 + "//result/doc[2]/float[@name='id'][.='3.0']", // group2 + "//result/doc[3]/float[@name='id'][.='4.0']", // group2 + "//result/doc[4]/float[@name='id'][.='2.0']", // group1 + "//result/doc[5]/float[@name='id'][.='1.0']");// group1 + + // now grouping and reranking should rescore the documents inside the groups and then + // reorder the groups if the scores changed: + // so: + // + // the result should be group3[doc5], group2[doc3], group1[doc2] + + params.add("group", "true"); + params.add("group.field", "group_s"); + + assertQ(req(params),"*[count(//doc)=3]", + "//arr/lst[1]/result/doc/float[@name='id'][.='5.0']", + "//arr/lst[2]/result/doc/float[@name='id'][.='3.0']", + "//arr/lst[3]/result/doc/float[@name='id'][.='2.0']" + ); + } }