Skip to content

Commit

Permalink
fixed target docid check in LTRRescorer and unit test for extracting… (
Browse files Browse the repository at this point in the history
…#176)

* fixed target docid check in LTRRescorer  and unit test for extracting features when there are multiple segments

(cherry picked from commit dfa0e2cc3baa72cec1b6329891d14b451effbd74)

* renamed unit test file and added comments

(cherry picked from commit fcfb661574c973b8963401d58145ddaf1942b511)
  • Loading branch information
nsanthapuri authored and cpoerschke committed Oct 26, 2016
1 parent be3b843 commit 4f4454c
Show file tree
Hide file tree
Showing 6 changed files with 210 additions and 3 deletions.
4 changes: 2 additions & 2 deletions solr/contrib/ltr/example/solrconfig.xml
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,11 @@
Even older versions of Lucene used LogDocMergePolicy.
-->
<!--
<mergePolicy class="org.apache.lucene.index.TieredMergePolicy">
<mergePolicyFactory class="org.apache.lucene.index.TieredMergePolicyFactory">
<int name="maxMergeAtOnce">10</int>
<int name="segmentsPerTier">10</int>
<double name="noCFSRatio">0.1</double>
</mergePolicy>
</mergePolicyFactory>
-->

<!-- Merge Factor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ public static LTRScoringQuery.FeatureInfo[] extractFeaturesInfo(ModelWeight mode
final LeafReaderContext atomicContext = leafContexts.get(n);
final int deBasedDoc = docid - atomicContext.docBase;
final ModelScorer r = modelWeight.scorer(atomicContext);
if ( (r == null) || (r.iterator().advance(deBasedDoc) != docid) ) {
if ( (r == null) || (r.iterator().advance(deBasedDoc) != deBasedDoc) ) {
return new LTRScoringQuery.FeatureInfo[0];
} else {
if (originalDocScore != null) {
Expand Down
37 changes: 37 additions & 0 deletions solr/contrib/ltr/src/test-files/featureExamples/comp_features.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
[
{ "name":"origScore",
"class":"org.apache.solr.ltr.feature.OriginalScoreFeature",
"params":{},
"store": "feature-store-6"
},
{
"name": "descriptionTermFreq",
"class": "org.apache.solr.ltr.feature.SolrFeature",
"params": { "q" : "{!func}termfreq(description,${user_text})" },
"store": "feature-store-6"
},
{
"name": "popularity",
"class": "org.apache.solr.ltr.feature.SolrFeature",
"params": { "q" : "{!func}normHits"},
"store": "feature-store-6"
},
{
"name": "isPopular",
"class": "org.apache.solr.ltr.feature.SolrFeature",
"params": {"fq" : ["{!field f=popularity}201"] },
"store": "feature-store-6"
},
{
"name": "queryPartialMatch2",
"class": "org.apache.solr.ltr.feature.SolrFeature",
"params": {"q": "{!dismax qf=description mm=2}${user_text}" },
"store": "feature-store-6"
},
{
"name": "queryPartialMatch2.1",
"class": "org.apache.solr.ltr.feature.SolrFeature",
"params": {"q": "{!dismax qf=description mm=2}${user_text}" },
"store": "feature-store-6"
}
]
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
<field name="description" type="text_general" indexed="true" stored="true"/>
<field name="keywords" type="text_general" indexed="true" stored="true" multiValued="true"/>
<field name="popularity" type="int" indexed="true" stored="true" />
<field name="normHits" type="float" indexed="true" stored="true" />
<field name="text" type="text_general" indexed="true" stored="false" multiValued="true"/>
<field name="_version_" type="long" indexed="true" stored="true"/>

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
<?xml version="1.0" ?>
<!-- Licensed to the Apache Software Foundation (ASF) under one or more contributor
license agreements. See the NOTICE file distributed with this work for additional
information regarding copyright ownership. The ASF licenses this file to
You under the Apache License, Version 2.0 (the "License"); you may not use
this file except in compliance with the License. You may obtain a copy of
the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required
by applicable law or agreed to in writing, software distributed under the
License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS
OF ANY KIND, either express or implied. See the License for the specific
language governing permissions and limitations under the License. -->

<config>
<luceneMatchVersion>6.0.0</luceneMatchVersion>
<dataDir>${solr.data.dir:}</dataDir>
<directoryFactory name="DirectoryFactory"
class="${solr.directoryFactory:solr.RAMDirectoryFactory}" />

<schemaFactory class="ClassicIndexSchemaFactory" />


<!-- Query parser used to rerank top docs with a provided model -->
<queryParser name="ltr"
class="org.apache.solr.search.LTRQParserPlugin" />

<maxBufferedDocs>1</maxBufferedDocs>
<mergePolicyFactory class="org.apache.solr.index.TieredMergePolicyFactory">
<int name="maxMergeAtOnce">10</int>
<int name="segmentsPerTier">1000</int>
</mergePolicyFactory>
<!-- add a transformer that will encode the document features in the response.
For each document the transformer will add the features as an extra field
in the response. The name of the field we will be the the name of the transformer
enclosed between brackets (in this case [fv]). In order to get the feature
vector you will have to specify that you want the field (e.g., fl="*,[fv]) -->
<transformer name="features"
class="org.apache.solr.response.transform.LTRFeatureLoggerTransformerFactory" />

<updateHandler class="solr.DirectUpdateHandler2">
<autoCommit>
<maxTime>15000</maxTime>
<openSearcher>false</openSearcher>
</autoCommit>
<autoSoftCommit>
<maxTime>1000</maxTime>
</autoSoftCommit>
<updateLog>
<str name="dir">${solr.data.dir:}</str>
</updateLog>
</updateHandler>

<requestHandler name="/update" class="solr.UpdateRequestHandler" />
<!-- Query request handler managing models and features -->
<requestHandler name="/query" class="solr.SearchHandler">
<lst name="defaults">
<str name="echoParams">explicit</str>
<str name="wt">json</str>
<str name="indent">true</str>
<str name="df">id</str>
</lst>
</requestHandler>

</config>
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.feature;

import java.security.SecureRandom;

import java.util.List;
import java.util.Map;
import org.apache.solr.client.solrj.SolrQuery;
import org.apache.solr.ltr.TestRerankBase;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
import org.noggit.ObjectBuilder;


public class TestFeatureExtractionFromMultipleSegments extends TestRerankBase {
static final String AB = "abcdefghijklmnopqrstuvwxyz";
static SecureRandom rnd = new SecureRandom();

static String randomString( int len ){
StringBuilder sb = new StringBuilder( len );
for( int i = 0; i < len; i++ )
sb.append( AB.charAt( rnd.nextInt(AB.length()) ) );
return sb.toString();
}

@BeforeClass
public static void before() throws Exception {
// solrconfig-multiseg.xml contains the merge policy to restrict merging
setuptest("solrconfig-multiseg.xml", "schema-ltr.xml");
// index 400 documents
for(int i = 0; i<400;i=i+20) {
assertU(adoc("id", new Integer(i).toString(), "popularity", "201", "description", "apple is a company " + randomString(i%6+3), "normHits", "0.1"));
assertU(adoc("id", new Integer(i+1).toString(), "popularity", "201", "description", "d " + randomString(i%6+3), "normHits", "0.11"));

assertU(adoc("id", new Integer(i+2).toString(), "popularity", "201", "description", "apple is a company too " + randomString(i%6+3), "normHits", "0.1"));
assertU(adoc("id", new Integer(i+3).toString(), "popularity", "201", "description", "new york city is big apple " + randomString(i%6+3), "normHits", "0.11"));

assertU(adoc("id", new Integer(i+6).toString(), "popularity", "301", "description", "function name " + randomString(i%6+3), "normHits", "0.1"));
assertU(adoc("id", new Integer(i+7).toString(), "popularity", "301", "description", "function " + randomString(i%6+3), "normHits", "0.1"));

assertU(adoc("id", new Integer(i+8).toString(), "popularity", "301", "description", "This is a sample function for testing " + randomString(i%6+3), "normHits", "0.1"));
assertU(adoc("id", new Integer(i+9).toString(), "popularity", "301", "description", "Function to check out stock prices "+randomString(i%6+3), "normHits", "0.1"));
assertU(adoc("id", new Integer(i+10).toString(),"popularity", "301", "description", "Some descriptions "+randomString(i%6+3), "normHits", "0.1"));

assertU(adoc("id", new Integer(i+11).toString(), "popularity", "201", "description", "apple apple is a company " + randomString(i%6+3), "normHits", "0.1"));
assertU(adoc("id", new Integer(i+12).toString(), "popularity", "201", "description", "Big Apple is New York.", "normHits", "0.01"));
assertU(adoc("id", new Integer(i+13).toString(), "popularity", "201", "description", "New some York is Big. "+ randomString(i%6+3), "normHits", "0.1"));

assertU(adoc("id", new Integer(i+14).toString(), "popularity", "201", "description", "apple apple is a company " + randomString(i%6+3), "normHits", "0.1"));
assertU(adoc("id", new Integer(i+15).toString(), "popularity", "201", "description", "Big Apple is New York.", "normHits", "0.01"));
assertU(adoc("id", new Integer(i+16).toString(), "popularity", "401", "description", "barack h", "normHits", "0.0"));
assertU(adoc("id", new Integer(i+17).toString(), "popularity", "201", "description", "red delicious apple " + randomString(i%6+3), "normHits", "0.1"));
assertU(adoc("id", new Integer(i+18).toString(), "popularity", "201", "description", "nyc " + randomString(i%6+3), "normHits", "0.11"));
}

assertU(commit());

loadFeatures("comp_features.json");
}

@AfterClass
public static void after() throws Exception {
aftertest();
}

@Test
public void testFeatureExtractionFromMultipleSegments() throws Exception {

final SolrQuery query = new SolrQuery();
query.setQuery("{!edismax qf='description^1' boost='sum(product(pow(normHits, 0.7), 1600), .1)' v='apple'}");
// request 100 rows, if any rows are fetched from the second or subsequent segments the tests should succeed if LTRRescorer::extractFeaturesInfo() advances the doc iterator properly
int numRows = 100;
query.add("rows", (new Integer(numRows)).toString());
query.add("wt", "json");
query.add("fq", "popularity:201");
query.add("fl", "*, score,id,normHits,description,fv:[features store='feature-store-6' format='dense' efi.user_text='apple']");
String res = restTestHarness.query("/query" + query.toQueryString());

Map<String,Object> resultJson = (Map<String,Object>) ObjectBuilder.fromJSON(res);

List<Map<String,Object>> docs = (List<Map<String,Object>>)((Map<String,Object>)resultJson.get("response")).get("docs");
int passCount = 0;
for (final Map<String,Object> doc : docs) {
String features = (String)doc.get("fv");
assert(features.length() > 0);
++passCount;
}
assert(passCount == numRows);
}
}

0 comments on commit 4f4454c

Please sign in to comment.