diff --git a/build.gradle b/build.gradle index e40ac6b6..826d94ea 100644 --- a/build.gradle +++ b/build.gradle @@ -12,14 +12,6 @@ buildscript { opensearch_version = System.getProperty("opensearch.version", "3.0.0-SNAPSHOT") isSnapshot = "true" == System.getProperty("build.snapshot", "true") buildVersionQualifier = System.getProperty("build.version_qualifier", "") - version_tokens = opensearch_version.tokenize('-') - opensearch_build = version_tokens[0] + '.0' - if (buildVersionQualifier) { - opensearch_build += "-${buildVersionQualifier}" - } - if (isSnapshot) { - opensearch_build += "-SNAPSHOT" - } } repositories { @@ -77,6 +69,8 @@ apply plugin: 'opensearch.testclusters' apply plugin: 'opensearch.pluginzip' def sqlJarDirectory = "$buildDir/dependencies/opensearch-sql-plugin" +def jsJarDirectory = "$buildDir/dependencies/opensearch-job-scheduler" +def adJarDirectory = "$buildDir/dependencies/opensearch-time-series-analytics" configurations { zipArchive @@ -96,28 +90,52 @@ task addJarsToClasspath(type: Copy) { include "protocol-${version}.jar" } into("$buildDir/classes") + + from(fileTree(dir: jsJarDirectory)) { + include "opensearch-job-scheduler-${version}.jar" + } + into("$buildDir/classes") + + from(fileTree(dir: adJarDirectory)) { + include "opensearch-time-series-analytics-${version}.jar" + } + into("$buildDir/classes") } dependencies { - compileOnly group: 'org.opensearch', name:'opensearch-ml-client', version: "${version}" + // 3P dependencies compileOnly group: 'com.google.code.gson', name: 'gson', version: '2.10.1' compileOnly "org.apache.logging.log4j:log4j-slf4j-impl:2.22.0" compileOnly group: 'org.json', name: 'json', version: '20231013' - zipArchive group: 'org.opensearch.plugin', name:'opensearch-sql-plugin', version: "${version}" implementation("com.google.guava:guava:32.1.3-jre") + implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.13.0' + + // Plugin dependencies + compileOnly group: 'org.opensearch', name:'opensearch-ml-client', version: "${version}" + implementation fileTree(dir: jsJarDirectory, include: ["opensearch-job-scheduler-${version}.jar"]) + implementation fileTree(dir: adJarDirectory, include: ["opensearch-time-series-analytics-${version}.jar"]) implementation fileTree(dir: sqlJarDirectory, include: ["opensearch-sql-${version}.jar", "ppl-${version}.jar", "protocol-${version}.jar"]) compileOnly "org.opensearch:common-utils:${version}" + + // ZipArchive dependencies used for integration tests + zipArchive group: 'org.opensearch.plugin', name:'opensearch-ml-plugin', version: "${version}" + zipArchive group: 'org.opensearch.plugin', name:'opensearch-job-scheduler', version: "${version}" + zipArchive "org.opensearch.plugin:opensearch-anomaly-detection:${version}" + zipArchive group: 'org.opensearch.plugin', name:'opensearch-sql-plugin', version: "${version}" + + // Test dependencies testImplementation "org.opensearch.test:framework:${opensearch_version}" - testImplementation "org.mockito:mockito-core:5.8.0" + testImplementation group: 'junit', name: 'junit', version: '4.13.2' + testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.8.0' + testImplementation group: 'org.mockito', name: 'mockito-inline', version: '5.2.0' + testImplementation("net.bytebuddy:byte-buddy:1.14.7") + testImplementation("net.bytebuddy:byte-buddy-agent:1.14.7") testImplementation 'org.junit.jupiter:junit-jupiter-api:5.10.1' testImplementation 'org.mockito:mockito-junit-jupiter:5.8.0' testImplementation "com.nhaarman.mockitokotlin2:mockito-kotlin:2.2.0" testImplementation "com.cronutils:cron-utils:9.2.1" testImplementation "commons-validator:commons-validator:1.8.0" testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.10.1' - - // ZipArchive dependencies used for integration tests - zipArchive group: 'org.opensearch.plugin', name:'opensearch-ml-plugin', version: "${opensearch_build}" } task extractSqlJar(type: Copy) { @@ -126,7 +144,21 @@ task extractSqlJar(type: Copy) { into sqlJarDirectory } +task extractJsJar(type: Copy) { + mustRunAfter() + from(zipTree(configurations.zipArchive.find { it.name.startsWith("opensearch-job-scheduler")})) + into jsJarDirectory +} + +task extractAdJar(type: Copy) { + mustRunAfter() + from(zipTree(configurations.zipArchive.find { it.name.startsWith("opensearch-anomaly-detection")})) + into adJarDirectory +} + tasks.addJarsToClasspath.dependsOn(extractSqlJar) +tasks.addJarsToClasspath.dependsOn(extractJsJar) +tasks.addJarsToClasspath.dependsOn(extractAdJar) project.tasks.delombok.dependsOn(addJarsToClasspath) tasks.publishNebulaPublicationToMavenLocal.dependsOn ':generatePomFileForPluginZipPublication' tasks.validateNebulaPom.dependsOn ':generatePomFileForPluginZipPublication' @@ -137,12 +169,13 @@ testingConventions.enabled = false thirdPartyAudit.enabled = false test { - useJUnitPlatform() testLogging { exceptionFormat "full" events "skipped", "passed", "failed" // "started" showStandardStreams true } + include '**/*Tests.class' + systemProperty 'tests.security.manager', 'false' } spotless { @@ -161,6 +194,8 @@ spotless { compileJava { dependsOn extractSqlJar + dependsOn extractJsJar + dependsOn extractAdJar dependsOn delombok options.compilerArgs.addAll(["-processor", 'lombok.launch.AnnotationProcessorHider$AnnotationProcessor']) } @@ -169,6 +204,8 @@ compileTestJava { options.compilerArgs.addAll(["-processor", 'lombok.launch.AnnotationProcessorHider$AnnotationProcessor']) } +forbiddenApisTest.ignoreFailures = true + opensearchplugin { name 'skills' diff --git a/src/main/java/org/opensearch/agent/tools/SearchAnomalyDetectorsTool.java b/src/main/java/org/opensearch/agent/tools/SearchAnomalyDetectorsTool.java new file mode 100644 index 00000000..f8f43bb0 --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/SearchAnomalyDetectorsTool.java @@ -0,0 +1,203 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import org.apache.commons.lang3.StringUtils; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.ad.client.AnomalyDetectionNodeClient; +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.RangeQueryBuilder; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.index.query.WildcardQueryBuilder; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.spi.tools.Parser; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.sort.SortOrder; + +import lombok.Getter; +import lombok.Setter; + +@ToolAnnotation(SearchAnomalyDetectorsTool.TYPE) +public class SearchAnomalyDetectorsTool implements Tool { + public static final String TYPE = "SearchAnomalyDetectorsTool"; + private static final String DEFAULT_DESCRIPTION = "Use this tool to search anomaly detectors."; + + @Setter + @Getter + private String name = TYPE; + @Getter + @Setter + private String description = DEFAULT_DESCRIPTION; + + @Getter + private String version; + + private Client client; + + private AnomalyDetectionNodeClient adClient; + + @Setter + private Parser inputParser; + @Setter + private Parser outputParser; + + public SearchAnomalyDetectorsTool(Client client) { + this.client = client; + this.adClient = new AnomalyDetectionNodeClient(client); + + // probably keep this overridden output parser. need to ensure the output matches what's expected + outputParser = new Parser<>() { + @Override + public Object parse(Object o) { + @SuppressWarnings("unchecked") + List mlModelOutputs = (List) o; + return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response"); + } + }; + } + + // Response is currently in a simple string format including the list of anomaly detectors (only name and ID attached), and + // number of total detectors. The output will likely need to be updated, standardized, and include more fields in the + // future to cover a sufficient amount of potential questions the agent will need to handle. + @Override + public void run(Map parameters, ActionListener listener) { + final String detectorName = parameters.getOrDefault("detectorName", null); + final String detectorNamePattern = parameters.getOrDefault("detectorNamePattern", null); + final String indices = parameters.getOrDefault("indices", null); + final Boolean highCardinality = parameters.containsKey("highCardinality") + ? Boolean.parseBoolean(parameters.get("highCardinality")) + : null; + final Long lastUpdateTime = parameters.containsKey("lastUpdateTime") && StringUtils.isNumeric(parameters.get("lastUpdateTime")) + ? Long.parseLong(parameters.get("lastUpdateTime")) : null; + final String sortOrderStr = parameters.getOrDefault("sortOrder", "asc"); + final SortOrder sortOrder = sortOrderStr.equalsIgnoreCase("asc") ? SortOrder.ASC : SortOrder.DESC; + final String sortString = parameters.getOrDefault("sortString", "name.keyword"); + final int size = parameters.containsKey("size") ? Integer.parseInt(parameters.get("size")) : 20; + final int startIndex = parameters.containsKey("startIndex") ? Integer.parseInt(parameters.get("startIndex")) : 0; + final Boolean running = parameters.containsKey("running") ? Boolean.parseBoolean(parameters.get("running")) : null; + final Boolean disabled = parameters.containsKey("disabled") ? Boolean.parseBoolean(parameters.get("disabled")) : null; + final Boolean failed = parameters.containsKey("failed") ? Boolean.parseBoolean(parameters.get("failed")) : null; + + List mustList = new ArrayList(); + if (detectorName != null) { + mustList.add(new TermQueryBuilder("name.keyword", detectorName)); + } + if (detectorNamePattern != null) { + mustList.add(new WildcardQueryBuilder("name.keyword", detectorNamePattern)); + } + if (indices != null) { + mustList.add(new TermQueryBuilder("indices", indices)); + } + if (highCardinality != null) { + mustList.add(new TermQueryBuilder("detector_type", highCardinality ? "MULTI_ENTITY" : "SINGLE_ENTITY")); + } + if (lastUpdateTime != null) { + mustList.add(new BoolQueryBuilder().filter(new RangeQueryBuilder("last_update_time").gte(lastUpdateTime))); + + } + + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + boolQueryBuilder.must().addAll(mustList); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() + .query(boolQueryBuilder) + .size(size) + .from(startIndex) + .sort(sortString, sortOrder); + + SearchRequest searchDetectorRequest = new SearchRequest().source(searchSourceBuilder); + + if (running != null || disabled != null || failed != null) { + // TODO: add a listener to trigger when the first response is received, to trigger the profile API call + // to fetch the detector state, etc. + // Will need AD client to onboard the profile API first. + } + + ActionListener searchDetectorListener = ActionListener.wrap(response -> { + StringBuilder sb = new StringBuilder(); + SearchHit[] hits = response.getHits().getHits(); + sb.append("AnomalyDetectors=["); + for (SearchHit hit : hits) { + sb.append("{"); + sb.append("id=").append(hit.getId()).append(","); + sb.append("name=").append(hit.getSourceAsMap().get("name")); + sb.append("}"); + } + sb.append("]"); + sb.append("TotalAnomalyDetectors=").append(response.getHits().getTotalHits().value); + listener.onResponse((T) sb.toString()); + }, e -> { listener.onFailure(e); }); + + adClient.searchAnomalyDetectors(searchDetectorRequest, searchDetectorListener); + } + + @Override + public boolean validate(Map parameters) { + return true; + } + + @Override + public String getType() { + return TYPE; + } + + /** + * Factory for the {@link SearchAnomalyDetectorsTool} + */ + public static class Factory implements Tool.Factory { + private Client client; + + private AnomalyDetectionNodeClient adClient; + + private static Factory INSTANCE; + + /** + * Create or return the singleton factory instance + */ + public static Factory getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (SearchAnomalyDetectorsTool.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new Factory(); + return INSTANCE; + } + } + + /** + * Initialize this factory + * @param client The OpenSearch client + */ + public void init(Client client) { + this.client = client; + this.adClient = new AnomalyDetectionNodeClient(client); + } + + @Override + public SearchAnomalyDetectorsTool create(Map map) { + return new SearchAnomalyDetectorsTool(client); + } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } + } + +} diff --git a/src/test/java/org/opensearch/agent/tools/SearchAnomalyDetectorsToolTests.java b/src/test/java/org/opensearch/agent/tools/SearchAnomalyDetectorsToolTests.java new file mode 100644 index 00000000..37ff02a1 --- /dev/null +++ b/src/test/java/org/opensearch/agent/tools/SearchAnomalyDetectorsToolTests.java @@ -0,0 +1,153 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Locale; +import java.util.Map; + +import org.apache.lucene.search.TotalHits; +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.ActionType; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.client.AdminClient; +import org.opensearch.client.ClusterAdminClient; +import org.opensearch.client.IndicesAdminClient; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.Aggregations; + +public class SearchAnomalyDetectorsToolTests { + @Mock + private NodeClient nodeClient; + @Mock + private AdminClient adminClient; + @Mock + private IndicesAdminClient indicesAdminClient; + @Mock + private ClusterAdminClient clusterAdminClient; + + private Map nullParams; + private Map emptyParams; + private Map nonEmptyParams; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + SearchAnomalyDetectorsTool.Factory.getInstance().init(nodeClient); + + nullParams = null; + emptyParams = Collections.emptyMap(); + nonEmptyParams = Map.of("detectorName", "foo"); + } + + @Test + public void testRunWithNoDetectors() throws Exception { + Tool tool = SearchAnomalyDetectorsTool.Factory.getInstance().create(Collections.emptyMap()); + + SearchHit[] hits = new SearchHit[0]; + + TotalHits totalHits = new TotalHits(hits.length, TotalHits.Relation.EQUAL_TO); + + SearchResponse getDetectorsResponse = new SearchResponse( + new SearchResponseSections(new SearchHits(hits, totalHits, 0), new Aggregations(new ArrayList<>()), null, false, null, null, 0), + null, + 0, + 0, + 0, + 0, + null, + null + ); + String expectedResponseStr = String.format(Locale.getDefault(), "AnomalyDetectors=[]TotalAnomalyDetectors=%d", hits.length); + + @SuppressWarnings("unchecked") + ActionListener listener = Mockito.mock(ActionListener.class); + + doAnswer((invocation) -> { + ActionListener responseListener = invocation.getArgument(2); + responseListener.onResponse(getDetectorsResponse); + return null; + }).when(nodeClient).execute(any(ActionType.class), any(), any()); + + tool.run(emptyParams, listener); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(String.class); + verify(listener, times(1)).onResponse(responseCaptor.capture()); + assertEquals(expectedResponseStr, responseCaptor.getValue()); + } + + @Test + public void testRunWithSingleAnomalyDetector() throws Exception { + final String detectorName = "detector-1"; + final String detectorId = "detector-1-id"; + Tool tool = SearchAnomalyDetectorsTool.Factory.getInstance().create(Collections.emptyMap()); + + XContentBuilder content = XContentBuilder.builder(XContentType.JSON.xContent()); + content.startObject(); + content.field("name", detectorName); + content.endObject(); + SearchHit[] hits = new SearchHit[1]; + hits[0] = new SearchHit(0, detectorId, null, null).sourceRef(BytesReference.bytes(content)); + + TotalHits totalHits = new TotalHits(hits.length, TotalHits.Relation.EQUAL_TO); + + SearchResponse getDetectorsResponse = new SearchResponse( + new SearchResponseSections(new SearchHits(hits, totalHits, 0), new Aggregations(new ArrayList<>()), null, false, null, null, 0), + null, + 0, + 0, + 0, + 0, + null, + null + ); + String expectedResponseStr = String + .format("AnomalyDetectors=[{id=%s,name=%s}]TotalAnomalyDetectors=%d", detectorId, detectorName, hits.length); + + @SuppressWarnings("unchecked") + ActionListener listener = Mockito.mock(ActionListener.class); + + doAnswer((invocation) -> { + ActionListener responseListener = invocation.getArgument(2); + responseListener.onResponse(getDetectorsResponse); + return null; + }).when(nodeClient).execute(any(ActionType.class), any(), any()); + + tool.run(emptyParams, listener); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(String.class); + verify(listener, times(1)).onResponse(responseCaptor.capture()); + assertEquals(expectedResponseStr, responseCaptor.getValue()); + } + + @Test + public void testValidate() { + Tool tool = SearchAnomalyDetectorsTool.Factory.getInstance().create(Collections.emptyMap()); + assertEquals(SearchAnomalyDetectorsTool.TYPE, tool.getType()); + assertTrue(tool.validate(emptyParams)); + assertTrue(tool.validate(nonEmptyParams)); + assertTrue(tool.validate(nullParams)); + } +}