diff --git a/src/main/java/redis/clients/jedis/search/SearchResult.java b/src/main/java/redis/clients/jedis/search/SearchResult.java index 55afbe0b24..b51e791927 100644 --- a/src/main/java/redis/clients/jedis/search/SearchResult.java +++ b/src/main/java/redis/clients/jedis/search/SearchResult.java @@ -21,10 +21,16 @@ public class SearchResult { private final long totalResults; private final List documents; + private final List warnings; private SearchResult(long totalResults, List documents) { + this(totalResults, documents, (List) null); + } + + private SearchResult(long totalResults, List documents, List warnings) { this.totalResults = totalResults; this.documents = documents; + this.warnings = warnings; } public long getTotalResults() { @@ -35,10 +41,16 @@ public List getDocuments() { return Collections.unmodifiableList(documents); } + public List getWarnings() { + return warnings; + } + @Override public String toString() { return getClass().getSimpleName() + "{Total results:" + totalResults - + ", Documents:" + documents + "}"; + + ", Documents:" + documents + + (warnings != null ? ", Warnings:" + warnings : "") + + "}"; } public static class SearchResultBuilder extends Builder { @@ -104,6 +116,7 @@ public static final class PerFieldDecoderSearchResultBuilder extends Builder documentBuilder; @@ -120,20 +133,25 @@ public SearchResult build(Object data) { List list = (List) data; long totalResults = -1; List results = null; + List warnings = null; for (KeyValue kv : list) { String key = BuilderFactory.STRING.build(kv.getKey()); + Object rawVal = kv.getValue(); switch (key) { case TOTAL_RESULTS_STR: - totalResults = BuilderFactory.LONG.build(kv.getValue()); + totalResults = BuilderFactory.LONG.build(rawVal); break; case RESULTS_STR: - results = ((List) kv.getValue()).stream() + results = ((List) rawVal).stream() .map(documentBuilder::build) .collect(Collectors.toList()); break; + case WARNINGS_STR: + warnings = BuilderFactory.STRING_LIST.build(rawVal); + break; } } - return new SearchResult(totalResults, results); + return new SearchResult(totalResults, results, warnings); } }; /// <-- RESP3 diff --git a/src/main/java/redis/clients/jedis/search/aggr/AggregationResult.java b/src/main/java/redis/clients/jedis/search/aggr/AggregationResult.java index cec65f9cd9..99c47e9fcb 100644 --- a/src/main/java/redis/clients/jedis/search/aggr/AggregationResult.java +++ b/src/main/java/redis/clients/jedis/search/aggr/AggregationResult.java @@ -19,37 +19,18 @@ public class AggregationResult { private final List> results; - private Long cursorId = -1L; - - private AggregationResult(Object resp, long cursorId) { - this(resp); - this.cursorId = cursorId; - } + private final List warnings; - private AggregationResult(Object resp) { - List list = (List) SafeEncoder.encodeObject(resp); - - // the first element is always the number of results - totalResults = (Long) list.get(0); - results = new ArrayList<>(list.size() - 1); + private Long cursorId = -1L; - for (int i = 1; i < list.size(); i++) { - List mapList = (List) list.get(i); - Map map = new HashMap<>(mapList.size() / 2, 1f); - for (int j = 0; j < mapList.size(); j += 2) { - Object r = mapList.get(j); - if (r instanceof JedisDataException) { - throw (JedisDataException) r; - } - map.put((String) r, mapList.get(j + 1)); - } - results.add(map); - } + private AggregationResult(long totalResults, List> results) { + this(totalResults, results, (List) null); } - private AggregationResult(long totalResults, List> results) { + public AggregationResult(long totalResults, List> results, List warnings) { this.totalResults = totalResults; this.results = results; + this.warnings = warnings; } private void setCursorId(Long cursorId) { @@ -80,12 +61,17 @@ public Row getRow(int index) { return new Row(results.get(index)); } + public List getWarnings() { + return warnings; + } + public static final Builder SEARCH_AGGREGATION_RESULT = new Builder() { private static final String TOTAL_RESULTS_STR = "total_results"; private static final String RESULTS_STR = "results"; // private static final String FIELDS_STR = "fields"; private static final String FIELDS_STR = "extra_attributes"; + private static final String WARNINGS_STR = "warning"; @Override public AggregationResult build(Object data) { @@ -96,14 +82,16 @@ public AggregationResult build(Object data) { List kvList = (List) data; long totalResults = -1; List> results = null; + List warnings = null; for (KeyValue kv : kvList) { String key = BuilderFactory.STRING.build(kv.getKey()); + Object rawVal = kv.getValue(); switch (key) { case TOTAL_RESULTS_STR: - totalResults = BuilderFactory.LONG.build(kv.getValue()); + totalResults = BuilderFactory.LONG.build(rawVal); break; case RESULTS_STR: - List> resList = (List>) kv.getValue(); + List> resList = (List>) rawVal; results = new ArrayList<>(resList.size()); for (List rikv : resList) { for (KeyValue ikv : rikv) { @@ -114,9 +102,12 @@ public AggregationResult build(Object data) { } } break; + case WARNINGS_STR: + warnings = BuilderFactory.STRING_LIST.build(rawVal); + break; } } - return new AggregationResult(totalResults, results); + return new AggregationResult(totalResults, results, warnings); } list = (List) SafeEncoder.encodeObject(data); diff --git a/src/test/java/redis/clients/jedis/modules/search/SearchDefaultDialectTest.java b/src/test/java/redis/clients/jedis/modules/search/SearchDefaultDialectTest.java index 34adccf1d5..3c7d4f2372 100644 --- a/src/test/java/redis/clients/jedis/modules/search/SearchDefaultDialectTest.java +++ b/src/test/java/redis/clients/jedis/modules/search/SearchDefaultDialectTest.java @@ -1,6 +1,7 @@ package redis.clients.jedis.modules.search; import static org.junit.Assert.*; +import static redis.clients.jedis.util.AssertUtil.assertEqualsByProtocol; import static redis.clients.jedis.util.AssertUtil.assertOK; import java.util.*; @@ -17,6 +18,7 @@ import redis.clients.jedis.exceptions.JedisDataException; import redis.clients.jedis.search.*; import redis.clients.jedis.search.schemafields.NumericField; +import redis.clients.jedis.search.schemafields.TagField; import redis.clients.jedis.search.schemafields.TextField; import redis.clients.jedis.modules.RedisModuleCommandsTestBase; import redis.clients.jedis.search.aggr.AggregationBuilder; @@ -57,6 +59,14 @@ private void addDocument(Document doc) { client.hset(key, map); } + private static Map toMap(String... values) { + Map map = new HashMap<>(); + for (int i = 0; i < values.length; i += 2) { + map.put(values[i], values[i + 1]); + } + return map; + } + @Test public void testQueryParams() { Schema sc = new Schema().addNumericField("numval"); @@ -199,4 +209,30 @@ public void dialectBoundSpellCheck() { FTSpellCheckParams.spellCheckParams().dialect(0))); MatcherAssert.assertThat(error.getMessage(), Matchers.containsString("DIALECT requires a non negative integer")); } + + @org.junit.Ignore + @Test + public void warningMaxPrefixExpansions() { + final String configParam = "MAXPREFIXEXPANSIONS"; + String configValue = (String) client.ftConfigGet(configParam).get(configParam); + try { + assertOK(client.ftCreate(INDEX, FTCreateParams.createParams().on(IndexDataType.HASH), + TextField.of("t"), TagField.of("t2"))); + + client.hset("doc13", toMap("t", "foo", "t2", "foo")); + + client.ftConfigSet(configParam, "1"); + + SearchResult srcResult = client.ftSearch(INDEX, "fo*"); + assertEqualsByProtocol(protocol, null, Arrays.asList(), srcResult.getWarnings()); + + client.hset("doc23", toMap("t", "fooo", "t2", "fooo")); + + AggregationResult aggResult = client.ftAggregate(INDEX, new AggregationBuilder("fo*").loadAll()); + assertEqualsByProtocol(protocol, null, Arrays.asList("Max prefix expansions limit was reached"), aggResult.getWarnings()); + } finally { + client.ftConfigSet(configParam, configValue); + } + } + } diff --git a/src/test/java/redis/clients/jedis/util/AssertUtil.java b/src/test/java/redis/clients/jedis/util/AssertUtil.java index 110be7e48e..152f981d84 100644 --- a/src/test/java/redis/clients/jedis/util/AssertUtil.java +++ b/src/test/java/redis/clients/jedis/util/AssertUtil.java @@ -7,7 +7,6 @@ import java.util.Collection; import java.util.Iterator; import java.util.List; -import java.util.Map; import java.util.Objects; import java.util.Set;