diff --git a/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java b/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java index 6fe683546d06..30a3a358358e 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java @@ -88,11 +88,17 @@ protected List> assignClassNormalizedList(String return asignedClassesNorm; } + private double calculateLogPrior(BytesRef cclass) throws IOException { + Term term = new Term(this.classFieldName, cclass); + int docsWithC = indexReader.docFreq(term); + return Math.log((double) docsWithC) - Math.log(docsWithClassSize); + } + private List> calculateLogLikelihood(String[] tokenizedText) throws IOException { // initialize the return List ArrayList> ret = new ArrayList<>(); for (BytesRef cclass : cclasses) { - ClassificationResult cr = new ClassificationResult<>(cclass, 0d); + ClassificationResult cr = new ClassificationResult<>(cclass, calculateLogPrior(cclass)); ret.add(cr); } // for each word