From 0d928a9abf40427fb4f057185e5477725b7cf796 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Tue, 10 Dec 2024 16:43:31 +0800 Subject: [PATCH] address review comments Signed-off-by: zhichao-aws --- .../NeuralSparseTwoPhaseProcessor.java | 19 +++++++------ .../processor/SparseEncodingProcessor.java | 12 +++++--- .../SparseEncodingProcessorFactory.java | 26 +++++++++-------- .../neuralsearch/util/prune/PruneType.java | 6 ++-- .../neuralsearch/util/prune/PruneUtils.java | 28 +++++++++++-------- ...ncodingEmbeddingProcessorFactoryTests.java | 2 +- .../util/prune/PruneUtilsTests.java | 23 +++++---------- 7 files changed, 61 insertions(+), 55 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessor.java index 9e0ecefd4..bc5971e3f 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessor.java @@ -221,14 +221,17 @@ public NeuralSparseTwoPhaseProcessor create( twoPhaseConfigMap.getOrDefault(PruneUtils.PRUNE_TYPE_FIELD, pruneType.getValue()).toString() ); } - if (!PruneUtils.isValidPruneRatio(pruneType, pruneRatio)) throw new IllegalArgumentException( - "Illegal prune_ratio " - + pruneRatio - + " for prune_type: " - + pruneType.getValue() - + ". " - + PruneUtils.getValidPruneRatioDescription(pruneType) - ); + if (!PruneUtils.isValidPruneRatio(pruneType, pruneRatio)) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "Illegal prune_ratio %f for prune_type: %s. %s", + pruneRatio, + pruneType.getValue(), + PruneUtils.getValidPruneRatioDescription(pruneType) + ) + ); + } return new NeuralSparseTwoPhaseProcessor( tag, diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java index 1f06e81e2..9250c8d64 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java @@ -60,8 +60,10 @@ public void doExecute( BiConsumer handler ) { mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> { - List> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps); - sparseVectors = sparseVectors.stream().map(vector -> PruneUtils.pruneSparseVector(pruneType, pruneRatio, vector)).toList(); + List> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps) + .stream() + .map(vector -> PruneUtils.pruneSparseVector(pruneType, pruneRatio, vector)) + .toList(); setVectorFieldsToDocument(ingestDocument, ProcessMap, sparseVectors); handler.accept(ingestDocument, null); }, e -> { handler.accept(null, e); })); @@ -70,8 +72,10 @@ public void doExecute( @Override public void doBatchExecute(List inferenceList, Consumer> handler, Consumer onException) { mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> { - List> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps); - sparseVectors = sparseVectors.stream().map(vector -> PruneUtils.pruneSparseVector(pruneType, pruneRatio, vector)).toList(); + List> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps) + .stream() + .map(vector -> PruneUtils.pruneSparseVector(pruneType, pruneRatio, vector)) + .toList(); handler.accept(sparseVectors); }, onException)); } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java index f5066369f..7a7d7dfde 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java @@ -12,6 +12,7 @@ import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.FIELD_MAP_FIELD; import static org.opensearch.neuralsearch.processor.SparseEncodingProcessor.TYPE; +import java.util.Locale; import java.util.Map; import org.opensearch.cluster.service.ClusterService; @@ -51,19 +52,20 @@ protected AbstractBatchingProcessor newProcessor(String tag, String description, // if we have prune type, then prune ratio field must have value // readDoubleProperty will throw exception if value is not present pruneRatio = readDoubleProperty(TYPE, tag, config, PruneUtils.PRUNE_RATIO_FIELD).floatValue(); - if (!PruneUtils.isValidPruneRatio(pruneType, pruneRatio)) throw new IllegalArgumentException( - "Illegal prune_ratio " - + pruneRatio - + " for prune_type: " - + pruneType.getValue() - + ". " - + PruneUtils.getValidPruneRatioDescription(pruneType) - ); - } else { - // if we don't have prune type, then prune ratio field must not have value - if (config.containsKey(PruneUtils.PRUNE_RATIO_FIELD)) { - throw new IllegalArgumentException("prune_ratio field is not supported when prune_type is not provided"); + if (!PruneUtils.isValidPruneRatio(pruneType, pruneRatio)) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "Illegal prune_ratio %f for prune_type: %s. %s", + pruneRatio, + pruneType.getValue(), + PruneUtils.getValidPruneRatioDescription(pruneType) + ) + ); } + } else if (config.containsKey(PruneUtils.PRUNE_RATIO_FIELD)) { + // if we don't have prune type, then prune ratio field must not have value + throw new IllegalArgumentException("prune_ratio field is not supported when prune_type is not provided"); } return new SparseEncodingProcessor( diff --git a/src/main/java/org/opensearch/neuralsearch/util/prune/PruneType.java b/src/main/java/org/opensearch/neuralsearch/util/prune/PruneType.java index 9e228ae27..5f8e62b7c 100644 --- a/src/main/java/org/opensearch/neuralsearch/util/prune/PruneType.java +++ b/src/main/java/org/opensearch/neuralsearch/util/prune/PruneType.java @@ -6,6 +6,8 @@ import org.apache.commons.lang.StringUtils; +import java.util.Locale; + /** * Enum representing different types of prune methods for sparse vectors */ @@ -33,13 +35,13 @@ public String getValue() { * @return corresponding PruneType enum * @throws IllegalArgumentException if value doesn't match any prune type */ - public static PruneType fromString(String value) { + public static PruneType fromString(final String value) { if (StringUtils.isEmpty(value)) return NONE; for (PruneType type : PruneType.values()) { if (type.value.equals(value)) { return type; } } - throw new IllegalArgumentException("Unknown prune type: " + value); + throw new IllegalArgumentException(String.format(Locale.ROOT, "Unknown prune type: %s", value)); } } diff --git a/src/main/java/org/opensearch/neuralsearch/util/prune/PruneUtils.java b/src/main/java/org/opensearch/neuralsearch/util/prune/PruneUtils.java index 34a9cff2d..77836972e 100644 --- a/src/main/java/org/opensearch/neuralsearch/util/prune/PruneUtils.java +++ b/src/main/java/org/opensearch/neuralsearch/util/prune/PruneUtils.java @@ -33,13 +33,13 @@ public class PruneUtils { */ private static Tuple, Map> pruneByTopK( Map sparseVector, - int k, + float k, boolean requiresPrunedEntries ) { PriorityQueue> pq = new PriorityQueue<>((a, b) -> Float.compare(a.getValue(), b.getValue())); for (Map.Entry entry : sparseVector.entrySet()) { - if (pq.size() < k) { + if (pq.size() < (int) k) { pq.offer(entry); } else if (entry.getValue() > pq.peek().getValue()) { pq.poll(); @@ -172,8 +172,8 @@ public static Tuple, Map> splitSparseVector( float pruneRatio, Map sparseVector ) { - if (Objects.isNull(pruneType) || Objects.isNull(pruneRatio)) { - throw new IllegalArgumentException("Prune type and prune ratio must be provided"); + if (Objects.isNull(pruneType)) { + throw new IllegalArgumentException("Prune type must be provided"); } if (Objects.isNull(sparseVector)) { @@ -188,7 +188,7 @@ public static Tuple, Map> splitSparseVector( switch (pruneType) { case TOP_K: - return pruneByTopK(sparseVector, (int) pruneRatio, true); + return pruneByTopK(sparseVector, pruneRatio, true); case ALPHA_MASS: return pruneByAlphaMass(sparseVector, pruneRatio, true); case MAX_RATIO: @@ -208,9 +208,13 @@ public static Tuple, Map> splitSparseVector( * @param sparseVector The input sparse vector as a map of string keys to float values * @return A map with high-scoring elements */ - public static Map pruneSparseVector(PruneType pruneType, float pruneRatio, Map sparseVector) { - if (Objects.isNull(pruneType) || Objects.isNull(pruneRatio)) { - throw new IllegalArgumentException("Prune type and prune ratio must be provided"); + public static Map pruneSparseVector( + final PruneType pruneType, + final float pruneRatio, + final Map sparseVector + ) { + if (Objects.isNull(pruneType)) { + throw new IllegalArgumentException("Prune type must be provided"); } if (Objects.isNull(sparseVector)) { @@ -225,7 +229,7 @@ public static Map pruneSparseVector(PruneType pruneType, float pr switch (pruneType) { case TOP_K: - return pruneByTopK(sparseVector, (int) pruneRatio, false).v1(); + return pruneByTopK(sparseVector, pruneRatio, false).v1(); case ALPHA_MASS: return pruneByAlphaMass(sparseVector, pruneRatio, false).v1(); case MAX_RATIO: @@ -245,7 +249,7 @@ public static Map pruneSparseVector(PruneType pruneType, float pr * @return true if the ratio is valid for the given prune type, false otherwise * @throws IllegalArgumentException if prune type is null */ - public static boolean isValidPruneRatio(PruneType pruneType, float pruneRatio) { + public static boolean isValidPruneRatio(final PruneType pruneType, final float pruneRatio) { if (pruneType == null) { throw new IllegalArgumentException("Prune type cannot be null"); } @@ -269,7 +273,7 @@ public static boolean isValidPruneRatio(PruneType pruneType, float pruneRatio) { * @param pruneType The type of prune strategy * @throws IllegalArgumentException if prune type is null */ - public static String getValidPruneRatioDescription(PruneType pruneType) { + public static String getValidPruneRatioDescription(final PruneType pruneType) { if (pruneType == null) { throw new IllegalArgumentException("Prune type cannot be null"); } @@ -283,7 +287,7 @@ public static String getValidPruneRatioDescription(PruneType pruneType) { case ABS_VALUE: return "prune_ratio should be non-negative."; default: - return ""; + return "prune_ratio field is not supported when prune_type is none"; } } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingEmbeddingProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingEmbeddingProcessorFactoryTests.java index 9d1b45866..5d098e77e 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingEmbeddingProcessorFactoryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingEmbeddingProcessorFactoryTests.java @@ -149,7 +149,7 @@ public void testCreateProcessor_whenInvalidPruneRatio_thenFail() { IllegalArgumentException.class, () -> sparseEncodingProcessorFactory.create(Map.of(), PROCESSOR_TAG, DESCRIPTION, config) ); - assertEquals("Illegal prune_ratio 0.2 for prune_type: top_k. prune_ratio should be positive integer.", exception.getMessage()); + assertEquals("Illegal prune_ratio 0.200000 for prune_type: top_k. prune_ratio should be positive integer.", exception.getMessage()); } @SneakyThrows diff --git a/src/test/java/org/opensearch/neuralsearch/util/prune/PruneUtilsTests.java b/src/test/java/org/opensearch/neuralsearch/util/prune/PruneUtilsTests.java index f0869ac53..536125152 100644 --- a/src/test/java/org/opensearch/neuralsearch/util/prune/PruneUtilsTests.java +++ b/src/test/java/org/opensearch/neuralsearch/util/prune/PruneUtilsTests.java @@ -183,26 +183,14 @@ public void testInvalidPruneType() { IllegalArgumentException.class, () -> PruneUtils.pruneSparseVector(null, 2, input) ); - assertEquals(exception1.getMessage(), "Prune type and prune ratio must be provided"); - - IllegalArgumentException exception2 = assertThrows( - IllegalArgumentException.class, - () -> PruneUtils.pruneSparseVector(null, 2, input) - ); - assertEquals(exception2.getMessage(), "Prune type and prune ratio must be provided"); + assertEquals(exception1.getMessage(), "Prune type must be provided"); // Test split - IllegalArgumentException exception3 = assertThrows( - IllegalArgumentException.class, - () -> PruneUtils.splitSparseVector(null, 2, input) - ); - assertEquals(exception3.getMessage(), "Prune type and prune ratio must be provided"); - - IllegalArgumentException exception4 = assertThrows( + IllegalArgumentException exception2 = assertThrows( IllegalArgumentException.class, () -> PruneUtils.splitSparseVector(null, 2, input) ); - assertEquals(exception4.getMessage(), "Prune type and prune ratio must be provided"); + assertEquals(exception2.getMessage(), "Prune type must be provided"); } public void testNullSparseVector() { @@ -264,7 +252,10 @@ public void testGetValidPruneRatioDescription() { assertEquals("prune_ratio should be in the range [0, 1).", PruneUtils.getValidPruneRatioDescription(PruneType.MAX_RATIO)); assertEquals("prune_ratio should be in the range [0, 1).", PruneUtils.getValidPruneRatioDescription(PruneType.ALPHA_MASS)); assertEquals("prune_ratio should be non-negative.", PruneUtils.getValidPruneRatioDescription(PruneType.ABS_VALUE)); - assertEquals("", PruneUtils.getValidPruneRatioDescription(PruneType.NONE)); + assertEquals( + "prune_ratio field is not supported when prune_type is none", + PruneUtils.getValidPruneRatioDescription(PruneType.NONE) + ); IllegalArgumentException exception = assertThrows( IllegalArgumentException.class,