Skip to content

Commit

Permalink
address review comments
Browse files Browse the repository at this point in the history
Signed-off-by: zhichao-aws <[email protected]>
  • Loading branch information
zhichao-aws committed Dec 10, 2024
1 parent cffd829 commit 0d928a9
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -221,14 +221,17 @@ public NeuralSparseTwoPhaseProcessor create(
twoPhaseConfigMap.getOrDefault(PruneUtils.PRUNE_TYPE_FIELD, pruneType.getValue()).toString()
);
}
if (!PruneUtils.isValidPruneRatio(pruneType, pruneRatio)) throw new IllegalArgumentException(
"Illegal prune_ratio "
+ pruneRatio
+ " for prune_type: "
+ pruneType.getValue()
+ ". "
+ PruneUtils.getValidPruneRatioDescription(pruneType)
);
if (!PruneUtils.isValidPruneRatio(pruneType, pruneRatio)) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"Illegal prune_ratio %f for prune_type: %s. %s",
pruneRatio,
pruneType.getValue(),
PruneUtils.getValidPruneRatioDescription(pruneType)
)
);
}

return new NeuralSparseTwoPhaseProcessor(
tag,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,10 @@ public void doExecute(
BiConsumer<IngestDocument, Exception> handler
) {
mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> {
List<Map<String, Float>> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps);
sparseVectors = sparseVectors.stream().map(vector -> PruneUtils.pruneSparseVector(pruneType, pruneRatio, vector)).toList();
List<Map<String, Float>> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)
.stream()
.map(vector -> PruneUtils.pruneSparseVector(pruneType, pruneRatio, vector))
.toList();
setVectorFieldsToDocument(ingestDocument, ProcessMap, sparseVectors);
handler.accept(ingestDocument, null);
}, e -> { handler.accept(null, e); }));
Expand All @@ -70,8 +72,10 @@ public void doExecute(
@Override
public void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler, Consumer<Exception> onException) {
mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> {
List<Map<String, Float>> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps);
sparseVectors = sparseVectors.stream().map(vector -> PruneUtils.pruneSparseVector(pruneType, pruneRatio, vector)).toList();
List<Map<String, Float>> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)
.stream()
.map(vector -> PruneUtils.pruneSparseVector(pruneType, pruneRatio, vector))
.toList();
handler.accept(sparseVectors);
}, onException));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.FIELD_MAP_FIELD;
import static org.opensearch.neuralsearch.processor.SparseEncodingProcessor.TYPE;

import java.util.Locale;
import java.util.Map;

import org.opensearch.cluster.service.ClusterService;
Expand Down Expand Up @@ -51,19 +52,20 @@ protected AbstractBatchingProcessor newProcessor(String tag, String description,
// if we have prune type, then prune ratio field must have value
// readDoubleProperty will throw exception if value is not present
pruneRatio = readDoubleProperty(TYPE, tag, config, PruneUtils.PRUNE_RATIO_FIELD).floatValue();
if (!PruneUtils.isValidPruneRatio(pruneType, pruneRatio)) throw new IllegalArgumentException(
"Illegal prune_ratio "
+ pruneRatio
+ " for prune_type: "
+ pruneType.getValue()
+ ". "
+ PruneUtils.getValidPruneRatioDescription(pruneType)
);
} else {
// if we don't have prune type, then prune ratio field must not have value
if (config.containsKey(PruneUtils.PRUNE_RATIO_FIELD)) {
throw new IllegalArgumentException("prune_ratio field is not supported when prune_type is not provided");
if (!PruneUtils.isValidPruneRatio(pruneType, pruneRatio)) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"Illegal prune_ratio %f for prune_type: %s. %s",
pruneRatio,
pruneType.getValue(),
PruneUtils.getValidPruneRatioDescription(pruneType)
)
);
}
} else if (config.containsKey(PruneUtils.PRUNE_RATIO_FIELD)) {
// if we don't have prune type, then prune ratio field must not have value
throw new IllegalArgumentException("prune_ratio field is not supported when prune_type is not provided");
}

return new SparseEncodingProcessor(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import org.apache.commons.lang.StringUtils;

import java.util.Locale;

/**
* Enum representing different types of prune methods for sparse vectors
*/
Expand Down Expand Up @@ -33,13 +35,13 @@ public String getValue() {
* @return corresponding PruneType enum
* @throws IllegalArgumentException if value doesn't match any prune type
*/
public static PruneType fromString(String value) {
public static PruneType fromString(final String value) {
if (StringUtils.isEmpty(value)) return NONE;
for (PruneType type : PruneType.values()) {
if (type.value.equals(value)) {
return type;
}
}
throw new IllegalArgumentException("Unknown prune type: " + value);
throw new IllegalArgumentException(String.format(Locale.ROOT, "Unknown prune type: %s", value));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ public class PruneUtils {
*/
private static Tuple<Map<String, Float>, Map<String, Float>> pruneByTopK(
Map<String, Float> sparseVector,
int k,
float k,
boolean requiresPrunedEntries
) {
PriorityQueue<Map.Entry<String, Float>> pq = new PriorityQueue<>((a, b) -> Float.compare(a.getValue(), b.getValue()));

for (Map.Entry<String, Float> entry : sparseVector.entrySet()) {
if (pq.size() < k) {
if (pq.size() < (int) k) {
pq.offer(entry);
} else if (entry.getValue() > pq.peek().getValue()) {
pq.poll();
Expand Down Expand Up @@ -172,8 +172,8 @@ public static Tuple<Map<String, Float>, Map<String, Float>> splitSparseVector(
float pruneRatio,
Map<String, Float> sparseVector
) {
if (Objects.isNull(pruneType) || Objects.isNull(pruneRatio)) {
throw new IllegalArgumentException("Prune type and prune ratio must be provided");
if (Objects.isNull(pruneType)) {
throw new IllegalArgumentException("Prune type must be provided");
}

if (Objects.isNull(sparseVector)) {
Expand All @@ -188,7 +188,7 @@ public static Tuple<Map<String, Float>, Map<String, Float>> splitSparseVector(

switch (pruneType) {
case TOP_K:
return pruneByTopK(sparseVector, (int) pruneRatio, true);
return pruneByTopK(sparseVector, pruneRatio, true);
case ALPHA_MASS:
return pruneByAlphaMass(sparseVector, pruneRatio, true);
case MAX_RATIO:
Expand All @@ -208,9 +208,13 @@ public static Tuple<Map<String, Float>, Map<String, Float>> splitSparseVector(
* @param sparseVector The input sparse vector as a map of string keys to float values
* @return A map with high-scoring elements
*/
public static Map<String, Float> pruneSparseVector(PruneType pruneType, float pruneRatio, Map<String, Float> sparseVector) {
if (Objects.isNull(pruneType) || Objects.isNull(pruneRatio)) {
throw new IllegalArgumentException("Prune type and prune ratio must be provided");
public static Map<String, Float> pruneSparseVector(
final PruneType pruneType,
final float pruneRatio,
final Map<String, Float> sparseVector
) {
if (Objects.isNull(pruneType)) {
throw new IllegalArgumentException("Prune type must be provided");
}

if (Objects.isNull(sparseVector)) {
Expand All @@ -225,7 +229,7 @@ public static Map<String, Float> pruneSparseVector(PruneType pruneType, float pr

switch (pruneType) {
case TOP_K:
return pruneByTopK(sparseVector, (int) pruneRatio, false).v1();
return pruneByTopK(sparseVector, pruneRatio, false).v1();
case ALPHA_MASS:
return pruneByAlphaMass(sparseVector, pruneRatio, false).v1();
case MAX_RATIO:
Expand All @@ -245,7 +249,7 @@ public static Map<String, Float> pruneSparseVector(PruneType pruneType, float pr
* @return true if the ratio is valid for the given prune type, false otherwise
* @throws IllegalArgumentException if prune type is null
*/
public static boolean isValidPruneRatio(PruneType pruneType, float pruneRatio) {
public static boolean isValidPruneRatio(final PruneType pruneType, final float pruneRatio) {
if (pruneType == null) {
throw new IllegalArgumentException("Prune type cannot be null");
}
Expand All @@ -269,7 +273,7 @@ public static boolean isValidPruneRatio(PruneType pruneType, float pruneRatio) {
* @param pruneType The type of prune strategy
* @throws IllegalArgumentException if prune type is null
*/
public static String getValidPruneRatioDescription(PruneType pruneType) {
public static String getValidPruneRatioDescription(final PruneType pruneType) {
if (pruneType == null) {
throw new IllegalArgumentException("Prune type cannot be null");
}
Expand All @@ -283,7 +287,7 @@ public static String getValidPruneRatioDescription(PruneType pruneType) {
case ABS_VALUE:
return "prune_ratio should be non-negative.";
default:
return "";
return "prune_ratio field is not supported when prune_type is none";
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ public void testCreateProcessor_whenInvalidPruneRatio_thenFail() {
IllegalArgumentException.class,
() -> sparseEncodingProcessorFactory.create(Map.of(), PROCESSOR_TAG, DESCRIPTION, config)
);
assertEquals("Illegal prune_ratio 0.2 for prune_type: top_k. prune_ratio should be positive integer.", exception.getMessage());
assertEquals("Illegal prune_ratio 0.200000 for prune_type: top_k. prune_ratio should be positive integer.", exception.getMessage());
}

@SneakyThrows
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,26 +183,14 @@ public void testInvalidPruneType() {
IllegalArgumentException.class,
() -> PruneUtils.pruneSparseVector(null, 2, input)
);
assertEquals(exception1.getMessage(), "Prune type and prune ratio must be provided");

IllegalArgumentException exception2 = assertThrows(
IllegalArgumentException.class,
() -> PruneUtils.pruneSparseVector(null, 2, input)
);
assertEquals(exception2.getMessage(), "Prune type and prune ratio must be provided");
assertEquals(exception1.getMessage(), "Prune type must be provided");

// Test split
IllegalArgumentException exception3 = assertThrows(
IllegalArgumentException.class,
() -> PruneUtils.splitSparseVector(null, 2, input)
);
assertEquals(exception3.getMessage(), "Prune type and prune ratio must be provided");

IllegalArgumentException exception4 = assertThrows(
IllegalArgumentException exception2 = assertThrows(
IllegalArgumentException.class,
() -> PruneUtils.splitSparseVector(null, 2, input)
);
assertEquals(exception4.getMessage(), "Prune type and prune ratio must be provided");
assertEquals(exception2.getMessage(), "Prune type must be provided");
}

public void testNullSparseVector() {
Expand Down Expand Up @@ -264,7 +252,10 @@ public void testGetValidPruneRatioDescription() {
assertEquals("prune_ratio should be in the range [0, 1).", PruneUtils.getValidPruneRatioDescription(PruneType.MAX_RATIO));
assertEquals("prune_ratio should be in the range [0, 1).", PruneUtils.getValidPruneRatioDescription(PruneType.ALPHA_MASS));
assertEquals("prune_ratio should be non-negative.", PruneUtils.getValidPruneRatioDescription(PruneType.ABS_VALUE));
assertEquals("", PruneUtils.getValidPruneRatioDescription(PruneType.NONE));
assertEquals(
"prune_ratio field is not supported when prune_type is none",
PruneUtils.getValidPruneRatioDescription(PruneType.NONE)
);

IllegalArgumentException exception = assertThrows(
IllegalArgumentException.class,
Expand Down

0 comments on commit 0d928a9

Please sign in to comment.