From 1e3ee66305a6786984d579727cccf3af9e5f13b8 Mon Sep 17 00:00:00 2001 From: Christine Poerschke Date: Thu, 22 Jul 2021 17:29:27 +0100 Subject: [PATCH] SOLR-15537: split 10-args LTRRescorer.scoreSingleHit method (#192) --- .../java/org/apache/solr/ltr/LTRRescorer.java | 45 ++++++++++++++----- .../interleaving/LTRInterleavingRescorer.java | 5 ++- 2 files changed, 39 insertions(+), 11 deletions(-) diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java index 3c1f85b2eb32..0f1f2c98e220 100644 --- a/solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java @@ -178,13 +178,41 @@ public void scoreFeatures(IndexSearcher indexSearcher, docBase = readerContext.docBase; scorer = modelWeight.scorer(readerContext); } - scoreSingleHit(indexSearcher, topN, modelWeight, docBase, hitUpto, hit, docID, scoringQuery, scorer, reranked); + if (scoreSingleHit(topN, docBase, hitUpto, hit, docID, scorer, reranked)) { + logSingleHit(indexSearcher, modelWeight, hit.doc, scoringQuery); + } hitUpto++; } } + /** + * @deprecated Use {@link #scoreSingleHit(int, int, int, ScoreDoc, int, org.apache.solr.ltr.LTRScoringQuery.ModelWeight.ModelScorer, ScoreDoc[])} + * and {@link #logSingleHit(IndexSearcher, org.apache.solr.ltr.LTRScoringQuery.ModelWeight, int, LTRScoringQuery)} instead. + */ + @Deprecated protected static void scoreSingleHit(IndexSearcher indexSearcher, int topN, LTRScoringQuery.ModelWeight modelWeight, int docBase, int hitUpto, ScoreDoc hit, int docID, LTRScoringQuery rerankingQuery, LTRScoringQuery.ModelWeight.ModelScorer scorer, ScoreDoc[] reranked) throws IOException { - final FeatureLogger featureLogger = rerankingQuery.getFeatureLogger(); + if (scoreSingleHit(topN, docBase, hitUpto, hit, docID, scorer, reranked)) { + logSingleHit(indexSearcher, modelWeight, hit.doc, rerankingQuery); + } + } + + /** + * Call this method if the {@link #scoreSingleHit(int, int, int, ScoreDoc, int, org.apache.solr.ltr.LTRScoringQuery.ModelWeight.ModelScorer, ScoreDoc[])} + * method indicated that the document's feature info should be logged. + */ + protected static void logSingleHit(IndexSearcher indexSearcher, LTRScoringQuery.ModelWeight modelWeight, int docid, LTRScoringQuery scoringQuery) { + final FeatureLogger featureLogger = scoringQuery.getFeatureLogger(); + if (featureLogger != null && indexSearcher instanceof SolrIndexSearcher) { + featureLogger.log(docid, scoringQuery, (SolrIndexSearcher)indexSearcher, modelWeight.getFeaturesInfo()); + } + } + + /** + * Scores a single document and returns true if the document's feature info should be logged via the + * {@link #logSingleHit(IndexSearcher, org.apache.solr.ltr.LTRScoringQuery.ModelWeight, int, LTRScoringQuery)} + * method. Feature info logging is only necessary for the topN documents. + */ + protected static boolean scoreSingleHit(int topN, int docBase, int hitUpto, ScoreDoc hit, int docID, LTRScoringQuery.ModelWeight.ModelScorer scorer, ScoreDoc[] reranked) throws IOException { // Scorer for a LTRScoringQuery.ModelWeight should never be null since we always have to // call score // even if no feature scorers match, since a model might use that info to @@ -198,16 +226,15 @@ protected static void scoreSingleHit(IndexSearcher indexSearcher, int topN, LTRS scorer.docID(); scorer.iterator().advance(targetDoc); + boolean logHit = false; + scorer.getDocInfo().setOriginalDocScore(hit.score); hit.score = scorer.score(); if (hitUpto < topN) { reranked[hitUpto] = hit; // if the heap is not full, maybe I want to log the features for this // document - if (featureLogger != null && indexSearcher instanceof SolrIndexSearcher) { - featureLogger.log(hit.doc, rerankingQuery, (SolrIndexSearcher) indexSearcher, - modelWeight.getFeaturesInfo()); - } + logHit = true; } else if (hitUpto == topN) { // collected topN document, I create the heap heapify(reranked, topN); @@ -221,12 +248,10 @@ protected static void scoreSingleHit(IndexSearcher indexSearcher, int topN, LTRS if (hit.score > reranked[0].score) { reranked[0] = hit; heapAdjust(reranked, topN, 0); - if (featureLogger != null && indexSearcher instanceof SolrIndexSearcher) { - featureLogger.log(hit.doc, rerankingQuery, (SolrIndexSearcher) indexSearcher, - modelWeight.getFeaturesInfo()); - } + logHit = true; } } + return logHit; } @Override diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/interleaving/LTRInterleavingRescorer.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/interleaving/LTRInterleavingRescorer.java index 082138b76e55..799f4d9e36a8 100644 --- a/solr/contrib/ltr/src/java/org/apache/solr/ltr/interleaving/LTRInterleavingRescorer.java +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/interleaving/LTRInterleavingRescorer.java @@ -142,7 +142,10 @@ public void scoreFeatures(IndexSearcher indexSearcher, } for (int i = 0; i < rerankingQueries.length; i++) { if (modelWeights[i] != null) { - scoreSingleHit(indexSearcher, topN, modelWeights[i], docBase, hitUpto, new ScoreDoc(hit.doc, hit.score, hit.shardIndex), docID, rerankingQueries[i], scorers[i], rerankedPerModel[i]); + final ScoreDoc hit_i = new ScoreDoc(hit.doc, hit.score, hit.shardIndex); + if (scoreSingleHit(topN, docBase, hitUpto, hit_i, docID, scorers[i], rerankedPerModel[i])) { + logSingleHit(indexSearcher, modelWeights[i], hit_i.doc, rerankingQueries[i]); + } } } hitUpto++;