diff --git a/src/integrationTest/java/org/opensearch/security/MaskingTests.java b/src/integrationTest/java/org/opensearch/security/MaskingTests.java index cd071666f9..77aa9305f2 100644 --- a/src/integrationTest/java/org/opensearch/security/MaskingTests.java +++ b/src/integrationTest/java/org/opensearch/security/MaskingTests.java @@ -35,6 +35,7 @@ import org.junit.ClassRule; import org.junit.Test; import org.junit.runner.RunWith; +import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; import org.opensearch.action.bulk.BulkRequest; import org.opensearch.action.index.IndexRequest; import org.opensearch.action.search.SearchRequest; @@ -168,13 +169,6 @@ public void testMaskingBaslineScenarios() throws Exception { return null; }; - createIndices(1, 50); - check.call(); - - setup(); - createIndices(1, 50 * 100); - check.call(); - setup(); createIndices(3, 50 * 100); check.call(); @@ -185,7 +179,7 @@ public void testMaskingBaslineScenarios() throws Exception { } @Test - public void testMaskingAggregateFilterScenarios() throws Exception { + public void testMaskingAggregateFilterStringQueryScenarios() throws Exception { final Callable check = () -> { final long startMs = System.currentTimeMillis(); @@ -220,12 +214,50 @@ public void testMaskingAggregateFilterScenarios() throws Exception { return null; }; - createIndices(1, 50); + setup(); + createIndices(3, 50 * 100); check.call(); setup(); - createIndices(1, 50 * 100); + createIndices(3, 50 * 100 * 10); check.call(); + } + + @Test + public void testMaskingAggregateFilterTermQueryScenarios() throws Exception { + final Callable check = () -> { + final long startMs = System.currentTimeMillis(); + + SearchSourceBuilder ssb = new SearchSourceBuilder(); + ssb.aggregation(AggregationBuilders.filters("my-filter", QueryBuilders.termQuery("title","last"))); + ssb.aggregation(AggregationBuilders.count("counting").field("genre.keyword")); + ssb.aggregation(AggregationBuilders.avg("averaging").field("longId")); + ssb.size(0); + + queryAndGetStats(ADMIN_USER, ssb); + queryAndGetStats(READER, ssb); + + removeRolesFromReader(); + attachRoleToReader(TestRoles.ROLE_WITH_NO_MASKING); + queryAndGetStats(READER, ssb); + + removeRolesFromReader(); + attachRoleToReader(TestRoles.MASKING_LOW_REPEAT_VALUE); + queryAndGetStats(READER, ssb); + + removeRolesFromReader(); + attachRoleToReader(TestRoles.MASKING_RANDOM_LONG); + queryAndGetStats(READER, ssb); + + removeRolesFromReader(); + attachRoleToReader(TestRoles.MASKING_RANDOM_STRING); + queryAndGetStats(READER, ssb); + + final long endMs = System.currentTimeMillis() - startMs; + System.out.println("Finished checks in " + endMs + "ms"); + + return null; + }; setup(); createIndices(3, 50 * 100); @@ -250,6 +282,8 @@ private void createIndices(final int count, final int docCount) throws IOExcepti System.out.println("Creating " + count + " indices with " + docCount + " documents"); final long currentTimeMillis = System.currentTimeMillis(); try (Client client = cluster.getInternalNodeClient()) { + client.admin().indices().delete(new DeleteIndexRequest().indices("*")).actionGet(); + final ExecutorService pool = Executors.newFixedThreadPool(25); final List> futures = IntStream.range(1, count + 1).mapToObj(n -> { final String indexName = INDEX_NAME_PREFIX + n; @@ -264,6 +298,12 @@ private void createIndices(final int count, final int docCount) throws IOExcepti docs.add(new IndexRequest().index(indexName).id(uuid).source(baseDoc)); } + var uuid = UUID.randomUUID().toString(); + baseDoc.put("guid", uuid); + baseDoc.put("longId", random.nextLong()); + baseDoc.put("title", "last"); + docs.add(new IndexRequest().index(indexName).id(uuid).source(baseDoc)); + for (int indexReqGroupN = 0; indexReqGroupN < docCount / 250; indexReqGroupN++) { BulkRequest br = new BulkRequest(); docs.stream().skip(n * 250).limit(250).forEach(ir -> { @@ -323,10 +363,10 @@ private void queryAndGetStats(final TestSecurityConfig.User user, final SearchSo user.getName() + ", " + results.size() + ", " + attempts + - ", " + String.format("%,f", results.stream().mapToLong(a -> a).average().getAsDouble()) + + ", " + String.format("%,.0f", results.stream().mapToLong(a -> a).average().getAsDouble()) + ", " + String.format("%,d", results.stream().mapToLong(a -> a).max().getAsLong()) + ", " + String.format("%,d", results.stream().mapToLong(a -> a).min().getAsLong()) + - ", " + String.format("%,.2f", calcStd(results))); + ", " + String.format("%,.0f", calcStd(results))); } }