From 0220ef817a29ef0ea84591288b0e5fa5ebb60b0a Mon Sep 17 00:00:00 2001 From: Jason Hinch <46059987+jhinch-at-atlassian-com@users.noreply.github.com> Date: Wed, 30 Oct 2024 05:38:49 +1100 Subject: [PATCH] Fix data race which can occur when using script and derived expression features with concurrent segment search (#54) Signed-off-by: Jason Hinch --- .../es/ltr/action/BaseIntegrationTest.java | 6 +-- .../com/o19s/es/ltr/logging/LoggingIT.java | 1 - .../es/ltr/feature/store/ScriptFeature.java | 30 +++++++----- .../com/o19s/es/ltr/query/RankerQuery.java | 46 ++++++++++++------- .../java/com/o19s/es/ltr/utils/Suppliers.java | 17 ------- .../o19s/es/termstat/TermStatSupplier.java | 9 ---- .../feature/store/FeatureSupplierTests.java | 13 ++---- 7 files changed, 54 insertions(+), 68 deletions(-) diff --git a/src/javaRestTest/java/com/o19s/es/ltr/action/BaseIntegrationTest.java b/src/javaRestTest/java/com/o19s/es/ltr/action/BaseIntegrationTest.java index 463e0db3..98326bb5 100644 --- a/src/javaRestTest/java/com/o19s/es/ltr/action/BaseIntegrationTest.java +++ b/src/javaRestTest/java/com/o19s/es/ltr/action/BaseIntegrationTest.java @@ -176,9 +176,9 @@ public ScoreScript newInstance(LeafReaderContext ctx) throws IOException { public double execute(ExplanationHolder explainationHolder) { // For testing purposes just look for the "terms" key and see if stats were injected if(p.containsKey("termStats")) { - AbstractMap> termStats = (AbstractMap>) p.get("termStats"); - ArrayList dfStats = termStats.get("df"); + Supplier>> termStats = (Supplier>>) p.get("termStats"); + ArrayList dfStats = termStats.get().get("df"); return dfStats.size() > 0 ? dfStats.get(0) : 0.0; } else { return 0.0; diff --git a/src/javaRestTest/java/com/o19s/es/ltr/logging/LoggingIT.java b/src/javaRestTest/java/com/o19s/es/ltr/logging/LoggingIT.java index 4be44573..188e0f7d 100644 --- a/src/javaRestTest/java/com/o19s/es/ltr/logging/LoggingIT.java +++ b/src/javaRestTest/java/com/o19s/es/ltr/logging/LoggingIT.java @@ -31,7 +31,6 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.common.lucene.search.function.FieldValueFactorFunction; import org.opensearch.common.lucene.search.function.FunctionScoreQuery; -import org.opensearch.common.xcontent.XContentType; import org.opensearch.index.query.InnerHitBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; diff --git a/src/main/java/com/o19s/es/ltr/feature/store/ScriptFeature.java b/src/main/java/com/o19s/es/ltr/feature/store/ScriptFeature.java index 5ff0d856..dce39036 100644 --- a/src/main/java/com/o19s/es/ltr/feature/store/ScriptFeature.java +++ b/src/main/java/com/o19s/es/ltr/feature/store/ScriptFeature.java @@ -60,6 +60,7 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.function.Supplier; import java.util.stream.Collectors; public class ScriptFeature implements Feature { @@ -71,6 +72,14 @@ public class ScriptFeature implements Feature { public static final String EXTRA_LOGGING = "extra_logging"; public static final String EXTRA_SCRIPT_PARAMS = "extra_script_params"; + /** + * A thread local allowing for term stats to made available for the script score feature. + * This is needed as the parameters for the script score are created up-front when creating the + * lucene query with their values being swapped out for each document using a Supplier. A thread + * local is used to allow for different documents to have their scores computed concurrently. + */ + private static final ThreadLocal CURRENT_TERM_STATS = new ThreadLocal<>(); + private final String name; private final Script script; private final Collection queryParams; @@ -143,7 +152,6 @@ public Query doToQuery(LtrQueryContext context, FeatureSet featureSet, Map nparams = new HashMap<>(); // Parse terms if set @@ -220,8 +228,8 @@ public Query doToQuery(LtrQueryContext context, FeatureSet featureSet, Map) CURRENT_TERM_STATS::get); + nparams.put(MATCH_COUNT, (Supplier) () -> CURRENT_TERM_STATS.get().getMatchedTermCount()); nparams.put(UNIQUE_TERMS, terms.size()); } @@ -240,25 +248,22 @@ public Query doToQuery(LtrQueryContext context, FeatureSet featureSet, Map terms; LtrScript(ScriptScoreFunction function, FeatureSupplier supplier, ExtraLoggingSupplier extraLoggingSupplier, - TermStatSupplier termStatSupplier, Set terms) { this.function = function; this.supplier = supplier; this.extraLoggingSupplier = extraLoggingSupplier; - this.termStatSupplier = termStatSupplier; this.terms = terms; } @@ -285,7 +290,7 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo if (!scoreMode.needsScores()) { return new MatchAllDocsQuery().createWeight(searcher, scoreMode, 1F); } - return new LtrScriptWeight(this, this.function, termStatSupplier, terms, searcher, scoreMode); + return new LtrScriptWeight(this, this.function, terms, searcher, scoreMode); } @Override @@ -317,18 +322,15 @@ static class LtrScriptWeight extends Weight { private final IndexSearcher searcher; private final ScoreMode scoreMode; private final ScriptScoreFunction function; - private final TermStatSupplier termStatSupplier; private final Set terms; private final HashMap termContexts; LtrScriptWeight(Query query, ScriptScoreFunction function, - TermStatSupplier termStatSupplier, Set terms, IndexSearcher searcher, ScoreMode scoreMode) throws IOException { super(query); this.function = function; - this.termStatSupplier = termStatSupplier; this.terms = terms; this.searcher = searcher; this.scoreMode = scoreMode; @@ -355,6 +357,7 @@ public Explanation explain(LeafReaderContext context, int doc) throws IOExceptio public Scorer scorer(LeafReaderContext context) throws IOException { LeafScoreFunction leafScoreFunction = function.getLeafScoreFunction(context); DocIdSetIterator iterator = DocIdSetIterator.all(context.reader().maxDoc()); + TermStatSupplier termStatSupplier = new TermStatSupplier(); return new Scorer(this) { @Override public int docID() { @@ -363,12 +366,15 @@ public int docID() { @Override public float score() throws IOException { + CURRENT_TERM_STATS.set(termStatSupplier); // Do the terms magic if the user asked for it if (terms.size() > 0) { termStatSupplier.bump(searcher, context, docID(), terms, scoreMode, termContexts); } - return (float) leafScoreFunction.score(iterator.docID(), 0F); + float score = (float) leafScoreFunction.score(iterator.docID(), 0F); + CURRENT_TERM_STATS.remove(); + return score; } @Override diff --git a/src/main/java/com/o19s/es/ltr/query/RankerQuery.java b/src/main/java/com/o19s/es/ltr/query/RankerQuery.java index bc91daf6..6b93f5ba 100644 --- a/src/main/java/com/o19s/es/ltr/query/RankerQuery.java +++ b/src/main/java/com/o19s/es/ltr/query/RankerQuery.java @@ -24,8 +24,6 @@ import com.o19s.es.ltr.ranker.LogLtrRanker; import com.o19s.es.ltr.ranker.LtrRanker; import com.o19s.es.ltr.ranker.NullRanker; -import com.o19s.es.ltr.utils.Suppliers; -import com.o19s.es.ltr.utils.Suppliers.MutableSupplier; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.Term; @@ -61,6 +59,26 @@ * or within a BooleanQuery and an appropriate filter clause. */ public class RankerQuery extends Query { + /** + * A thread local to allow for sharing the current feature vector between features. This + * is used primarily for derived expression and script features which derive one feature + * score from another. It relies on the following invariants to work: + *
    + *
  • + * Any call to {@link LtrRanker#newFeatureVector(LtrRanker.FeatureVector)} is + * followed by a subsequent call to {@link LtrRanker#score(LtrRanker.FeatureVector)} + *
  • + *
  • + * All feature scorers are invoked only between the creation of the feature vector and + * the final score being computed (the calls outlined above) + *
  • + *
  • + * All calls described above happen on the same thread for a single document + *
  • + *
+ */ + private static final ThreadLocal CURRENT_VECTOR = new ThreadLocal<>(); + private final List queries; private final FeatureSet features; private final LtrRanker ranker; @@ -200,12 +218,9 @@ public boolean isCacheable(LeafReaderContext ctx) { } List weights = new ArrayList<>(queries.size()); - // XXX: this is not thread safe and may run into extremely weird issues - // if the searcher uses the parallel collector - // Hopefully elastic never runs - MutableSupplier vectorSupplier = new Suppliers.MutableSupplier<>(); - FVLtrRankerWrapper ltrRankerWrapper = new FVLtrRankerWrapper(ranker, vectorSupplier); - LtrRewriteContext context = new LtrRewriteContext(ranker, vectorSupplier); + + FVLtrRankerWrapper ltrRankerWrapper = new FVLtrRankerWrapper(ranker); + LtrRewriteContext context = new LtrRewriteContext(ranker, CURRENT_VECTOR::get); for (Query q : queries) { if (q instanceof LtrRewritableQuery) { q = ((LtrRewritableQuery) q).ltrRewrite(context); @@ -442,11 +457,9 @@ public long cost() { static class FVLtrRankerWrapper implements LtrRanker { private final LtrRanker wrapped; - private final MutableSupplier vectorSupplier; - FVLtrRankerWrapper(LtrRanker wrapped, MutableSupplier vectorSupplier) { + FVLtrRankerWrapper(LtrRanker wrapped) { this.wrapped = Objects.requireNonNull(wrapped); - this.vectorSupplier = Objects.requireNonNull(vectorSupplier); } @Override @@ -457,13 +470,15 @@ public String name() { @Override public FeatureVector newFeatureVector(FeatureVector reuse) { FeatureVector fv = wrapped.newFeatureVector(reuse); - vectorSupplier.set(fv); + CURRENT_VECTOR.set(fv); return fv; } @Override public float score(FeatureVector point) { - return wrapped.score(point); + float score = wrapped.score(point); + CURRENT_VECTOR.remove(); + return score; } @Override @@ -471,13 +486,12 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; FVLtrRankerWrapper that = (FVLtrRankerWrapper) o; - return Objects.equals(wrapped, that.wrapped) && - Objects.equals(vectorSupplier, that.vectorSupplier); + return Objects.equals(wrapped, that.wrapped); } @Override public int hashCode() { - return Objects.hash(wrapped, vectorSupplier); + return Objects.hash(wrapped); } } diff --git a/src/main/java/com/o19s/es/ltr/utils/Suppliers.java b/src/main/java/com/o19s/es/ltr/utils/Suppliers.java index d54fbcea..2a77239d 100644 --- a/src/main/java/com/o19s/es/ltr/utils/Suppliers.java +++ b/src/main/java/com/o19s/es/ltr/utils/Suppliers.java @@ -17,7 +17,6 @@ package com.o19s.es.ltr.utils; import java.util.Objects; -import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier; public final class Suppliers { @@ -59,20 +58,4 @@ public E get() { return value; } } - - /** - * A mutable supplier - */ - public static class MutableSupplier implements Supplier { - private final AtomicReference ref = new AtomicReference<>(); - - @Override - public T get() { - return ref.get(); - } - - public void set(T obj) { - this.ref.set(obj); - } - } } diff --git a/src/main/java/com/o19s/es/termstat/TermStatSupplier.java b/src/main/java/com/o19s/es/termstat/TermStatSupplier.java index 6167b27f..2d3d29f2 100644 --- a/src/main/java/com/o19s/es/termstat/TermStatSupplier.java +++ b/src/main/java/com/o19s/es/termstat/TermStatSupplier.java @@ -17,7 +17,6 @@ import com.o19s.es.explore.StatisticsHelper; import com.o19s.es.explore.StatisticsHelper.AggrType; -import com.o19s.es.ltr.utils.Suppliers; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.PostingsEnum; import org.apache.lucene.index.ReaderUtil; @@ -48,12 +47,10 @@ public class TermStatSupplier extends AbstractMap> { private final ClassicSimilarity sim; private final StatisticsHelper df_stats, idf_stats, tf_stats, ttf_stats, tp_stats; - private final Suppliers.MutableSupplier matchedCountSupplier; private int matchedTermCount = 0; public TermStatSupplier() { - this.matchedCountSupplier = new Suppliers.MutableSupplier<>(); this.sim = new ClassicSimilarity(); this.df_stats = new StatisticsHelper(); this.idf_stats = new StatisticsHelper(); @@ -124,8 +121,6 @@ public void bump (IndexSearcher searcher, LeafReaderContext context, tp_stats.add(0.0f); } } - - matchedCountSupplier.set(matchedTermCount); } /** @@ -229,10 +224,6 @@ public int getMatchedTermCount() { return matchedTermCount; } - public Suppliers.MutableSupplier getMatchedTermCountSupplier() { - return matchedCountSupplier; - } - public void setPosAggr(AggrType type) { this.posAggrType = type; } diff --git a/src/test/java/com/o19s/es/ltr/feature/store/FeatureSupplierTests.java b/src/test/java/com/o19s/es/ltr/feature/store/FeatureSupplierTests.java index 0c54be7d..fa0d07bf 100644 --- a/src/test/java/com/o19s/es/ltr/feature/store/FeatureSupplierTests.java +++ b/src/test/java/com/o19s/es/ltr/feature/store/FeatureSupplierTests.java @@ -19,7 +19,6 @@ import com.o19s.es.ltr.feature.FeatureSet; import com.o19s.es.ltr.ranker.DenseFeatureVector; import com.o19s.es.ltr.ranker.LtrRanker; -import com.o19s.es.ltr.utils.Suppliers; import org.apache.lucene.tests.util.LuceneTestCase; import org.opensearch.index.query.QueryBuilders; @@ -45,10 +44,8 @@ public void testGetWhenFeatureVectorNotSet() { public void testGetWhenFeatureVectorSet() { FeatureSupplier featureSupplier = new FeatureSupplier(getFeatureSet()); - Suppliers.MutableSupplier vectorSupplier = new Suppliers.MutableSupplier<>(); LtrRanker.FeatureVector featureVector = new DenseFeatureVector(1); - vectorSupplier.set(featureVector); - featureSupplier.set(vectorSupplier); + featureSupplier.set(() -> featureVector); assertEquals(featureVector, featureSupplier.get()); } @@ -60,11 +57,9 @@ public void testContainsKey() { public void testGetFeatureScore() { FeatureSupplier featureSupplier = new FeatureSupplier(getFeatureSet()); - Suppliers.MutableSupplier vectorSupplier = new Suppliers.MutableSupplier<>(); LtrRanker.FeatureVector featureVector = new DenseFeatureVector(1); featureVector.setFeatureScore(0, 10.0f); - vectorSupplier.set(featureVector); - featureSupplier.set(vectorSupplier); + featureSupplier.set(() -> featureVector); assertEquals(10.0f, featureSupplier.get("test"), 0.0f); assertNull(featureSupplier.get("bad_test")); } @@ -81,11 +76,9 @@ public void testEntrySetWhenFeatureVectorNotSet(){ public void testEntrySetWhenFeatureVectorIsSet(){ FeatureSupplier featureSupplier = new FeatureSupplier(getFeatureSet()); - Suppliers.MutableSupplier vectorSupplier = new Suppliers.MutableSupplier<>(); LtrRanker.FeatureVector featureVector = new DenseFeatureVector(1); featureVector.setFeatureScore(0, 10.0f); - vectorSupplier.set(featureVector); - featureSupplier.set(vectorSupplier); + featureSupplier.set(() -> featureVector); Set> entrySet = featureSupplier.entrySet(); assertFalse(entrySet.isEmpty());