Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up filtered disjunctions by loading the filter into a bit set. #14024

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -305,10 +305,26 @@ BulkScorer filteredOptionalBulkScorer() throws IOException {
|| subs.get(Occur.FILTER).isEmpty()
|| scoreMode != ScoreMode.TOP_SCORES
|| subs.get(Occur.SHOULD).size() <= 1
|| minShouldMatch > 1) {
|| minShouldMatch != 1) {
return null;
}
long cost = cost();

long filterCost = Long.MAX_VALUE;
for (ScorerSupplier supplier : subs.get(Occur.FILTER)) {
filterCost = Math.min(filterCost, supplier.cost());
}

long shouldCost = 0;
for (ScorerSupplier supplier : subs.get(Occur.SHOULD)) {
shouldCost += supplier.cost();
}

if (filterCost < shouldCost) {
// Don't do bulk scoring if the filter leads iteration.
return null;
}

long cost = Math.min(shouldCost, filterCost);
List<Scorer> optionalScorers = new ArrayList<>();
for (ScorerSupplier ss : subs.get(Occur.SHOULD)) {
optionalScorers.add(ss.get(cost));
Expand All @@ -317,13 +333,20 @@ BulkScorer filteredOptionalBulkScorer() throws IOException {
for (ScorerSupplier ss : subs.get(Occur.FILTER)) {
filters.add(ss.get(cost));
}
Scorer filterScorer;
if (filters.size() == 1) {
filterScorer = filters.iterator().next();

if (filters.stream().map(Scorer::twoPhaseIterator).anyMatch(Objects::nonNull)) {
Scorer scoring = new WANDScorer(optionalScorers, minShouldMatch, scoreMode, cost);
filters.add(scoring);
return new DefaultBulkScorer(new ConjunctionScorer(filters, Collections.singleton(scoring)));
} else {
filterScorer = new ConjunctionScorer(filters, Collections.emptySet());
Scorer filterScorer;
if (filters.size() == 1) {
filterScorer = filters.iterator().next();
} else {
filterScorer = new ConjunctionScorer(filters, Collections.emptySet());
}
return new MaxScoreBulkScorer(maxDoc, optionalScorers, filterScorer);
}
return new MaxScoreBulkScorer(maxDoc, optionalScorers, filterScorer);
}

// Return a BulkScorer for the required clauses only
Expand Down
130 changes: 85 additions & 45 deletions lucene/core/src/java/org/apache/lucene/search/MaxScoreBulkScorer.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,17 @@ final class MaxScoreBulkScorer extends BulkScorer {

private final long[] windowMatches = new long[FixedBitSet.bits2words(INNER_WINDOW_SIZE)];
private final double[] windowScores = new double[INNER_WINDOW_SIZE];
private final FixedBitSet filterMatches;

MaxScoreBulkScorer(int maxDoc, List<Scorer> scorers, Scorer filter) throws IOException {
this.maxDoc = maxDoc;
this.filter = filter == null ? null : new DisiWrapper(filter, false);
if (filter == null) {
this.filter = null;
filterMatches = null;
} else {
this.filter = new DisiWrapper(filter, false);
filterMatches = new FixedBitSet(INNER_WINDOW_SIZE);
}
allScorers = new DisiWrapper[scorers.size()];
scratch = new DisiWrapper[allScorers.length];
int i = 0;
Expand Down Expand Up @@ -143,72 +150,100 @@ public int score(LeafCollector collector, Bits acceptDocs, int min, int max) thr

private void scoreInnerWindow(
LeafCollector collector, Bits acceptDocs, int max, DisiWrapper filter) throws IOException {
if (filter != null) {
scoreInnerWindowWithFilter(collector, acceptDocs, max, filter);
} else if (allScorers.length - firstRequiredScorer >= 2) {

if (allScorers.length - firstRequiredScorer >= 2 && filter == null) {
scoreInnerWindowAsConjunction(collector, acceptDocs, max);
} else {
DisiWrapper top = essentialQueue.top();
DisiWrapper top2 = essentialQueue.top2();
if (top2 == null) {
scoreInnerWindowSingleEssentialClause(collector, acceptDocs, max);
} else if (top2.doc - INNER_WINDOW_SIZE / 2 >= top.doc) {
// The first half of the window would match a single clause. Let's collect this single
// clause until the next doc ID of the next clause.
scoreInnerWindowSingleEssentialClause(collector, acceptDocs, Math.min(max, top2.doc));

if (top2 == null || top2.doc - INNER_WINDOW_SIZE / 2 >= top.doc) {
if (top2 != null) {
// The first half of the window would match a single clause. Let's collect this single
// clause until the next doc ID of the next clause.
max = Math.min(max, top2.doc);
}
if (filter == null) {
scoreInnerWindowSingleEssentialClause(collector, acceptDocs, max);
} else {
scoreInnerWindowSingleEssentialClauseWithFilter(collector, acceptDocs, max, filter);
}
} else {
scoreInnerWindowMultipleEssentialClauses(collector, acceptDocs, max);
if (filter == null) {
scoreInnerWindowMultipleEssentialClauses(collector, acceptDocs, max);
} else {
scoreInnerWindowMultipleEssentialClausesWithFilter(collector, acceptDocs, max, filter);
}
}
}
}

private void scoreInnerWindowWithFilter(
private void scoreInnerWindowSingleEssentialClauseWithFilter(
LeafCollector collector, Bits acceptDocs, int max, DisiWrapper filter) throws IOException {

// TODO: Sometimes load the filter into a bitset and use the more optimized execution paths with
// this bitset as `acceptDocs`

DisiWrapper top = essentialQueue.top();
assert top.doc < max;
if (top.doc < filter.doc) {
top.doc = top.approximation.advance(filter.doc);

// With a single essential clause we skip the bitset and compute directly the intersection
// between the single essential clause and the filter.
while (top.doc < max) {
if (filter.doc < top.doc) {
filter.doc = filter.iterator.advance(top.doc);
}
if (filter.doc == top.doc) {
if (acceptDocs == null || acceptDocs.get(top.doc)) {
scoreNonEssentialClauses(collector, top.doc, top.scorer.score(), firstEssentialScorer);
}
top.doc = top.iterator.nextDoc();
} else {
top.doc = top.iterator.advance(filter.doc);
}
}

// Only score an inner window, after that we'll check if the min competitive score has increased
// enough for a more favorable partitioning to be used.
top = essentialQueue.updateTop();
}

private void scoreInnerWindowMultipleEssentialClausesWithFilter(
LeafCollector collector, Bits acceptDocs, int max, DisiWrapper filter) throws IOException {

DisiWrapper top = essentialQueue.top();

int innerWindowMin = top.doc;
int innerWindowMax = (int) Math.min(max, (long) innerWindowMin + INNER_WINDOW_SIZE);

while (top.doc < innerWindowMax) {
assert filter.doc <= top.doc; // invariant
if (filter.doc < top.doc) {
filter.doc = filter.approximation.advance(top.doc);
if (filter.doc < top.doc) {
filter.doc = filter.iterator.advance(top.doc);
}
for (int doc = filter.doc; doc < innerWindowMax; doc = filter.iterator.nextDoc()) {
if (acceptDocs == null || acceptDocs.get(doc)) {
filterMatches.set(doc - innerWindowMin);
}
}
filter.doc = filter.iterator.docID();

if (filter.doc != top.doc) {
do {
top.doc = top.iterator.advance(filter.doc);
top = essentialQueue.updateTop();
} while (top.doc < filter.doc);
} else {
int doc = top.doc;
boolean match =
(acceptDocs == null || acceptDocs.get(doc))
&& (filter.twoPhaseView == null || filter.twoPhaseView.matches());
double score = 0;
do {
if (match) {
score += top.scorer.score();
}
top.doc = top.iterator.nextDoc();
top = essentialQueue.updateTop();
} while (top.doc == doc);

if (match) {
scoreNonEssentialClauses(collector, doc, score, firstEssentialScorer);
while (top.doc < filter.doc) {
for (int doc = top.doc; doc < innerWindowMax; ) {
final int delta = doc - innerWindowMin;
int next = filterMatches.nextSetBit(doc - innerWindowMin);
if (next == DocIdSetIterator.NO_MORE_DOCS) {
break;
} else if (delta == next) {
windowMatches[delta >>> 6] |= 1L << delta;
windowScores[delta] += top.scorer.score();
doc = top.iterator.nextDoc();
} else {
doc = top.iterator.advance(innerWindowMin + next);
}
}
top.doc = top.iterator.docID();
if (top.doc < filter.doc) {
top.doc = top.iterator.advance(filter.doc);
}
top = essentialQueue.updateTop();
}

filterMatches.clear();

replayEssentialMatches(collector, innerWindowMin);
}

private void scoreInnerWindowSingleEssentialClause(
Expand Down Expand Up @@ -314,6 +349,11 @@ private void scoreInnerWindowMultipleEssentialClauses(
top = essentialQueue.updateTop();
} while (top.doc < innerWindowMax);

replayEssentialMatches(collector, innerWindowMin);
}

private void replayEssentialMatches(LeafCollector collector, int innerWindowMin)
throws IOException {
for (int wordIndex = 0; wordIndex < windowMatches.length; ++wordIndex) {
long bits = windowMatches[wordIndex];
windowMatches[wordIndex] = 0L;
Expand Down
Loading