diff --git a/lucene/core/src/java/org/apache/lucene/search/BooleanScorerSupplier.java b/lucene/core/src/java/org/apache/lucene/search/BooleanScorerSupplier.java index 7732445e8cd4..7f2cb381953c 100644 --- a/lucene/core/src/java/org/apache/lucene/search/BooleanScorerSupplier.java +++ b/lucene/core/src/java/org/apache/lucene/search/BooleanScorerSupplier.java @@ -305,10 +305,26 @@ BulkScorer filteredOptionalBulkScorer() throws IOException { || subs.get(Occur.FILTER).isEmpty() || scoreMode != ScoreMode.TOP_SCORES || subs.get(Occur.SHOULD).size() <= 1 - || minShouldMatch > 1) { + || minShouldMatch != 1) { return null; } - long cost = cost(); + + long filterCost = Long.MAX_VALUE; + for (ScorerSupplier supplier : subs.get(Occur.FILTER)) { + filterCost = Math.min(filterCost, supplier.cost()); + } + + long shouldCost = 0; + for (ScorerSupplier supplier : subs.get(Occur.SHOULD)) { + shouldCost += supplier.cost(); + } + + if (filterCost < shouldCost) { + // Don't do bulk scoring if the filter leads iteration. + return null; + } + + long cost = Math.min(shouldCost, filterCost); List optionalScorers = new ArrayList<>(); for (ScorerSupplier ss : subs.get(Occur.SHOULD)) { optionalScorers.add(ss.get(cost)); @@ -317,13 +333,20 @@ BulkScorer filteredOptionalBulkScorer() throws IOException { for (ScorerSupplier ss : subs.get(Occur.FILTER)) { filters.add(ss.get(cost)); } - Scorer filterScorer; - if (filters.size() == 1) { - filterScorer = filters.iterator().next(); + + if (filters.stream().map(Scorer::twoPhaseIterator).anyMatch(Objects::nonNull)) { + Scorer scoring = new WANDScorer(optionalScorers, minShouldMatch, scoreMode, cost); + filters.add(scoring); + return new DefaultBulkScorer(new ConjunctionScorer(filters, Collections.singleton(scoring))); } else { - filterScorer = new ConjunctionScorer(filters, Collections.emptySet()); + Scorer filterScorer; + if (filters.size() == 1) { + filterScorer = filters.iterator().next(); + } else { + filterScorer = new ConjunctionScorer(filters, Collections.emptySet()); + } + return new MaxScoreBulkScorer(maxDoc, optionalScorers, filterScorer); } - return new MaxScoreBulkScorer(maxDoc, optionalScorers, filterScorer); } // Return a BulkScorer for the required clauses only diff --git a/lucene/core/src/java/org/apache/lucene/search/MaxScoreBulkScorer.java b/lucene/core/src/java/org/apache/lucene/search/MaxScoreBulkScorer.java index 93dd1ea91e31..223bf0dd1d9a 100644 --- a/lucene/core/src/java/org/apache/lucene/search/MaxScoreBulkScorer.java +++ b/lucene/core/src/java/org/apache/lucene/search/MaxScoreBulkScorer.java @@ -50,10 +50,17 @@ final class MaxScoreBulkScorer extends BulkScorer { private final long[] windowMatches = new long[FixedBitSet.bits2words(INNER_WINDOW_SIZE)]; private final double[] windowScores = new double[INNER_WINDOW_SIZE]; + private final FixedBitSet filterMatches; MaxScoreBulkScorer(int maxDoc, List scorers, Scorer filter) throws IOException { this.maxDoc = maxDoc; - this.filter = filter == null ? null : new DisiWrapper(filter, false); + if (filter == null) { + this.filter = null; + filterMatches = null; + } else { + this.filter = new DisiWrapper(filter, false); + filterMatches = new FixedBitSet(INNER_WINDOW_SIZE); + } allScorers = new DisiWrapper[scorers.size()]; scratch = new DisiWrapper[allScorers.length]; int i = 0; @@ -143,72 +150,100 @@ public int score(LeafCollector collector, Bits acceptDocs, int min, int max) thr private void scoreInnerWindow( LeafCollector collector, Bits acceptDocs, int max, DisiWrapper filter) throws IOException { - if (filter != null) { - scoreInnerWindowWithFilter(collector, acceptDocs, max, filter); - } else if (allScorers.length - firstRequiredScorer >= 2) { + + if (allScorers.length - firstRequiredScorer >= 2 && filter == null) { scoreInnerWindowAsConjunction(collector, acceptDocs, max); } else { DisiWrapper top = essentialQueue.top(); DisiWrapper top2 = essentialQueue.top2(); - if (top2 == null) { - scoreInnerWindowSingleEssentialClause(collector, acceptDocs, max); - } else if (top2.doc - INNER_WINDOW_SIZE / 2 >= top.doc) { - // The first half of the window would match a single clause. Let's collect this single - // clause until the next doc ID of the next clause. - scoreInnerWindowSingleEssentialClause(collector, acceptDocs, Math.min(max, top2.doc)); + + if (top2 == null || top2.doc - INNER_WINDOW_SIZE / 2 >= top.doc) { + if (top2 != null) { + // The first half of the window would match a single clause. Let's collect this single + // clause until the next doc ID of the next clause. + max = Math.min(max, top2.doc); + } + if (filter == null) { + scoreInnerWindowSingleEssentialClause(collector, acceptDocs, max); + } else { + scoreInnerWindowSingleEssentialClauseWithFilter(collector, acceptDocs, max, filter); + } } else { - scoreInnerWindowMultipleEssentialClauses(collector, acceptDocs, max); + if (filter == null) { + scoreInnerWindowMultipleEssentialClauses(collector, acceptDocs, max); + } else { + scoreInnerWindowMultipleEssentialClausesWithFilter(collector, acceptDocs, max, filter); + } } } } - private void scoreInnerWindowWithFilter( + private void scoreInnerWindowSingleEssentialClauseWithFilter( LeafCollector collector, Bits acceptDocs, int max, DisiWrapper filter) throws IOException { - // TODO: Sometimes load the filter into a bitset and use the more optimized execution paths with - // this bitset as `acceptDocs` - DisiWrapper top = essentialQueue.top(); - assert top.doc < max; - if (top.doc < filter.doc) { - top.doc = top.approximation.advance(filter.doc); + + // With a single essential clause we skip the bitset and compute directly the intersection + // between the single essential clause and the filter. + while (top.doc < max) { + if (filter.doc < top.doc) { + filter.doc = filter.iterator.advance(top.doc); + } + if (filter.doc == top.doc) { + if (acceptDocs == null || acceptDocs.get(top.doc)) { + scoreNonEssentialClauses(collector, top.doc, top.scorer.score(), firstEssentialScorer); + } + top.doc = top.iterator.nextDoc(); + } else { + top.doc = top.iterator.advance(filter.doc); + } } - // Only score an inner window, after that we'll check if the min competitive score has increased - // enough for a more favorable partitioning to be used. + top = essentialQueue.updateTop(); + } + + private void scoreInnerWindowMultipleEssentialClausesWithFilter( + LeafCollector collector, Bits acceptDocs, int max, DisiWrapper filter) throws IOException { + + DisiWrapper top = essentialQueue.top(); + int innerWindowMin = top.doc; int innerWindowMax = (int) Math.min(max, (long) innerWindowMin + INNER_WINDOW_SIZE); - while (top.doc < innerWindowMax) { - assert filter.doc <= top.doc; // invariant - if (filter.doc < top.doc) { - filter.doc = filter.approximation.advance(top.doc); + if (filter.doc < top.doc) { + filter.doc = filter.iterator.advance(top.doc); + } + for (int doc = filter.doc; doc < innerWindowMax; doc = filter.iterator.nextDoc()) { + if (acceptDocs == null || acceptDocs.get(doc)) { + filterMatches.set(doc - innerWindowMin); } + } + filter.doc = filter.iterator.docID(); - if (filter.doc != top.doc) { - do { - top.doc = top.iterator.advance(filter.doc); - top = essentialQueue.updateTop(); - } while (top.doc < filter.doc); - } else { - int doc = top.doc; - boolean match = - (acceptDocs == null || acceptDocs.get(doc)) - && (filter.twoPhaseView == null || filter.twoPhaseView.matches()); - double score = 0; - do { - if (match) { - score += top.scorer.score(); - } - top.doc = top.iterator.nextDoc(); - top = essentialQueue.updateTop(); - } while (top.doc == doc); - - if (match) { - scoreNonEssentialClauses(collector, doc, score, firstEssentialScorer); + while (top.doc < filter.doc) { + for (int doc = top.doc; doc < innerWindowMax; ) { + final int delta = doc - innerWindowMin; + int next = filterMatches.nextSetBit(doc - innerWindowMin); + if (next == DocIdSetIterator.NO_MORE_DOCS) { + break; + } else if (delta == next) { + windowMatches[delta >>> 6] |= 1L << delta; + windowScores[delta] += top.scorer.score(); + doc = top.iterator.nextDoc(); + } else { + doc = top.iterator.advance(innerWindowMin + next); } } + top.doc = top.iterator.docID(); + if (top.doc < filter.doc) { + top.doc = top.iterator.advance(filter.doc); + } + top = essentialQueue.updateTop(); } + + filterMatches.clear(); + + replayEssentialMatches(collector, innerWindowMin); } private void scoreInnerWindowSingleEssentialClause( @@ -314,6 +349,11 @@ private void scoreInnerWindowMultipleEssentialClauses( top = essentialQueue.updateTop(); } while (top.doc < innerWindowMax); + replayEssentialMatches(collector, innerWindowMin); + } + + private void replayEssentialMatches(LeafCollector collector, int innerWindowMin) + throws IOException { for (int wordIndex = 0; wordIndex < windowMatches.length; ++wordIndex) { long bits = windowMatches[wordIndex]; windowMatches[wordIndex] = 0L;